diff --git a/.landscape.yml b/.landscape.yml new file mode 100644 index 000000000..a27bbb034 --- /dev/null +++ b/.landscape.yml @@ -0,0 +1,22 @@ +pylint: + disable: + # We use this a lot (e.g. via document._meta) + - protected-access + + options: + additional-builtins: + # add xrange and long as valid built-ins. In Python 3, xrange is + # translated into range and long is translated into int via 2to3 (see + # "use_2to3" in setup.py). This should be removed when we drop Python + # 2 support (which probably won't happen any time soon). + - xrange + - long + +pyflakes: + disable: + # undefined variables are already covered by pylint (and exclude + # xrange & long) + - F821 + +ignore-paths: + - benchmark.py diff --git a/.travis.yml b/.travis.yml index 4721a0868..cb6c97e69 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: python python: -- '2.6' # TODO remove in v0.11.0 - '2.7' - '3.3' - '3.4' @@ -43,7 +42,11 @@ before_script: script: - tox -e $(echo py$TRAVIS_PYTHON_VERSION-mg$PYMONGO | tr -d . | sed -e 's/pypypy/pypy/') -- --with-coverage -after_script: coveralls --verbose +# For now only submit coveralls for Python v2.7. Python v3.x currently shows +# 0% coverage. That's caused by 'use_2to3', which builds the py3-compatible +# code in a separate dir and runs tests on that. +after_script: +- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then coveralls --verbose; fi notifications: irc: irc.freenode.org#mongoengine diff --git a/benchmark.py b/benchmark.py index 53ecf32cb..8e93ee40c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,118 +1,41 @@ #!/usr/bin/env python -import timeit - - -def cprofile_main(): - from pymongo import Connection - connection = Connection() - connection.drop_database('timeit_test') - connection.disconnect() - - from mongoengine import Document, DictField, connect - connect("timeit_test") - - class Noddy(Document): - fields = DictField() +""" +Simple benchmark comparing PyMongo and MongoEngine. + +Sample run on a mid 2015 MacBook Pro (commit b282511): + +Benchmarking... +---------------------------------------------------------------------------------------------------- +Creating 10000 dictionaries - Pymongo +2.58979988098 +---------------------------------------------------------------------------------------------------- +Creating 10000 dictionaries - Pymongo write_concern={"w": 0} +1.26657605171 +---------------------------------------------------------------------------------------------------- +Creating 10000 dictionaries - MongoEngine +8.4351580143 +---------------------------------------------------------------------------------------------------- +Creating 10000 dictionaries without continual assign - MongoEngine +7.20191693306 +---------------------------------------------------------------------------------------------------- +Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade = True +6.31104588509 +---------------------------------------------------------------------------------------------------- +Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True +6.07083487511 +---------------------------------------------------------------------------------------------------- +Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False +5.97704291344 +---------------------------------------------------------------------------------------------------- +Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False +5.9111430645 +""" - for i in range(1): - noddy = Noddy() - for j in range(20): - noddy.fields["key" + str(j)] = "value " + str(j) - noddy.save() +import timeit def main(): - """ - 0.4 Performance Figures ... - - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - Pymongo - 3.86744189262 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine - 6.23374891281 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, safe=False, validate=False - 5.33027005196 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False - pass - No Cascade - - 0.5.X - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - Pymongo - 3.89597702026 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine - 21.7735359669 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, safe=False, validate=False - 19.8670389652 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False - pass - No Cascade - - 0.6.X - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - Pymongo - 3.81559205055 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine - 10.0446798801 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, safe=False, validate=False - 9.51354718208 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False - 9.02567505836 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, force=True - 8.44933390617 - - 0.7.X - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - Pymongo - 3.78801012039 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine - 9.73050498962 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, safe=False, validate=False - 8.33456707001 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, safe=False, validate=False, cascade=False - 8.37778115273 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, force=True - 8.36906409264 - 0.8.X - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - Pymongo - 3.69964408875 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - Pymongo write_concern={"w": 0} - 3.5526599884 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine - 7.00959801674 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries without continual assign - MongoEngine - 5.60943293571 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine - write_concern={"w": 0}, cascade=True - 6.715102911 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False, cascade=True - 5.50644683838 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, write_concern={"w": 0}, validate=False - 4.69851183891 - ---------------------------------------------------------------------------------------------------- - Creating 10000 dictionaries - MongoEngine, force_insert=True, write_concern={"w": 0}, validate=False - 4.68946313858 - ---------------------------------------------------------------------------------------------------- - """ print("Benchmarking...") setup = """ @@ -131,7 +54,7 @@ def main(): for i in range(10000): example = {'fields': {}} for j in range(20): - example['fields']["key"+str(j)] = "value "+str(j) + example['fields']['key' + str(j)] = 'value ' + str(j) noddy.save(example) @@ -146,9 +69,10 @@ def main(): stmt = """ from pymongo import MongoClient +from pymongo.write_concern import WriteConcern connection = MongoClient() -db = connection.timeit_test +db = connection.get_database('timeit_test', write_concern=WriteConcern(w=0)) noddy = db.noddy for i in range(10000): @@ -156,7 +80,7 @@ def main(): for j in range(20): example['fields']["key"+str(j)] = "value "+str(j) - noddy.save(example, write_concern={"w": 0}) + noddy.save(example) myNoddys = noddy.find() [n for n in myNoddys] # iterate @@ -171,10 +95,10 @@ def main(): from pymongo import MongoClient connection = MongoClient() connection.drop_database('timeit_test') -connection.disconnect() +connection.close() from mongoengine import Document, DictField, connect -connect("timeit_test") +connect('timeit_test') class Noddy(Document): fields = DictField() diff --git a/docs/changelog.rst b/docs/changelog.rst index 3407d9cea..e04c48fbc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,13 @@ Changelog Development =========== +- (Fill this out as you fix issues and develop you features). + +Changes in 0.11.0 +================= +- BREAKING CHANGE: Renamed `ConnectionError` to `MongoEngineConnectionError` since the former is a built-in exception name in Python v3.x. #1428 +- BREAKING CHANGE: Dropped Python 2.6 support. #1428 +- BREAKING CHANGE: `from mongoengine.base import ErrorClass` won't work anymore for any error from `mongoengine.errors` (e.g. `ValidationError`). Use `from mongoengine.errors import ErrorClass instead`. #1428 - Fixed absent rounding for DecimalField when `force_string` is set. #1103 Changes in 0.10.8 diff --git a/docs/upgrade.rst b/docs/upgrade.rst index a363eae35..c0ae72051 100644 --- a/docs/upgrade.rst +++ b/docs/upgrade.rst @@ -2,6 +2,32 @@ Upgrading ######### +0.11.0 +****** +This release includes a major rehaul of MongoEngine's code quality and +introduces a few breaking changes. It also touches many different parts of +the package and although all the changes have been tested and scrutinized, +you're encouraged to thorougly test the upgrade. + +First breaking change involves renaming `ConnectionError` to `MongoEngineConnectionError`. +If you import or catch this exception, you'll need to rename it in your code. + +Second breaking change drops Python v2.6 support. If you run MongoEngine on +that Python version, you'll need to upgrade it first. + +Third breaking change drops an old backward compatibility measure where +`from mongoengine.base import ErrorClass` would work on top of +`from mongoengine.errors import ErrorClass` (where `ErrorClass` is e.g. +`ValidationError`). If you import any exceptions from `mongoengine.base`, +change it to `mongoengine.errors`. + +0.10.8 +****** +This version fixed an issue where specifying a MongoDB URI host would override +more information than it should. These changes are minor, but they still +subtly modify the connection logic and thus you're encouraged to test your +MongoDB connection before shipping v0.10.8 in production. + 0.10.7 ****** diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index 1544aa516..cece51265 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -1,25 +1,35 @@ -import connection -from connection import * -import document -from document import * -import errors -from errors import * -import fields -from fields import * -import queryset -from queryset import * -import signals -from signals import * - -__all__ = (list(document.__all__) + fields.__all__ + connection.__all__ + - list(queryset.__all__) + signals.__all__ + list(errors.__all__)) +# Import submodules so that we can expose their __all__ +from mongoengine import connection +from mongoengine import document +from mongoengine import errors +from mongoengine import fields +from mongoengine import queryset +from mongoengine import signals + +# Import everything from each submodule so that it can be accessed via +# mongoengine, e.g. instead of `from mongoengine.connection import connect`, +# users can simply use `from mongoengine import connect`, or even +# `from mongoengine import *` and then `connect('testdb')`. +from mongoengine.connection import * +from mongoengine.document import * +from mongoengine.errors import * +from mongoengine.fields import * +from mongoengine.queryset import * +from mongoengine.signals import * + + +__all__ = (list(document.__all__) + list(fields.__all__) + + list(connection.__all__) + list(queryset.__all__) + + list(signals.__all__) + list(errors.__all__)) + VERSION = (0, 10, 9) def get_version(): - if isinstance(VERSION[-1], basestring): - return '.'.join(map(str, VERSION[:-1])) + VERSION[-1] + """Return the VERSION as a string, e.g. for VERSION == (0, 10, 7), + return '0.10.7'. + """ return '.'.join(map(str, VERSION)) diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py index e8d4b6ad9..da31b9227 100644 --- a/mongoengine/base/__init__.py +++ b/mongoengine/base/__init__.py @@ -1,8 +1,28 @@ +# Base module is split into several files for convenience. Files inside of +# this module should import from a specific submodule (e.g. +# `from mongoengine.base.document import BaseDocument`), but all of the +# other modules should import directly from the top-level module (e.g. +# `from mongoengine.base import BaseDocument`). This approach is cleaner and +# also helps with cyclical import errors. from mongoengine.base.common import * from mongoengine.base.datastructures import * from mongoengine.base.document import * from mongoengine.base.fields import * from mongoengine.base.metaclasses import * -# Help with backwards compatibility -from mongoengine.errors import * +__all__ = ( + # common + 'UPDATE_OPERATORS', '_document_registry', 'get_document', + + # datastructures + 'BaseDict', 'BaseList', 'EmbeddedDocumentList', + + # document + 'BaseDocument', + + # fields + 'BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField', + + # metaclasses + 'DocumentMetaclass', 'TopLevelDocumentMetaclass' +) diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index 3a966c792..da2b8b68b 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -1,13 +1,18 @@ from mongoengine.errors import NotRegistered -__all__ = ('ALLOW_INHERITANCE', 'get_document', '_document_registry') +__all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry') + + +UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', + 'push_all', 'pull', 'pull_all', 'add_to_set', + 'set_on_insert', 'min', 'max']) -ALLOW_INHERITANCE = False _document_registry = {} def get_document(name): + """Get a document class by name.""" doc = _document_registry.get(name, None) if not doc: # Possible old style name diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 2c6ebc2a5..5e90a2e55 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -1,14 +1,16 @@ import itertools import weakref +import six + from mongoengine.common import _import_class from mongoengine.errors import DoesNotExist, MultipleObjectsReturned -__all__ = ("BaseDict", "BaseList", "EmbeddedDocumentList") +__all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList') class BaseDict(dict): - """A special dict so we can watch any changes""" + """A special dict so we can watch any changes.""" _dereferenced = False _instance = None @@ -93,8 +95,7 @@ def _mark_as_changed(self, key=None): class BaseList(list): - """A special list so we can watch any changes - """ + """A special list so we can watch any changes.""" _dereferenced = False _instance = None @@ -209,17 +210,22 @@ def _mark_as_changed(self, key=None): class EmbeddedDocumentList(BaseList): @classmethod - def __match_all(cls, i, kwargs): - items = kwargs.items() - return all([ - getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items - ]) + def __match_all(cls, embedded_doc, kwargs): + """Return True if a given embedded doc matches all the filter + kwargs. If it doesn't return False. + """ + for key, expected_value in kwargs.items(): + doc_val = getattr(embedded_doc, key) + if doc_val != expected_value and six.text_type(doc_val) != expected_value: + return False + return True @classmethod - def __only_matches(cls, obj, kwargs): + def __only_matches(cls, embedded_docs, kwargs): + """Return embedded docs that match the filter kwargs.""" if not kwargs: - return obj - return filter(lambda i: cls.__match_all(i, kwargs), obj) + return embedded_docs + return [doc for doc in embedded_docs if cls.__match_all(doc, kwargs)] def __init__(self, list_items, instance, name): super(EmbeddedDocumentList, self).__init__(list_items, instance, name) @@ -285,18 +291,18 @@ def get(self, **kwargs): values = self.__only_matches(self, kwargs) if len(values) == 0: raise DoesNotExist( - "%s matching query does not exist." % self._name + '%s matching query does not exist.' % self._name ) elif len(values) > 1: raise MultipleObjectsReturned( - "%d items returned, instead of 1" % len(values) + '%d items returned, instead of 1' % len(values) ) return values[0] def first(self): - """ - Returns the first embedded document in the list, or ``None`` if empty. + """Return the first embedded document in the list, or ``None`` + if empty. """ if len(self) > 0: return self[0] @@ -438,7 +444,7 @@ class SpecificStrictDict(cls): __slots__ = allowed_keys_tuple def __repr__(self): - return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) + return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) cls._classes[allowed_keys] = SpecificStrictDict return cls._classes[allowed_keys] diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 59f5aebc7..03dc75626 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -1,6 +1,5 @@ import copy import numbers -import operator from collections import Hashable from functools import partial @@ -8,30 +7,27 @@ from bson.dbref import DBRef from bson.son import SON import pymongo +import six from mongoengine import signals -from mongoengine.base.common import ALLOW_INHERITANCE, get_document -from mongoengine.base.datastructures import ( - BaseDict, - BaseList, - EmbeddedDocumentList, - SemiStrictDict, - StrictDict -) +from mongoengine.base.common import get_document +from mongoengine.base.datastructures import (BaseDict, BaseList, + EmbeddedDocumentList, + SemiStrictDict, StrictDict) from mongoengine.base.fields import ComplexBaseField from mongoengine.common import _import_class from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, - LookUpError, ValidationError) -from mongoengine.python_support import PY3, txt_type + LookUpError, OperationError, ValidationError) -__all__ = ('BaseDocument', 'NON_FIELD_ERRORS') +__all__ = ('BaseDocument',) NON_FIELD_ERRORS = '__all__' class BaseDocument(object): __slots__ = ('_changed_fields', '_initialised', '_created', '_data', - '_dynamic_fields', '_auto_id_field', '_db_field_map', '__weakref__') + '_dynamic_fields', '_auto_id_field', '_db_field_map', + '__weakref__') _dynamic = False _dynamic_lock = True @@ -57,15 +53,15 @@ def __init__(self, *args, **values): name = next(field) if name in values: raise TypeError( - "Multiple values for keyword argument '" + name + "'") + 'Multiple values for keyword argument "%s"' % name) values[name] = value - __auto_convert = values.pop("__auto_convert", True) + __auto_convert = values.pop('__auto_convert', True) # 399: set default values only to fields loaded from DB - __only_fields = set(values.pop("__only_fields", values)) + __only_fields = set(values.pop('__only_fields', values)) - _created = values.pop("_created", True) + _created = values.pop('_created', True) signals.pre_init.send(self.__class__, document=self, values=values) @@ -76,7 +72,7 @@ def __init__(self, *args, **values): self._fields.keys() + ['id', 'pk', '_cls', '_text_score']) if _undefined_fields: msg = ( - "The fields '{0}' do not exist on the document '{1}'" + 'The fields "{0}" do not exist on the document "{1}"' ).format(_undefined_fields, self._class_name) raise FieldDoesNotExist(msg) @@ -95,7 +91,7 @@ def __init__(self, *args, **values): value = getattr(self, key, None) setattr(self, key, value) - if "_cls" not in values: + if '_cls' not in values: self._cls = self._class_name # Set passed values after initialisation @@ -150,7 +146,7 @@ def __setattr__(self, name, value): if self._dynamic and not self._dynamic_lock: if not hasattr(self, name) and not name.startswith('_'): - DynamicField = _import_class("DynamicField") + DynamicField = _import_class('DynamicField') field = DynamicField(db_field=name) field.name = name self._dynamic_fields[name] = field @@ -169,11 +165,13 @@ def __setattr__(self, name, value): except AttributeError: self__created = True - if (self._is_document and not self__created and - name in self._meta.get('shard_key', tuple()) and - self._data.get(name) != value): - OperationError = _import_class('OperationError') - msg = "Shard Keys are immutable. Tried to update %s" % name + if ( + self._is_document and + not self__created and + name in self._meta.get('shard_key', tuple()) and + self._data.get(name) != value + ): + msg = 'Shard Keys are immutable. Tried to update %s' % name raise OperationError(msg) try: @@ -197,8 +195,8 @@ def __getstate__(self): return data def __setstate__(self, data): - if isinstance(data["_data"], SON): - data["_data"] = self.__class__._from_son(data["_data"])._data + if isinstance(data['_data'], SON): + data['_data'] = self.__class__._from_son(data['_data'])._data for k in ('_changed_fields', '_initialised', '_created', '_data', '_dynamic_fields'): if k in data: @@ -212,7 +210,7 @@ def __setstate__(self, data): dynamic_fields = data.get('_dynamic_fields') or SON() for k in dynamic_fields.keys(): - setattr(self, k, data["_data"].get(k)) + setattr(self, k, data['_data'].get(k)) def __iter__(self): return iter(self._fields_ordered) @@ -254,12 +252,13 @@ def __repr__(self): return repr_type('<%s: %s>' % (self.__class__.__name__, u)) def __str__(self): + # TODO this could be simpler? if hasattr(self, '__unicode__'): - if PY3: + if six.PY3: return self.__unicode__() else: - return unicode(self).encode('utf-8') - return txt_type('%s object' % self.__class__.__name__) + return six.text_type(self).encode('utf-8') + return six.text_type('%s object' % self.__class__.__name__) def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: @@ -308,7 +307,7 @@ def to_mongo(self, use_db_field=True, fields=None): fields = [] data = SON() - data["_id"] = None + data['_id'] = None data['_cls'] = self._class_name # only root fields ['test1.a', 'test2'] => ['test1', 'test2'] @@ -351,18 +350,8 @@ def to_mongo(self, use_db_field=True, fields=None): else: data[field.name] = value - # If "_id" has not been set, then try and set it - Document = _import_class("Document") - if isinstance(self, Document): - if data["_id"] is None: - data["_id"] = self._data.get("id", None) - - if data['_id'] is None: - data.pop('_id') - # Only add _cls if allow_inheritance is True - if (not hasattr(self, '_meta') or - not self._meta.get('allow_inheritance', ALLOW_INHERITANCE)): + if not self._meta.get('allow_inheritance'): data.pop('_cls') return data @@ -376,16 +365,16 @@ def validate(self, clean=True): if clean: try: self.clean() - except ValidationError, error: + except ValidationError as error: errors[NON_FIELD_ERRORS] = error # Get a list of tuples of field names and their current values fields = [(self._fields.get(name, self._dynamic_fields.get(name)), self._data.get(name)) for name in self._fields_ordered] - EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + EmbeddedDocumentField = _import_class('EmbeddedDocumentField') GenericEmbeddedDocumentField = _import_class( - "GenericEmbeddedDocumentField") + 'GenericEmbeddedDocumentField') for field, value in fields: if value is not None: @@ -395,21 +384,21 @@ def validate(self, clean=True): field._validate(value, clean=clean) else: field._validate(value) - except ValidationError, error: + except ValidationError as error: errors[field.name] = error.errors or error - except (ValueError, AttributeError, AssertionError), error: + except (ValueError, AttributeError, AssertionError) as error: errors[field.name] = error elif field.required and not getattr(field, '_auto_gen', False): errors[field.name] = ValidationError('Field is required', field_name=field.name) if errors: - pk = "None" + pk = 'None' if hasattr(self, 'pk'): pk = self.pk elif self._instance and hasattr(self._instance, 'pk'): pk = self._instance.pk - message = "ValidationError (%s:%s) " % (self._class_name, pk) + message = 'ValidationError (%s:%s) ' % (self._class_name, pk) raise ValidationError(message, errors=errors) def to_json(self, *args, **kwargs): @@ -426,33 +415,26 @@ def from_json(cls, json_data, created=False): return cls._from_son(json_util.loads(json_data), created=created) def __expand_dynamic_values(self, name, value): - """expand any dynamic values to their correct types / values""" + """Expand any dynamic values to their correct types / values.""" if not isinstance(value, (dict, list, tuple)): return value - EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') - - is_list = False - if not hasattr(value, 'items'): - is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) - - if not is_list and '_cls' in value: + # If the value is a dict with '_cls' in it, turn it into a document + is_dict = isinstance(value, dict) + if is_dict and '_cls' in value: cls = get_document(value['_cls']) return cls(**value) - data = {} - for k, v in value.items(): - key = name if is_list else k - data[k] = self.__expand_dynamic_values(key, v) - - if is_list: # Convert back to a list - data_items = sorted(data.items(), key=operator.itemgetter(0)) - value = [v for k, v in data_items] + if is_dict: + value = { + k: self.__expand_dynamic_values(k, v) + for k, v in value.items() + } else: - value = data + value = [self.__expand_dynamic_values(name, v) for v in value] # Convert lists / values so we can watch for any changes on them + EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField') if (isinstance(value, (list, tuple)) and not isinstance(value, BaseList)): if issubclass(type(self), EmbeddedDocumentListField): @@ -465,8 +447,7 @@ def __expand_dynamic_values(self, name, value): return value def _mark_as_changed(self, key): - """Marks a key as explicitly changed by the user - """ + """Mark a key as explicitly changed by the user.""" if not key: return @@ -496,10 +477,11 @@ def _mark_as_changed(self, key): remove(field) def _clear_changed_fields(self): - """Using get_changed_fields iterate and remove any fields that are - marked as changed""" + """Using _get_changed_fields iterate and remove any fields that + are marked as changed. + """ for changed in self._get_changed_fields(): - parts = changed.split(".") + parts = changed.split('.') data = self for part in parts: if isinstance(data, list): @@ -511,10 +493,13 @@ def _clear_changed_fields(self): data = data.get(part, None) else: data = getattr(data, part, None) - if hasattr(data, "_changed_fields"): - if hasattr(data, "_is_document") and data._is_document: + + if hasattr(data, '_changed_fields'): + if getattr(data, '_is_document', False): continue + data._changed_fields = [] + self._changed_fields = [] def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): @@ -526,26 +511,27 @@ def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): iterator = data.iteritems() for index, value in iterator: - list_key = "%s%s." % (key, index) + list_key = '%s%s.' % (key, index) # don't check anything lower if this key is already marked # as changed. if list_key[:-1] in changed_fields: continue if hasattr(value, '_get_changed_fields'): changed = value._get_changed_fields(inspected) - changed_fields += ["%s%s" % (list_key, k) + changed_fields += ['%s%s' % (list_key, k) for k in changed if k] elif isinstance(value, (list, tuple, dict)): self._nestable_types_changed_fields( changed_fields, list_key, value, inspected) def _get_changed_fields(self, inspected=None): - """Returns a list of all fields that have explicitly been changed. + """Return a list of all fields that have explicitly been changed. """ - EmbeddedDocument = _import_class("EmbeddedDocument") - DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") - ReferenceField = _import_class("ReferenceField") - SortedListField = _import_class("SortedListField") + EmbeddedDocument = _import_class('EmbeddedDocument') + DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument') + ReferenceField = _import_class('ReferenceField') + SortedListField = _import_class('SortedListField') + changed_fields = [] changed_fields += getattr(self, '_changed_fields', []) @@ -572,7 +558,7 @@ def _get_changed_fields(self, inspected=None): ): # Find all embedded fields that have been changed changed = data._get_changed_fields(inspected) - changed_fields += ["%s%s" % (key, k) for k in changed if k] + changed_fields += ['%s%s' % (key, k) for k in changed if k] elif (isinstance(data, (list, tuple, dict)) and db_field_name not in changed_fields): if (hasattr(field, 'field') and @@ -676,21 +662,25 @@ def _delta(self): @classmethod def _get_collection_name(cls): - """Returns the collection name for this class. None for abstract class + """Return the collection name for this class. None for abstract + class. """ return cls._meta.get('collection', None) @classmethod def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False): - """Create an instance of a Document (subclass) from a PyMongo SON. + """Create an instance of a Document (subclass) from a PyMongo + SON. """ if not only_fields: only_fields = [] - # get the class name from the document, falling back to the given + # Get the class name from the document, falling back to the given # class if unavailable class_name = son.get('_cls', cls._class_name) - data = dict(("%s" % key, value) for key, value in son.iteritems()) + + # Convert SON to a dict, making sure each key is a string + data = {str(key): value for key, value in son.iteritems()} # Return correct subclass for document type if class_name != cls._class_name: @@ -712,19 +702,20 @@ def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False) else field.to_python(value)) if field_name != field.db_field: del data[field.db_field] - except (AttributeError, ValueError), e: + except (AttributeError, ValueError) as e: errors_dict[field_name] = e if errors_dict: - errors = "\n".join(["%s - %s" % (k, v) + errors = '\n'.join(['%s - %s' % (k, v) for k, v in errors_dict.items()]) - msg = ("Invalid data to create a `%s` instance.\n%s" + msg = ('Invalid data to create a `%s` instance.\n%s' % (cls._class_name, errors)) raise InvalidDocumentError(msg) + # In STRICT documents, remove any keys that aren't in cls._fields if cls.STRICT: - data = dict((k, v) - for k, v in data.iteritems() if k in cls._fields) + data = {k: v for k, v in data.iteritems() if k in cls._fields} + obj = cls(__auto_convert=False, _created=created, __only_fields=only_fields, **data) obj._changed_fields = changed_fields if not _auto_dereference: @@ -734,37 +725,43 @@ def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False) @classmethod def _build_index_specs(cls, meta_indexes): - """Generate and merge the full index specs - """ - + """Generate and merge the full index specs.""" geo_indices = cls._geo_indices() unique_indices = cls._unique_with_indexes() - index_specs = [cls._build_index_spec(spec) - for spec in meta_indexes] + index_specs = [cls._build_index_spec(spec) for spec in meta_indexes] def merge_index_specs(index_specs, indices): + """Helper method for merging index specs.""" if not indices: return index_specs - spec_fields = [v['fields'] - for k, v in enumerate(index_specs)] - # Merge unique_indexes with existing specs - for k, v in enumerate(indices): - if v['fields'] in spec_fields: - index_specs[spec_fields.index(v['fields'])].update(v) + # Create a map of index fields to index spec. We're converting + # the fields from a list to a tuple so that it's hashable. + spec_fields = { + tuple(index['fields']): index for index in index_specs + } + + # For each new index, if there's an existing index with the same + # fields list, update the existing spec with all data from the + # new spec. + for new_index in indices: + candidate = spec_fields.get(tuple(new_index['fields'])) + if candidate is None: + index_specs.append(new_index) else: - index_specs.append(v) + candidate.update(new_index) + return index_specs + # Merge geo indexes and unique_with indexes into the meta index specs. index_specs = merge_index_specs(index_specs, geo_indices) index_specs = merge_index_specs(index_specs, unique_indices) return index_specs @classmethod def _build_index_spec(cls, spec): - """Build a PyMongo index spec from a MongoEngine index spec. - """ - if isinstance(spec, basestring): + """Build a PyMongo index spec from a MongoEngine index spec.""" + if isinstance(spec, six.string_types): spec = {'fields': [spec]} elif isinstance(spec, (list, tuple)): spec = {'fields': list(spec)} @@ -775,8 +772,7 @@ def _build_index_spec(cls, spec): direction = None # Check to see if we need to include _cls - allow_inheritance = cls._meta.get('allow_inheritance', - ALLOW_INHERITANCE) + allow_inheritance = cls._meta.get('allow_inheritance') include_cls = ( allow_inheritance and not spec.get('sparse', False) and @@ -786,7 +782,7 @@ def _build_index_spec(cls, spec): # 733: don't include cls if index_cls is False unless there is an explicit cls with the index include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True)) - if "cls" in spec: + if 'cls' in spec: spec.pop('cls') for key in spec['fields']: # If inherited spec continue @@ -801,19 +797,19 @@ def _build_index_spec(cls, spec): # GEOHAYSTACK from ) # GEO2D from * direction = pymongo.ASCENDING - if key.startswith("-"): + if key.startswith('-'): direction = pymongo.DESCENDING - elif key.startswith("$"): + elif key.startswith('$'): direction = pymongo.TEXT - elif key.startswith("#"): + elif key.startswith('#'): direction = pymongo.HASHED - elif key.startswith("("): + elif key.startswith('('): direction = pymongo.GEOSPHERE - elif key.startswith(")"): + elif key.startswith(')'): direction = pymongo.GEOHAYSTACK - elif key.startswith("*"): + elif key.startswith('*'): direction = pymongo.GEO2D - if key.startswith(("+", "-", "*", "$", "#", "(", ")")): + if key.startswith(('+', '-', '*', '$', '#', '(', ')')): key = key[1:] # Use real field name, do it manually because we need field @@ -826,7 +822,7 @@ def _build_index_spec(cls, spec): parts = [] for field in fields: try: - if field != "_id": + if field != '_id': field = field.db_field except AttributeError: pass @@ -845,49 +841,53 @@ def _build_index_spec(cls, spec): return spec @classmethod - def _unique_with_indexes(cls, namespace=""): - """ - Find and set unique indexes - """ + def _unique_with_indexes(cls, namespace=''): + """Find unique indexes in the document schema and return them.""" unique_indexes = [] for field_name, field in cls._fields.items(): sparse = field.sparse + # Generate a list of indexes needed by uniqueness constraints if field.unique: unique_fields = [field.db_field] # Add any unique_with fields to the back of the index spec if field.unique_with: - if isinstance(field.unique_with, basestring): + if isinstance(field.unique_with, six.string_types): field.unique_with = [field.unique_with] # Convert unique_with field names to real field names unique_with = [] for other_name in field.unique_with: parts = other_name.split('.') + # Lookup real name parts = cls._lookup_field(parts) name_parts = [part.db_field for part in parts] unique_with.append('.'.join(name_parts)) + # Unique field should be required parts[-1].required = True sparse = (not sparse and parts[-1].name not in cls.__dict__) + unique_fields += unique_with # Add the new index to the list - fields = [("%s%s" % (namespace, f), pymongo.ASCENDING) - for f in unique_fields] + fields = [ + ('%s%s' % (namespace, f), pymongo.ASCENDING) + for f in unique_fields + ] index = {'fields': fields, 'unique': True, 'sparse': sparse} unique_indexes.append(index) - if field.__class__.__name__ == "ListField": + if field.__class__.__name__ == 'ListField': field = field.field # Grab any embedded document field unique indexes - if (field.__class__.__name__ == "EmbeddedDocumentField" and + if (field.__class__.__name__ == 'EmbeddedDocumentField' and field.document_type != cls): - field_namespace = "%s." % field_name + field_namespace = '%s.' % field_name doc_cls = field.document_type unique_indexes += doc_cls._unique_with_indexes(field_namespace) @@ -899,8 +899,9 @@ def _geo_indices(cls, inspected=None, parent_field=None): geo_indices = [] inspected.append(cls) - geo_field_type_names = ["EmbeddedDocumentField", "GeoPointField", - "PointField", "LineStringField", "PolygonField"] + geo_field_type_names = ('EmbeddedDocumentField', 'GeoPointField', + 'PointField', 'LineStringField', + 'PolygonField') geo_field_types = tuple([_import_class(field) for field in geo_field_type_names]) @@ -908,32 +909,68 @@ def _geo_indices(cls, inspected=None, parent_field=None): for field in cls._fields.values(): if not isinstance(field, geo_field_types): continue + if hasattr(field, 'document_type'): field_cls = field.document_type if field_cls in inspected: continue + if hasattr(field_cls, '_geo_indices'): geo_indices += field_cls._geo_indices( inspected, parent_field=field.db_field) elif field._geo_index: field_name = field.db_field if parent_field: - field_name = "%s.%s" % (parent_field, field_name) - geo_indices.append({'fields': - [(field_name, field._geo_index)]}) + field_name = '%s.%s' % (parent_field, field_name) + geo_indices.append({ + 'fields': [(field_name, field._geo_index)] + }) + return geo_indices @classmethod def _lookup_field(cls, parts): - """Lookup a field based on its attribute and return a list containing - the field's parents and the field. + """Given the path to a given field, return a list containing + the Field object associated with that field and all of its parent + Field objects. + + Args: + parts (str, list, or tuple) - path to the field. Should be a + string for simple fields existing on this document or a list + of strings for a field that exists deeper in embedded documents. + + Returns: + A list of Field instances for fields that were found or + strings for sub-fields that weren't. + + Example: + >>> user._lookup_field('name') + [] + + >>> user._lookup_field('roles') + [] + + >>> user._lookup_field(['roles', 'role']) + [, + ] + + >>> user._lookup_field('doesnt_exist') + raises LookUpError + + >>> user._lookup_field(['roles', 'doesnt_exist']) + [, + 'doesnt_exist'] + """ + # TODO this method is WAY too complicated. Simplify it. + # TODO don't think returning a string for embedded non-existent fields is desired - ListField = _import_class("ListField") + ListField = _import_class('ListField') DynamicField = _import_class('DynamicField') if not isinstance(parts, (list, tuple)): parts = [parts] + fields = [] field = None @@ -943,16 +980,17 @@ def _lookup_field(cls, parts): fields.append(field_name) continue + # Look up first field from the document if field is None: - # Look up first field from the document if field_name == 'pk': # Deal with "primary key" alias field_name = cls._meta['id_field'] + if field_name in cls._fields: field = cls._fields[field_name] elif cls._dynamic: field = DynamicField(db_field=field_name) - elif cls._meta.get("allow_inheritance", False) or cls._meta.get("abstract", False): + elif cls._meta.get('allow_inheritance') or cls._meta.get('abstract', False): # 744: in case the field is defined in a subclass for subcls in cls.__subclasses__(): try: @@ -965,35 +1003,55 @@ def _lookup_field(cls, parts): else: raise LookUpError('Cannot resolve field "%s"' % field_name) else: - raise LookUpError('Cannot resolve field "%s"' - % field_name) + raise LookUpError('Cannot resolve field "%s"' % field_name) else: ReferenceField = _import_class('ReferenceField') GenericReferenceField = _import_class('GenericReferenceField') + + # If previous field was a reference, throw an error (we + # cannot look up fields that are on references). if isinstance(field, (ReferenceField, GenericReferenceField)): raise LookUpError('Cannot perform join in mongoDB: %s' % '__'.join(parts)) + + # If the parent field has a "field" attribute which has a + # lookup_member method, call it to find the field + # corresponding to this iteration. if hasattr(getattr(field, 'field', None), 'lookup_member'): new_field = field.field.lookup_member(field_name) + + # If the parent field is a DynamicField or if it's part of + # a DynamicDocument, mark current field as a DynamicField + # with db_name equal to the field name. elif cls._dynamic and (isinstance(field, DynamicField) or getattr(getattr(field, 'document_type', None), '_dynamic', None)): new_field = DynamicField(db_field=field_name) + + # Else, try to use the parent field's lookup_member method + # to find the subfield. + elif hasattr(field, 'lookup_member'): + new_field = field.lookup_member(field_name) + + # Raise a LookUpError if all the other conditions failed. else: - # Look up subfield on the previous field or raise - try: - new_field = field.lookup_member(field_name) - except AttributeError: - raise LookUpError('Cannot resolve subfield or operator {} ' - 'on the field {}'.format( - field_name, field.name)) + raise LookUpError( + 'Cannot resolve subfield or operator {} ' + 'on the field {}'.format(field_name, field.name) + ) + + # If current field still wasn't found and the parent field + # is a ComplexBaseField, add the name current field name and + # move on. if not new_field and isinstance(field, ComplexBaseField): fields.append(field_name) continue elif not new_field: - raise LookUpError('Cannot resolve field "%s"' - % field_name) + raise LookUpError('Cannot resolve field "%s"' % field_name) + field = new_field # update field to the new field type + fields.append(field) + return fields @classmethod diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 037074d10..b13660c1d 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -4,21 +4,17 @@ from bson import DBRef, ObjectId, SON import pymongo +import six -from mongoengine.base.common import ALLOW_INHERITANCE -from mongoengine.base.datastructures import ( - BaseDict, BaseList, EmbeddedDocumentList -) +from mongoengine.base.common import UPDATE_OPERATORS +from mongoengine.base.datastructures import (BaseDict, BaseList, + EmbeddedDocumentList) from mongoengine.common import _import_class from mongoengine.errors import ValidationError -__all__ = ("BaseField", "ComplexBaseField", - "ObjectIdField", "GeoJsonBaseField") - -UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', - 'push_all', 'pull', 'pull_all', 'add_to_set', - 'set_on_insert', 'min', 'max']) +__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', + 'GeoJsonBaseField') class BaseField(object): @@ -73,7 +69,7 @@ def __init__(self, db_field=None, name=None, required=False, default=None, self.db_field = (db_field or name) if not primary_key else '_id' if name: - msg = "Fields' 'name' attribute deprecated in favour of 'db_field'" + msg = 'Field\'s "name" attribute deprecated in favour of "db_field"' warnings.warn(msg, DeprecationWarning) self.required = required or primary_key self.default = default @@ -89,7 +85,7 @@ def __init__(self, db_field=None, name=None, required=False, default=None, # Detect and report conflicts between metadata and base properties. conflicts = set(dir(self)) & set(kwargs) if conflicts: - raise TypeError("%s already has attribute(s): %s" % ( + raise TypeError('%s already has attribute(s): %s' % ( self.__class__.__name__, ', '.join(conflicts))) # Assign metadata to the instance @@ -147,25 +143,21 @@ def __set__(self, instance, value): v._instance = weakref.proxy(instance) instance._data[self.name] = value - def error(self, message="", errors=None, field_name=None): - """Raises a ValidationError. - """ + def error(self, message='', errors=None, field_name=None): + """Raise a ValidationError.""" field_name = field_name if field_name else self.name raise ValidationError(message, errors=errors, field_name=field_name) def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ + """Convert a MongoDB-compatible type to a Python type.""" return value def to_mongo(self, value): - """Convert a Python type to a MongoDB-compatible type. - """ + """Convert a Python type to a MongoDB-compatible type.""" return self.to_python(value) def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): - """A helper method to call to_mongo with proper inputs - """ + """Helper method to call to_mongo with proper inputs.""" f_inputs = self.to_mongo.__code__.co_varnames ex_vars = {} if 'fields' in f_inputs: @@ -177,15 +169,13 @@ def _to_mongo_safe_call(self, value, use_db_field=True, fields=None): return self.to_mongo(value, **ex_vars) def prepare_query_value(self, op, value): - """Prepare a value that is being used in a query for PyMongo. - """ + """Prepare a value that is being used in a query for PyMongo.""" if op in UPDATE_OPERATORS: self.validate(value) return value def validate(self, value, clean=True): - """Perform validation on a value. - """ + """Perform validation on a value.""" pass def _validate_choices(self, value): @@ -200,11 +190,13 @@ def _validate_choices(self, value): if isinstance(value, (Document, EmbeddedDocument)): if not any(isinstance(value, c) for c in choice_list): self.error( - 'Value must be instance of %s' % unicode(choice_list) + 'Value must be an instance of %s' % ( + six.text_type(choice_list) + ) ) # Choices which are types other than Documents elif value not in choice_list: - self.error('Value must be one of %s' % unicode(choice_list)) + self.error('Value must be one of %s' % six.text_type(choice_list)) def _validate(self, value, **kwargs): # Check the Choices Constraint @@ -247,8 +239,7 @@ class ComplexBaseField(BaseField): field = None def __get__(self, instance, owner): - """Descriptor to automatically dereference references. - """ + """Descriptor to automatically dereference references.""" if instance is None: # Document class being used rather than a document object return self @@ -260,7 +251,7 @@ def __get__(self, instance, owner): (self.field is None or isinstance(self.field, (GenericReferenceField, ReferenceField)))) - _dereference = _import_class("DeReference")() + _dereference = _import_class('DeReference')() self._auto_dereference = instance._fields[self.name]._auto_dereference if instance._initialised and dereference and instance._data.get(self.name): @@ -295,9 +286,8 @@ def __get__(self, instance, owner): return value def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ - if isinstance(value, basestring): + """Convert a MongoDB-compatible type to a Python type.""" + if isinstance(value, six.string_types): return value if hasattr(value, 'to_python'): @@ -307,14 +297,14 @@ def to_python(self, value): if not hasattr(value, 'items'): try: is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) + value = {k: v for k, v in enumerate(value)} except TypeError: # Not iterable return the value return value if self.field: self.field._auto_dereference = self._auto_dereference - value_dict = dict([(key, self.field.to_python(item)) - for key, item in value.items()]) + value_dict = {key: self.field.to_python(item) + for key, item in value.items()} else: Document = _import_class('Document') value_dict = {} @@ -337,13 +327,12 @@ def to_python(self, value): return value_dict def to_mongo(self, value, use_db_field=True, fields=None): - """Convert a Python type to a MongoDB-compatible type. - """ - Document = _import_class("Document") - EmbeddedDocument = _import_class("EmbeddedDocument") - GenericReferenceField = _import_class("GenericReferenceField") + """Convert a Python type to a MongoDB-compatible type.""" + Document = _import_class('Document') + EmbeddedDocument = _import_class('EmbeddedDocument') + GenericReferenceField = _import_class('GenericReferenceField') - if isinstance(value, basestring): + if isinstance(value, six.string_types): return value if hasattr(value, 'to_mongo'): @@ -360,13 +349,15 @@ def to_mongo(self, value, use_db_field=True, fields=None): if not hasattr(value, 'items'): try: is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) + value = {k: v for k, v in enumerate(value)} except TypeError: # Not iterable return the value return value if self.field: - value_dict = dict([(key, self.field._to_mongo_safe_call(item, use_db_field, fields)) - for key, item in value.iteritems()]) + value_dict = { + key: self.field._to_mongo_safe_call(item, use_db_field, fields) + for key, item in value.iteritems() + } else: value_dict = {} for k, v in value.iteritems(): @@ -380,9 +371,7 @@ def to_mongo(self, value, use_db_field=True, fields=None): # any _cls data so make it a generic reference allows # us to dereference meta = getattr(v, '_meta', {}) - allow_inheritance = ( - meta.get('allow_inheritance', ALLOW_INHERITANCE) - is True) + allow_inheritance = meta.get('allow_inheritance') if not allow_inheritance and not self.field: value_dict[k] = GenericReferenceField().to_mongo(v) else: @@ -404,8 +393,7 @@ def to_mongo(self, value, use_db_field=True, fields=None): return value_dict def validate(self, value): - """If field is provided ensure the value is valid. - """ + """If field is provided ensure the value is valid.""" errors = {} if self.field: if hasattr(value, 'iteritems') or hasattr(value, 'items'): @@ -415,9 +403,9 @@ def validate(self, value): for k, v in sequence: try: self.field._validate(v) - except ValidationError, error: + except ValidationError as error: errors[k] = error.errors or error - except (ValueError, AssertionError), error: + except (ValueError, AssertionError) as error: errors[k] = error if errors: @@ -443,8 +431,7 @@ def _set_owner_document(self, owner_document): class ObjectIdField(BaseField): - """A field wrapper around MongoDB's ObjectIds. - """ + """A field wrapper around MongoDB's ObjectIds.""" def to_python(self, value): try: @@ -457,10 +444,10 @@ def to_python(self, value): def to_mongo(self, value): if not isinstance(value, ObjectId): try: - return ObjectId(unicode(value)) - except Exception, e: + return ObjectId(six.text_type(value)) + except Exception as e: # e.message attribute has been deprecated since Python 2.6 - self.error(unicode(e)) + self.error(six.text_type(e)) return value def prepare_query_value(self, op, value): @@ -468,7 +455,7 @@ def prepare_query_value(self, op, value): def validate(self, value): try: - ObjectId(unicode(value)) + ObjectId(six.text_type(value)) except Exception: self.error('Invalid Object ID') @@ -480,21 +467,20 @@ class GeoJsonBaseField(BaseField): """ _geo_index = pymongo.GEOSPHERE - _type = "GeoBase" + _type = 'GeoBase' def __init__(self, auto_index=True, *args, **kwargs): """ - :param bool auto_index: Automatically create a "2dsphere" index.\ + :param bool auto_index: Automatically create a '2dsphere' index.\ Defaults to `True`. """ - self._name = "%sField" % self._type + self._name = '%sField' % self._type if not auto_index: self._geo_index = False super(GeoJsonBaseField, self).__init__(*args, **kwargs) def validate(self, value): - """Validate the GeoJson object based on its type - """ + """Validate the GeoJson object based on its type.""" if isinstance(value, dict): if set(value.keys()) == set(['type', 'coordinates']): if value['type'] != self._type: @@ -509,7 +495,7 @@ def validate(self, value): self.error('%s can only accept lists of [x, y]' % self._name) return - validate = getattr(self, "_validate_%s" % self._type.lower()) + validate = getattr(self, '_validate_%s' % self._type.lower()) error = validate(value) if error: self.error(error) @@ -522,7 +508,7 @@ def _validate_polygon(self, value, top_level=True): try: value[0][0][0] except (TypeError, IndexError): - return "Invalid Polygon must contain at least one valid linestring" + return 'Invalid Polygon must contain at least one valid linestring' errors = [] for val in value: @@ -533,12 +519,12 @@ def _validate_polygon(self, value, top_level=True): errors.append(error) if errors: if top_level: - return "Invalid Polygon:\n%s" % ", ".join(errors) + return 'Invalid Polygon:\n%s' % ', '.join(errors) else: - return "%s" % ", ".join(errors) + return '%s' % ', '.join(errors) def _validate_linestring(self, value, top_level=True): - """Validates a linestring""" + """Validate a linestring.""" if not isinstance(value, (list, tuple)): return 'LineStrings must contain list of coordinate pairs' @@ -546,7 +532,7 @@ def _validate_linestring(self, value, top_level=True): try: value[0][0] except (TypeError, IndexError): - return "Invalid LineString must contain at least one valid point" + return 'Invalid LineString must contain at least one valid point' errors = [] for val in value: @@ -555,19 +541,19 @@ def _validate_linestring(self, value, top_level=True): errors.append(error) if errors: if top_level: - return "Invalid LineString:\n%s" % ", ".join(errors) + return 'Invalid LineString:\n%s' % ', '.join(errors) else: - return "%s" % ", ".join(errors) + return '%s' % ', '.join(errors) def _validate_point(self, value): """Validate each set of coords""" if not isinstance(value, (list, tuple)): return 'Points must be a list of coordinate pairs' elif not len(value) == 2: - return "Value (%s) must be a two-dimensional point" % repr(value) + return 'Value (%s) must be a two-dimensional point' % repr(value) elif (not isinstance(value[0], (float, int)) or not isinstance(value[1], (float, int))): - return "Both values (%s) in point must be float or int" % repr(value) + return 'Both values (%s) in point must be float or int' % repr(value) def _validate_multipoint(self, value): if not isinstance(value, (list, tuple)): @@ -577,7 +563,7 @@ def _validate_multipoint(self, value): try: value[0][0] except (TypeError, IndexError): - return "Invalid MultiPoint must contain at least one valid point" + return 'Invalid MultiPoint must contain at least one valid point' errors = [] for point in value: @@ -586,7 +572,7 @@ def _validate_multipoint(self, value): errors.append(error) if errors: - return "%s" % ", ".join(errors) + return '%s' % ', '.join(errors) def _validate_multilinestring(self, value, top_level=True): if not isinstance(value, (list, tuple)): @@ -596,7 +582,7 @@ def _validate_multilinestring(self, value, top_level=True): try: value[0][0][0] except (TypeError, IndexError): - return "Invalid MultiLineString must contain at least one valid linestring" + return 'Invalid MultiLineString must contain at least one valid linestring' errors = [] for linestring in value: @@ -606,9 +592,9 @@ def _validate_multilinestring(self, value, top_level=True): if errors: if top_level: - return "Invalid MultiLineString:\n%s" % ", ".join(errors) + return 'Invalid MultiLineString:\n%s' % ', '.join(errors) else: - return "%s" % ", ".join(errors) + return '%s' % ', '.join(errors) def _validate_multipolygon(self, value): if not isinstance(value, (list, tuple)): @@ -618,7 +604,7 @@ def _validate_multipolygon(self, value): try: value[0][0][0][0] except (TypeError, IndexError): - return "Invalid MultiPolygon must contain at least one valid Polygon" + return 'Invalid MultiPolygon must contain at least one valid Polygon' errors = [] for polygon in value: @@ -627,9 +613,9 @@ def _validate_multipolygon(self, value): errors.append(error) if errors: - return "Invalid MultiPolygon:\n%s" % ", ".join(errors) + return 'Invalid MultiPolygon:\n%s' % ', '.join(errors) def to_mongo(self, value): if isinstance(value, dict): return value - return SON([("type", self._type), ("coordinates", value)]) + return SON([('type', self._type), ('coordinates', value)]) diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 972834517..481408bf0 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -1,10 +1,11 @@ import warnings -from mongoengine.base.common import ALLOW_INHERITANCE, _document_registry +import six + +from mongoengine.base.common import _document_registry from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField from mongoengine.common import _import_class from mongoengine.errors import InvalidDocumentError -from mongoengine.python_support import PY3 from mongoengine.queryset import (DO_NOTHING, DoesNotExist, MultipleObjectsReturned, QuerySetManager) @@ -45,7 +46,8 @@ def __new__(cls, name, bases, attrs): attrs['_meta'] = meta attrs['_meta']['abstract'] = False # 789: EmbeddedDocument shouldn't inherit abstract - if attrs['_meta'].get('allow_inheritance', ALLOW_INHERITANCE): + # If allow_inheritance is True, add a "_cls" string field to the attrs + if attrs['_meta'].get('allow_inheritance'): StringField = _import_class('StringField') attrs['_cls'] = StringField() @@ -87,16 +89,17 @@ def __new__(cls, name, bases, attrs): # Ensure no duplicate db_fields duplicate_db_fields = [k for k, v in field_names.items() if v > 1] if duplicate_db_fields: - msg = ("Multiple db_fields defined for: %s " % - ", ".join(duplicate_db_fields)) + msg = ('Multiple db_fields defined for: %s ' % + ', '.join(duplicate_db_fields)) raise InvalidDocumentError(msg) # Set _fields and db_field maps attrs['_fields'] = doc_fields - attrs['_db_field_map'] = dict([(k, getattr(v, 'db_field', k)) - for k, v in doc_fields.iteritems()]) - attrs['_reverse_db_field_map'] = dict( - (v, k) for k, v in attrs['_db_field_map'].iteritems()) + attrs['_db_field_map'] = {k: getattr(v, 'db_field', k) + for k, v in doc_fields.items()} + attrs['_reverse_db_field_map'] = { + v: k for k, v in attrs['_db_field_map'].items() + } attrs['_fields_ordered'] = tuple(i[1] for i in sorted( (v.creation_counter, v.name) @@ -116,10 +119,8 @@ def __new__(cls, name, bases, attrs): if hasattr(base, '_meta'): # Warn if allow_inheritance isn't set and prevent # inheritance of classes where inheritance is set to False - allow_inheritance = base._meta.get('allow_inheritance', - ALLOW_INHERITANCE) - if (allow_inheritance is not True and - not base._meta.get('abstract')): + allow_inheritance = base._meta.get('allow_inheritance') + if not allow_inheritance and not base._meta.get('abstract'): raise ValueError('Document %s may not be subclassed' % base.__name__) @@ -161,7 +162,7 @@ def __new__(cls, name, bases, attrs): # module continues to use im_func and im_self, so the code below # copies __func__ into im_func and __self__ into im_self for # classmethod objects in Document derived classes. - if PY3: + if six.PY3: for val in new_class.__dict__.values(): if isinstance(val, classmethod): f = val.__get__(new_class) @@ -179,11 +180,11 @@ def __new__(cls, name, bases, attrs): if isinstance(f, CachedReferenceField): if issubclass(new_class, EmbeddedDocument): - raise InvalidDocumentError( - "CachedReferenceFields is not allowed in EmbeddedDocuments") + raise InvalidDocumentError('CachedReferenceFields is not ' + 'allowed in EmbeddedDocuments') if not f.document_type: raise InvalidDocumentError( - "Document is not available to sync") + 'Document is not available to sync') if f.auto_sync: f.start_listener() @@ -195,8 +196,8 @@ def __new__(cls, name, bases, attrs): 'reverse_delete_rule', DO_NOTHING) if isinstance(f, DictField) and delete_rule != DO_NOTHING: - msg = ("Reverse delete rules are not supported " - "for %s (field: %s)" % + msg = ('Reverse delete rules are not supported ' + 'for %s (field: %s)' % (field.__class__.__name__, field.name)) raise InvalidDocumentError(msg) @@ -204,16 +205,16 @@ def __new__(cls, name, bases, attrs): if delete_rule != DO_NOTHING: if issubclass(new_class, EmbeddedDocument): - msg = ("Reverse delete rules are not supported for " - "EmbeddedDocuments (field: %s)" % field.name) + msg = ('Reverse delete rules are not supported for ' + 'EmbeddedDocuments (field: %s)' % field.name) raise InvalidDocumentError(msg) f.document_type.register_delete_rule(new_class, field.name, delete_rule) if (field.name and hasattr(Document, field.name) and EmbeddedDocument not in new_class.mro()): - msg = ("%s is a document method and not a valid " - "field name" % field.name) + msg = ('%s is a document method and not a valid ' + 'field name' % field.name) raise InvalidDocumentError(msg) return new_class @@ -271,6 +272,11 @@ def __new__(cls, name, bases, attrs): 'index_drop_dups': False, 'index_opts': None, 'delete_rules': None, + + # allow_inheritance can be True, False, and None. True means + # "allow inheritance", False means "don't allow inheritance", + # None means "do whatever your parent does, or don't allow + # inheritance if you're a top-level class". 'allow_inheritance': None, } attrs['_is_base_cls'] = True @@ -303,7 +309,7 @@ def __new__(cls, name, bases, attrs): # If parent wasn't an abstract class if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and not parent_doc_cls._meta.get('abstract', True)): - msg = "Trying to set a collection on a subclass (%s)" % name + msg = 'Trying to set a collection on a subclass (%s)' % name warnings.warn(msg, SyntaxWarning) del attrs['_meta']['collection'] @@ -311,7 +317,7 @@ def __new__(cls, name, bases, attrs): if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'): if (parent_doc_cls and not parent_doc_cls._meta.get('abstract', False)): - msg = "Abstract document cannot have non-abstract base" + msg = 'Abstract document cannot have non-abstract base' raise ValueError(msg) return super_new(cls, name, bases, attrs) @@ -334,12 +340,16 @@ def __new__(cls, name, bases, attrs): meta.merge(attrs.get('_meta', {})) # Top level meta - # Only simple classes (direct subclasses of Document) - # may set allow_inheritance to False + # Only simple classes (i.e. direct subclasses of Document) may set + # allow_inheritance to False. If the base Document allows inheritance, + # none of its subclasses can override allow_inheritance to False. simple_class = all([b._meta.get('abstract') for b in flattened_bases if hasattr(b, '_meta')]) - if (not simple_class and meta['allow_inheritance'] is False and - not meta['abstract']): + if ( + not simple_class and + meta['allow_inheritance'] is False and + not meta['abstract'] + ): raise ValueError('Only direct subclasses of Document may set ' '"allow_inheritance" to False') diff --git a/mongoengine/connection.py b/mongoengine/connection.py index ee21ba90c..bb353cffd 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,7 +1,9 @@ from pymongo import MongoClient, ReadPreference, uri_parser -from mongoengine.python_support import (IS_PYMONGO_3, str_types) +import six -__all__ = ['ConnectionError', 'connect', 'register_connection', +from mongoengine.python_support import IS_PYMONGO_3 + +__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', 'DEFAULT_CONNECTION_NAME'] @@ -14,7 +16,10 @@ READ_PREFERENCE = False -class ConnectionError(Exception): +class MongoEngineConnectionError(Exception): + """Error raised when the database connection can't be established or + when a connection with a requested alias can't be retrieved. + """ pass @@ -50,8 +55,6 @@ def register_connection(alias, name=None, host=None, port=None, .. versionchanged:: 0.10.6 - added mongomock support """ - global _connection_settings - conn_settings = { 'name': name or 'test', 'host': host or 'localhost', @@ -66,7 +69,7 @@ def register_connection(alias, name=None, host=None, port=None, # Handle uri style connections conn_host = conn_settings['host'] # host can be a list or a string, so if string, force to a list - if isinstance(conn_host, str_types): + if isinstance(conn_host, six.string_types): conn_host = [conn_host] resolved_hosts = [] @@ -111,9 +114,7 @@ def register_connection(alias, name=None, host=None, port=None, def disconnect(alias=DEFAULT_CONNECTION_NAME): - global _connections - global _dbs - + """Close the connection with a given alias.""" if alias in _connections: get_connection(alias=alias).close() del _connections[alias] @@ -122,71 +123,100 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME): def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): - global _connections + """Return a connection with a given alias.""" + # Connect to the database if not already connected if reconnect: disconnect(alias) - if alias not in _connections: - if alias not in _connection_settings: - msg = 'Connection with alias "%s" has not been defined' % alias - if alias == DEFAULT_CONNECTION_NAME: - msg = 'You have not defined a default connection' - raise ConnectionError(msg) - conn_settings = _connection_settings[alias].copy() - - conn_settings.pop('name', None) - conn_settings.pop('username', None) - conn_settings.pop('password', None) - conn_settings.pop('authentication_source', None) - conn_settings.pop('authentication_mechanism', None) - - is_mock = conn_settings.pop('is_mock', None) - if is_mock: - # Use MongoClient from mongomock - try: - import mongomock - except ImportError: - raise RuntimeError('You need mongomock installed ' - 'to mock MongoEngine.') - connection_class = mongomock.MongoClient - else: - # Use MongoClient from pymongo - connection_class = MongoClient + # If the requested alias already exists in the _connections list, return + # it immediately. + if alias in _connections: + return _connections[alias] + # Validate that the requested alias exists in the _connection_settings. + # Raise MongoEngineConnectionError if it doesn't. + if alias not in _connection_settings: + if alias == DEFAULT_CONNECTION_NAME: + msg = 'You have not defined a default connection' + else: + msg = 'Connection with alias "%s" has not been defined' % alias + raise MongoEngineConnectionError(msg) + + def _clean_settings(settings_dict): + irrelevant_fields = set([ + 'name', 'username', 'password', 'authentication_source', + 'authentication_mechanism' + ]) + return { + k: v for k, v in settings_dict.items() + if k not in irrelevant_fields + } + + # Retrieve a copy of the connection settings associated with the requested + # alias and remove the database name and authentication info (we don't + # care about them at this point). + conn_settings = _clean_settings(_connection_settings[alias].copy()) + + # Determine if we should use PyMongo's or mongomock's MongoClient. + is_mock = conn_settings.pop('is_mock', False) + if is_mock: + try: + import mongomock + except ImportError: + raise RuntimeError('You need mongomock installed to mock ' + 'MongoEngine.') + connection_class = mongomock.MongoClient + else: + connection_class = MongoClient + + # Handle replica set connections if 'replicaSet' in conn_settings: + # Discard port since it can't be used on MongoReplicaSetClient conn_settings.pop('port', None) - # Discard replicaSet if not base string - if not isinstance(conn_settings['replicaSet'], basestring): - conn_settings.pop('replicaSet', None) + + # Discard replicaSet if it's not a string + if not isinstance(conn_settings['replicaSet'], six.string_types): + del conn_settings['replicaSet'] + + # For replica set connections with PyMongo 2.x, use + # MongoReplicaSetClient. + # TODO remove this once we stop supporting PyMongo 2.x. if not IS_PYMONGO_3: connection_class = MongoReplicaSetClient conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) + # Iterate over all of the connection settings and if a connection with + # the same parameters is already established, use it instead of creating + # a new one. + existing_connection = None + connection_settings_iterator = ( + (db_alias, settings.copy()) + for db_alias, settings in _connection_settings.items() + ) + for db_alias, connection_settings in connection_settings_iterator: + connection_settings = _clean_settings(connection_settings) + if conn_settings == connection_settings and _connections.get(db_alias): + existing_connection = _connections[db_alias] + break + + # If an existing connection was found, assign it to the new alias + if existing_connection: + _connections[alias] = existing_connection + else: + # Otherwise, create the new connection for this alias. Raise + # MongoEngineConnectionError if it can't be established. try: - connection = None - # check for shared connections - connection_settings_iterator = ( - (db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems()) - for db_alias, connection_settings in connection_settings_iterator: - connection_settings.pop('name', None) - connection_settings.pop('username', None) - connection_settings.pop('password', None) - connection_settings.pop('authentication_source', None) - connection_settings.pop('authentication_mechanism', None) - if conn_settings == connection_settings and _connections.get(db_alias, None): - connection = _connections[db_alias] - break - - _connections[alias] = connection if connection else connection_class(**conn_settings) - except Exception, e: - raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) + _connections[alias] = connection_class(**conn_settings) + except Exception as e: + raise MongoEngineConnectionError( + 'Cannot connect to database %s :\n%s' % (alias, e)) + return _connections[alias] def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False): - global _dbs if reconnect: disconnect(alias) @@ -217,7 +247,6 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): .. versionchanged:: 0.6 - added multiple database support. """ - global _connections if alias not in _connections: register_connection(alias, db, **kwargs) diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index cc8600660..c477575e8 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -2,12 +2,12 @@ from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db -__all__ = ("switch_db", "switch_collection", "no_dereference", - "no_sub_classes", "query_counter") +__all__ = ('switch_db', 'switch_collection', 'no_dereference', + 'no_sub_classes', 'query_counter') class switch_db(object): - """ switch_db alias context manager. + """switch_db alias context manager. Example :: @@ -18,15 +18,14 @@ class switch_db(object): class Group(Document): name = StringField() - Group(name="test").save() # Saves in the default db + Group(name='test').save() # Saves in the default db with switch_db(Group, 'testdb-1') as Group: - Group(name="hello testdb!").save() # Saves in testdb-1 - + Group(name='hello testdb!').save() # Saves in testdb-1 """ def __init__(self, cls, db_alias): - """ Construct the switch_db context manager + """Construct the switch_db context manager :param cls: the class to change the registered db :param db_alias: the name of the specific database to use @@ -34,37 +33,36 @@ def __init__(self, cls, db_alias): self.cls = cls self.collection = cls._get_collection() self.db_alias = db_alias - self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) + self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME) def __enter__(self): - """ change the db_alias and clear the cached collection """ - self.cls._meta["db_alias"] = self.db_alias + """Change the db_alias and clear the cached collection.""" + self.cls._meta['db_alias'] = self.db_alias self.cls._collection = None return self.cls def __exit__(self, t, value, traceback): - """ Reset the db_alias and collection """ - self.cls._meta["db_alias"] = self.ori_db_alias + """Reset the db_alias and collection.""" + self.cls._meta['db_alias'] = self.ori_db_alias self.cls._collection = self.collection class switch_collection(object): - """ switch_collection alias context manager. + """switch_collection alias context manager. Example :: class Group(Document): name = StringField() - Group(name="test").save() # Saves in the default db + Group(name='test').save() # Saves in the default db with switch_collection(Group, 'group1') as Group: - Group(name="hello testdb!").save() # Saves in group1 collection - + Group(name='hello testdb!').save() # Saves in group1 collection """ def __init__(self, cls, collection_name): - """ Construct the switch_collection context manager + """Construct the switch_collection context manager. :param cls: the class to change the registered db :param collection_name: the name of the collection to use @@ -75,7 +73,7 @@ def __init__(self, cls, collection_name): self.collection_name = collection_name def __enter__(self): - """ change the _get_collection_name and clear the cached collection """ + """Change the _get_collection_name and clear the cached collection.""" @classmethod def _get_collection_name(cls): @@ -86,24 +84,23 @@ def _get_collection_name(cls): return self.cls def __exit__(self, t, value, traceback): - """ Reset the collection """ + """Reset the collection.""" self.cls._collection = self.ori_collection self.cls._get_collection_name = self.ori_get_collection_name class no_dereference(object): - """ no_dereference context manager. + """no_dereference context manager. Turns off all dereferencing in Documents for the duration of the context manager:: with no_dereference(Group) as Group: Group.objects.find() - """ def __init__(self, cls): - """ Construct the no_dereference context manager. + """Construct the no_dereference context manager. :param cls: the class to turn dereferencing off on """ @@ -119,103 +116,102 @@ def __init__(self, cls): ComplexBaseField))] def __enter__(self): - """ change the objects default and _auto_dereference values""" + """Change the objects default and _auto_dereference values.""" for field in self.deref_fields: self.cls._fields[field]._auto_dereference = False return self.cls def __exit__(self, t, value, traceback): - """ Reset the default and _auto_dereference values""" + """Reset the default and _auto_dereference values.""" for field in self.deref_fields: self.cls._fields[field]._auto_dereference = True return self.cls class no_sub_classes(object): - """ no_sub_classes context manager. + """no_sub_classes context manager. Only returns instances of this class and no sub (inherited) classes:: with no_sub_classes(Group) as Group: Group.objects.find() - """ def __init__(self, cls): - """ Construct the no_sub_classes context manager. + """Construct the no_sub_classes context manager. :param cls: the class to turn querying sub classes on """ self.cls = cls def __enter__(self): - """ change the objects default and _auto_dereference values""" + """Change the objects default and _auto_dereference values.""" self.cls._all_subclasses = self.cls._subclasses self.cls._subclasses = (self.cls,) return self.cls def __exit__(self, t, value, traceback): - """ Reset the default and _auto_dereference values""" + """Reset the default and _auto_dereference values.""" self.cls._subclasses = self.cls._all_subclasses delattr(self.cls, '_all_subclasses') return self.cls class query_counter(object): - """ Query_counter context manager to get the number of queries. """ + """Query_counter context manager to get the number of queries.""" def __init__(self): - """ Construct the query_counter. """ + """Construct the query_counter.""" self.counter = 0 self.db = get_db() def __enter__(self): - """ On every with block we need to drop the profile collection. """ + """On every with block we need to drop the profile collection.""" self.db.set_profiling_level(0) self.db.system.profile.drop() self.db.set_profiling_level(2) return self def __exit__(self, t, value, traceback): - """ Reset the profiling level. """ + """Reset the profiling level.""" self.db.set_profiling_level(0) def __eq__(self, value): - """ == Compare querycounter. """ + """== Compare querycounter.""" counter = self._get_count() return value == counter def __ne__(self, value): - """ != Compare querycounter. """ + """!= Compare querycounter.""" return not self.__eq__(value) def __lt__(self, value): - """ < Compare querycounter. """ + """< Compare querycounter.""" return self._get_count() < value def __le__(self, value): - """ <= Compare querycounter. """ + """<= Compare querycounter.""" return self._get_count() <= value def __gt__(self, value): - """ > Compare querycounter. """ + """> Compare querycounter.""" return self._get_count() > value def __ge__(self, value): - """ >= Compare querycounter. """ + """>= Compare querycounter.""" return self._get_count() >= value def __int__(self): - """ int representation. """ + """int representation.""" return self._get_count() def __repr__(self): - """ repr query_counter as the number of queries. """ + """repr query_counter as the number of queries.""" return u"%s" % self._get_count() def _get_count(self): - """ Get the number of queries. """ - ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} + """Get the number of queries.""" + ignore_query = {'ns': {'$ne': '%s.system.indexes' % self.db.name}} count = self.db.system.profile.find(ignore_query).count() - self.counter self.counter += 1 return count diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index c16028d2a..59204d4d6 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,14 +1,12 @@ from bson import DBRef, SON +import six -from .base import ( - BaseDict, BaseList, EmbeddedDocumentList, - TopLevelDocumentMetaclass, get_document -) -from .connection import get_db -from .document import Document, EmbeddedDocument -from .fields import DictField, ListField, MapField, ReferenceField -from .python_support import txt_type -from .queryset import QuerySet +from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, + TopLevelDocumentMetaclass, get_document) +from mongoengine.connection import get_db +from mongoengine.document import Document, EmbeddedDocument +from mongoengine.fields import DictField, ListField, MapField, ReferenceField +from mongoengine.queryset import QuerySet class DeReference(object): @@ -25,7 +23,7 @@ def __call__(self, items, max_depth=1, instance=None, name=None): :class:`~mongoengine.base.ComplexBaseField` :param get: A boolean determining if being called by __get__ """ - if items is None or isinstance(items, basestring): + if items is None or isinstance(items, six.string_types): return items # cheapest way to convert a queryset to a list @@ -68,11 +66,11 @@ def _get_items(items): items = _get_items(items) else: - items = dict([ - (k, field.to_python(v)) - if not isinstance(v, (DBRef, Document)) else (k, v) - for k, v in items.iteritems()] - ) + items = { + k: (v if isinstance(v, (DBRef, Document)) + else field.to_python(v)) + for k, v in items.iteritems() + } self.reference_map = self._find_references(items) self.object_map = self._fetch_objects(doc_type=doc_type) @@ -90,14 +88,14 @@ def _find_references(self, items, depth=0): return reference_map # Determine the iterator to use - if not hasattr(items, 'items'): - iterator = enumerate(items) + if isinstance(items, dict): + iterator = items.values() else: - iterator = items.iteritems() + iterator = items # Recursively find dbreferences depth += 1 - for k, item in iterator: + for item in iterator: if isinstance(item, (Document, EmbeddedDocument)): for field_name, field in item._fields.iteritems(): v = item._data.get(field_name, None) @@ -151,7 +149,7 @@ def _fetch_objects(self, doc_type=None): references = get_db()[collection].find({'_id': {'$in': refs}}) for ref in references: if '_cls' in ref: - doc = get_document(ref["_cls"])._from_son(ref) + doc = get_document(ref['_cls'])._from_son(ref) elif doc_type is None: doc = get_document( ''.join(x.capitalize() @@ -218,7 +216,7 @@ def _attach_objects(self, items, depth=0, instance=None, name=None): if k in self.object_map and not is_list: data[k] = self.object_map[k] elif isinstance(v, (Document, EmbeddedDocument)): - for field_name, field in v._fields.iteritems(): + for field_name in v._fields: v = data[k]._data.get(field_name, None) if isinstance(v, DBRef): data[k]._data[field_name] = self.object_map.get( @@ -227,7 +225,7 @@ def _attach_objects(self, items, depth=0, instance=None, name=None): data[k]._data[field_name] = self.object_map.get( (v['_ref'].collection, v['_ref'].id), v) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name) + item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: item_name = '%s.%s' % (name, k) if name else name diff --git a/mongoengine/document.py b/mongoengine/document.py index 91dcafc47..e86a45d9b 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -4,18 +4,12 @@ from bson.dbref import DBRef import pymongo from pymongo.read_preferences import ReadPreference +import six from mongoengine import signals -from mongoengine.base import ( - ALLOW_INHERITANCE, - BaseDict, - BaseDocument, - BaseList, - DocumentMetaclass, - EmbeddedDocumentList, - TopLevelDocumentMetaclass, - get_document -) +from mongoengine.base import (BaseDict, BaseDocument, BaseList, + DocumentMetaclass, EmbeddedDocumentList, + TopLevelDocumentMetaclass, get_document) from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.context_managers import switch_collection, switch_db @@ -31,12 +25,10 @@ def includes_cls(fields): - """ Helper function used for ensuring and comparing indexes - """ - + """Helper function used for ensuring and comparing indexes.""" first_field = None if len(fields): - if isinstance(fields[0], basestring): + if isinstance(fields[0], six.string_types): first_field = fields[0] elif isinstance(fields[0], (list, tuple)) and len(fields[0]): first_field = fields[0][0] @@ -57,9 +49,8 @@ class EmbeddedDocument(BaseDocument): to create a specialised version of the embedded document that will be stored in the same collection. To facilitate this behaviour a `_cls` field is added to documents (hidden though the MongoEngine interface). - To disable this behaviour and remove the dependence on the presence of - `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` - dictionary. + To enable this behaviour set :attr:`allow_inheritance` to ``True`` in the + :attr:`meta` dictionary. """ __slots__ = ('_instance', ) @@ -82,6 +73,15 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def to_mongo(self, *args, **kwargs): + data = super(EmbeddedDocument, self).to_mongo(*args, **kwargs) + + # remove _id from the SON if it's in it and it's None + if '_id' in data and data['_id'] is None: + del data['_id'] + + return data + def save(self, *args, **kwargs): self._instance.save(*args, **kwargs) @@ -106,9 +106,8 @@ class Document(BaseDocument): create a specialised version of the document that will be stored in the same collection. To facilitate this behaviour a `_cls` field is added to documents (hidden though the MongoEngine interface). - To disable this behaviour and remove the dependence on the presence of - `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` - dictionary. + To enable this behaviourset :attr:`allow_inheritance` to ``True`` in the + :attr:`meta` dictionary. A :class:`~mongoengine.Document` may use a **Capped Collection** by specifying :attr:`max_documents` and :attr:`max_size` in the :attr:`meta` @@ -149,26 +148,22 @@ class Document(BaseDocument): __slots__ = ('__objects',) - def pk(): - """Primary key alias - """ - - def fget(self): - if 'id_field' not in self._meta: - return None - return getattr(self, self._meta['id_field']) - - def fset(self, value): - return setattr(self, self._meta['id_field'], value) - - return property(fget, fset) + @property + def pk(self): + """Get the primary key.""" + if 'id_field' not in self._meta: + return None + return getattr(self, self._meta['id_field']) - pk = pk() + @pk.setter + def pk(self, value): + """Set the primary key.""" + return setattr(self, self._meta['id_field'], value) @classmethod def _get_db(cls): """Some Model using other db_alias""" - return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)) + return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) @classmethod def _get_collection(cls): @@ -211,7 +206,20 @@ def _get_collection(cls): cls.ensure_indexes() return cls._collection - def modify(self, query={}, **update): + def to_mongo(self, *args, **kwargs): + data = super(Document, self).to_mongo(*args, **kwargs) + + # If '_id' is None, try and set it from self._data. If that + # doesn't exist either, remote '_id' from the SON completely. + if data['_id'] is None: + if self._data.get('id') is None: + del data['_id'] + else: + data['_id'] = self._data['id'] + + return data + + def modify(self, query=None, **update): """Perform an atomic update of the document in the database and reload the document object using updated version. @@ -225,17 +233,19 @@ def modify(self, query={}, **update): database matches the query :param update: Django-style update keyword arguments """ + if query is None: + query = {} if self.pk is None: - raise InvalidDocumentError("The document does not have a primary key.") + raise InvalidDocumentError('The document does not have a primary key.') - id_field = self._meta["id_field"] + id_field = self._meta['id_field'] query = query.copy() if isinstance(query, dict) else query.to_query(self) if id_field not in query: query[id_field] = self.pk elif query[id_field] != self.pk: - raise InvalidQueryError("Invalid document modify query: it must modify only this document.") + raise InvalidQueryError('Invalid document modify query: it must modify only this document.') updated = self._qs(**query).modify(new=True, **update) if updated is None: @@ -310,7 +320,7 @@ def save(self, force_insert=False, validate=True, clean=True, self.validate(clean=clean) if write_concern is None: - write_concern = {"w": 1} + write_concern = {'w': 1} doc = self.to_mongo() @@ -347,7 +357,7 @@ def save(self, force_insert=False, validate=True, clean=True, else: select_dict = {} select_dict['_id'] = object_id - shard_key = self.__class__._meta.get('shard_key', tuple()) + shard_key = self._meta.get('shard_key', tuple()) for k in shard_key: path = self._lookup_field(k.split('.')) actual_key = [p.db_field for p in path] @@ -358,7 +368,7 @@ def save(self, force_insert=False, validate=True, clean=True, def is_new_object(last_error): if last_error is not None: - updated = last_error.get("updatedExisting") + updated = last_error.get('updatedExisting') if updated is not None: return not updated return created @@ -366,14 +376,14 @@ def is_new_object(last_error): update_query = {} if updates: - update_query["$set"] = updates + update_query['$set'] = updates if removals: - update_query["$unset"] = removals + update_query['$unset'] = removals if updates or removals: upsert = save_condition is None last_error = collection.update(select_dict, update_query, upsert=upsert, **write_concern) - if not upsert and last_error["n"] == 0: + if not upsert and last_error['n'] == 0: raise SaveConditionError('Race condition preventing' ' document update detected') created = is_new_object(last_error) @@ -384,26 +394,27 @@ def is_new_object(last_error): if cascade: kwargs = { - "force_insert": force_insert, - "validate": validate, - "write_concern": write_concern, - "cascade": cascade + 'force_insert': force_insert, + 'validate': validate, + 'write_concern': write_concern, + 'cascade': cascade } if cascade_kwargs: # Allow granular control over cascades kwargs.update(cascade_kwargs) kwargs['_refs'] = _refs self.cascade_save(**kwargs) - except pymongo.errors.DuplicateKeyError, err: + except pymongo.errors.DuplicateKeyError as err: message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - except pymongo.errors.OperationFailure, err: + raise NotUniqueError(message % six.text_type(err)) + except pymongo.errors.OperationFailure as err: message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', unicode(err)): + if re.match('^E1100[01] duplicate key', six.text_type(err)): # E11000 - duplicate key error index # E11001 - duplicate key on update message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - raise OperationError(message % unicode(err)) + raise NotUniqueError(message % six.text_type(err)) + raise OperationError(message % six.text_type(err)) + id_field = self._meta['id_field'] if created or id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) @@ -414,10 +425,11 @@ def is_new_object(last_error): self._created = False return self - def cascade_save(self, *args, **kwargs): - """Recursively saves any references / - generic references on the document""" - _refs = kwargs.get('_refs', []) or [] + def cascade_save(self, **kwargs): + """Recursively save any references and generic references on the + document. + """ + _refs = kwargs.get('_refs') or [] ReferenceField = _import_class('ReferenceField') GenericReferenceField = _import_class('GenericReferenceField') @@ -443,16 +455,17 @@ def cascade_save(self, *args, **kwargs): @property def _qs(self): - """ - Returns the queryset to use for updating / reloading / deletions - """ + """Return the queryset to use for updating / reloading / deletions.""" if not hasattr(self, '__objects'): self.__objects = QuerySet(self, self._get_collection()) return self.__objects @property def _object_key(self): - """Dict to identify object in collection + """Get the query dict that can be used to fetch this object from + the database. Most of the time it's a simple PK lookup, but in + case of a sharded collection with a compound shard key, it can + contain a more complex query. """ select_dict = {'pk': self.pk} shard_key = self.__class__._meta.get('shard_key', tuple()) @@ -475,8 +488,8 @@ def update(self, **kwargs): if self.pk is None: if kwargs.get('upsert', False): query = self.to_mongo() - if "_cls" in query: - del query["_cls"] + if '_cls' in query: + del query['_cls'] return self._qs.filter(**query).update_one(**kwargs) else: raise OperationError( @@ -513,7 +526,7 @@ def delete(self, signal_kwargs=None, **write_concern): try: self._qs.filter( **self._object_key).delete(write_concern=write_concern, _from_doc_delete=True) - except pymongo.errors.OperationFailure, err: + except pymongo.errors.OperationFailure as err: message = u'Could not delete document (%s)' % err.message raise OperationError(message) signals.post_delete.send(self.__class__, document=self, **signal_kwargs) @@ -601,11 +614,12 @@ def reload(self, *fields, **kwargs): if fields and isinstance(fields[0], int): max_depth = fields[0] fields = fields[1:] - elif "max_depth" in kwargs: - max_depth = kwargs["max_depth"] + elif 'max_depth' in kwargs: + max_depth = kwargs['max_depth'] if self.pk is None: - raise self.DoesNotExist("Document does not exist") + raise self.DoesNotExist('Document does not exist') + obj = self._qs.read_preference(ReadPreference.PRIMARY).filter( **self._object_key).only(*fields).limit( 1).select_related(max_depth=max_depth) @@ -613,7 +627,7 @@ def reload(self, *fields, **kwargs): if obj: obj = obj[0] else: - raise self.DoesNotExist("Document does not exist") + raise self.DoesNotExist('Document does not exist') for field in obj._data: if not fields or field in fields: @@ -656,7 +670,7 @@ def to_dbref(self): """Returns an instance of :class:`~bson.dbref.DBRef` useful in `__raw__` queries.""" if self.pk is None: - msg = "Only saved documents can have a valid dbref" + msg = 'Only saved documents can have a valid dbref' raise OperationError(msg) return DBRef(self.__class__._get_collection_name(), self.pk) @@ -711,7 +725,7 @@ def create_index(cls, keys, background=False, **kwargs): fields = index_spec.pop('fields') drop_dups = kwargs.get('drop_dups', False) if IS_PYMONGO_3 and drop_dups: - msg = "drop_dups is deprecated and is removed when using PyMongo 3+." + msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) elif not IS_PYMONGO_3: index_spec['drop_dups'] = drop_dups @@ -737,7 +751,7 @@ def ensure_index(cls, key_or_list, drop_dups=False, background=False, will be removed if PyMongo3+ is used """ if IS_PYMONGO_3 and drop_dups: - msg = "drop_dups is deprecated and is removed when using PyMongo 3+." + msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) elif not IS_PYMONGO_3: kwargs.update({'drop_dups': drop_dups}) @@ -757,7 +771,7 @@ def ensure_indexes(cls): index_opts = cls._meta.get('index_opts') or {} index_cls = cls._meta.get('index_cls', True) if IS_PYMONGO_3 and drop_dups: - msg = "drop_dups is deprecated and is removed when using PyMongo 3+." + msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) collection = cls._get_collection() @@ -795,8 +809,7 @@ def ensure_indexes(cls): # If _cls is being used (for polymorphism), it needs an index, # only if another index doesn't begin with _cls - if (index_cls and not cls_indexed and - cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): + if index_cls and not cls_indexed and cls._meta.get('allow_inheritance'): # we shouldn't pass 'cls' to the collection.ensureIndex options # because of https://jira.mongodb.org/browse/SERVER-769 @@ -866,16 +879,15 @@ def get_indexes_spec(cls): # finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed if [(u'_id', 1)] not in indexes: indexes.append([(u'_id', 1)]) - if (cls._meta.get('index_cls', True) and - cls._meta.get('allow_inheritance', ALLOW_INHERITANCE) is True): + if cls._meta.get('index_cls', True) and cls._meta.get('allow_inheritance'): indexes.append([(u'_cls', 1)]) return indexes @classmethod def compare_indexes(cls): - """ Compares the indexes defined in MongoEngine with the ones existing - in the database. Returns any missing/extra indexes. + """ Compares the indexes defined in MongoEngine with the ones + existing in the database. Returns any missing/extra indexes. """ required = cls.list_indexes() @@ -919,8 +931,9 @@ class DynamicDocument(Document): _dynamic = True def __delattr__(self, *args, **kwargs): - """Deletes the attribute by setting to None and allowing _delta to unset - it""" + """Delete the attribute by setting to None and allowing _delta + to unset it. + """ field_name = args[0] if field_name in self._dynamic_fields: setattr(self, field_name, None) @@ -942,8 +955,9 @@ class DynamicEmbeddedDocument(EmbeddedDocument): _dynamic = True def __delattr__(self, *args, **kwargs): - """Deletes the attribute by setting to None and allowing _delta to unset - it""" + """Delete the attribute by setting to None and allowing _delta + to unset it. + """ field_name = args[0] if field_name in self._fields: default = self._fields[field_name].default @@ -985,10 +999,10 @@ def object(self): try: self.key = id_field_type(self.key) except Exception: - raise Exception("Could not cast key as %s" % + raise Exception('Could not cast key as %s' % id_field_type.__name__) - if not hasattr(self, "_key_object"): + if not hasattr(self, '_key_object'): self._key_object = self._document.objects.with_id(self.key) return self._key_object return self._key_object diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 15830b5c3..2549e8228 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -1,7 +1,6 @@ from collections import defaultdict -from mongoengine.python_support import txt_type - +import six __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', @@ -71,13 +70,13 @@ class ValidationError(AssertionError): field_name = None _message = None - def __init__(self, message="", **kwargs): + def __init__(self, message='', **kwargs): self.errors = kwargs.get('errors', {}) self.field_name = kwargs.get('field_name') self.message = message def __str__(self): - return txt_type(self.message) + return six.text_type(self.message) def __repr__(self): return '%s(%s,)' % (self.__class__.__name__, self.message) @@ -111,17 +110,20 @@ def build_dict(source): errors_dict = {} if not source: return errors_dict + if isinstance(source, dict): for field_name, error in source.iteritems(): errors_dict[field_name] = build_dict(error) elif isinstance(source, ValidationError) and source.errors: return build_dict(source.errors) else: - return unicode(source) + return six.text_type(source) + return errors_dict if not self.errors: return {} + return build_dict(self.errors) def _format_errors(self): @@ -134,10 +136,10 @@ def generate_key(value, prefix=''): value = ' '.join( [generate_key(v, k) for k, v in value.iteritems()]) - results = "%s.%s" % (prefix, value) if prefix else value + results = '%s.%s' % (prefix, value) if prefix else value return results error_dict = defaultdict(list) for k, v in self.to_dict().iteritems(): error_dict[generate_key(v)].append(k) - return ' '.join(["%s: %s" % (k, v) for k, v in error_dict.iteritems()]) + return ' '.join(['%s: %s' % (k, v) for k, v in error_dict.iteritems()]) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index e9cc974cd..d812a762a 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -3,7 +3,6 @@ import itertools import re import time -import urllib2 import uuid import warnings from operator import itemgetter @@ -25,13 +24,13 @@ except ImportError: Int64 = long -from .base import (BaseDocument, BaseField, ComplexBaseField, GeoJsonBaseField, - ObjectIdField, get_document) -from .connection import DEFAULT_CONNECTION_NAME, get_db -from .document import Document, EmbeddedDocument -from .errors import DoesNotExist, ValidationError -from .python_support import PY3, StringIO, bin_type, str_types, txt_type -from .queryset import DO_NOTHING, QuerySet +from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField, + GeoJsonBaseField, ObjectIdField, get_document) +from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db +from mongoengine.document import Document, EmbeddedDocument +from mongoengine.errors import DoesNotExist, ValidationError +from mongoengine.python_support import StringIO +from mongoengine.queryset import DO_NOTHING, QuerySet try: from PIL import Image, ImageOps @@ -39,7 +38,7 @@ Image = None ImageOps = None -__all__ = [ +__all__ = ( 'StringField', 'URLField', 'EmailField', 'IntField', 'LongField', 'FloatField', 'DecimalField', 'BooleanField', 'DateTimeField', 'ComplexDateTimeField', 'EmbeddedDocumentField', 'ObjectIdField', @@ -50,14 +49,14 @@ 'FileField', 'ImageGridFsProxy', 'ImproperlyConfigured', 'ImageField', 'GeoPointField', 'PointField', 'LineStringField', 'PolygonField', 'SequenceField', 'UUIDField', 'MultiPointField', 'MultiLineStringField', - 'MultiPolygonField', 'GeoJsonBaseField'] + 'MultiPolygonField', 'GeoJsonBaseField' +) RECURSIVE_REFERENCE_CONSTANT = 'self' class StringField(BaseField): - """A unicode string field. - """ + """A unicode string field.""" def __init__(self, regex=None, max_length=None, min_length=None, **kwargs): self.regex = re.compile(regex) if regex else None @@ -66,7 +65,7 @@ def __init__(self, regex=None, max_length=None, min_length=None, **kwargs): super(StringField, self).__init__(**kwargs) def to_python(self, value): - if isinstance(value, unicode): + if isinstance(value, six.text_type): return value try: value = value.decode('utf-8') @@ -75,7 +74,7 @@ def to_python(self, value): return value def validate(self, value): - if not isinstance(value, basestring): + if not isinstance(value, six.string_types): self.error('StringField only accepts string values') if self.max_length is not None and len(value) > self.max_length: @@ -91,7 +90,7 @@ def lookup_member(self, member_name): return None def prepare_query_value(self, op, value): - if not isinstance(op, basestring): + if not isinstance(op, six.string_types): return value if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): @@ -148,17 +147,6 @@ def validate(self, value): self.error('Invalid URL: {}'.format(value)) return - if self.verify_exists: - warnings.warn( - "The URLField verify_exists argument has intractable security " - "and performance issues. Accordingly, it has been deprecated.", - DeprecationWarning) - try: - request = urllib2.Request(value) - urllib2.urlopen(request) - except Exception, e: - self.error('This URL appears to be a broken link: %s' % e) - class EmailField(StringField): """A field that validates input as an email address. @@ -182,8 +170,7 @@ def validate(self, value): class IntField(BaseField): - """An 32-bit integer field. - """ + """32-bit integer field.""" def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value @@ -216,8 +203,7 @@ def prepare_query_value(self, op, value): class LongField(BaseField): - """An 64-bit integer field. - """ + """64-bit integer field.""" def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value @@ -253,8 +239,7 @@ def prepare_query_value(self, op, value): class FloatField(BaseField): - """An floating point number field. - """ + """Floating point number field.""" def __init__(self, min_value=None, max_value=None, **kwargs): self.min_value, self.max_value = min_value, max_value @@ -291,7 +276,7 @@ def prepare_query_value(self, op, value): class DecimalField(BaseField): - """A fixed-point decimal number field. + """Fixed-point decimal number field. .. versionchanged:: 0.8 .. versionadded:: 0.3 @@ -332,25 +317,25 @@ def to_python(self, value): # Convert to string for python 2.6 before casting to Decimal try: - value = decimal.Decimal("%s" % value) + value = decimal.Decimal('%s' % value) except decimal.InvalidOperation: return value - return value.quantize(decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding) + return value.quantize(decimal.Decimal('.%s' % ('0' * self.precision)), rounding=self.rounding) def to_mongo(self, value): if value is None: return value if self.force_string: - return unicode(self.to_python(value)) + return six.text_type(self.to_python(value)) return float(self.to_python(value)) def validate(self, value): if not isinstance(value, decimal.Decimal): - if not isinstance(value, basestring): - value = unicode(value) + if not isinstance(value, six.string_types): + value = six.text_type(value) try: value = decimal.Decimal(value) - except Exception, exc: + except Exception as exc: self.error('Could not convert value to decimal: %s' % exc) if self.min_value is not None and value < self.min_value: @@ -364,7 +349,7 @@ def prepare_query_value(self, op, value): class BooleanField(BaseField): - """A boolean field type. + """Boolean field type. .. versionadded:: 0.1.2 """ @@ -382,7 +367,7 @@ def validate(self, value): class DateTimeField(BaseField): - """A datetime field. + """Datetime field. Uses the python-dateutil library if available alternatively use time.strptime to parse the dates. Note: python-dateutil's parser is fully featured and when @@ -410,7 +395,7 @@ def to_mongo(self, value): if callable(value): return value() - if not isinstance(value, basestring): + if not isinstance(value, six.string_types): return None # Attempt to parse a datetime: @@ -537,16 +522,19 @@ class EmbeddedDocumentField(BaseField): """ def __init__(self, document_type, **kwargs): - if not isinstance(document_type, basestring): - if not issubclass(document_type, EmbeddedDocument): - self.error('Invalid embedded document class provided to an ' - 'EmbeddedDocumentField') + if ( + not isinstance(document_type, six.string_types) and + not issubclass(document_type, EmbeddedDocument) + ): + self.error('Invalid embedded document class provided to an ' + 'EmbeddedDocumentField') + self.document_type_obj = document_type super(EmbeddedDocumentField, self).__init__(**kwargs) @property def document_type(self): - if isinstance(self.document_type_obj, basestring): + if isinstance(self.document_type_obj, six.string_types): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: self.document_type_obj = self.owner_document else: @@ -631,7 +619,7 @@ def to_mongo(self, value, use_db_field=True, fields=None): """Convert a Python type to a MongoDB compatible type. """ - if isinstance(value, basestring): + if isinstance(value, six.string_types): return value if hasattr(value, 'to_mongo'): @@ -639,7 +627,7 @@ def to_mongo(self, value, use_db_field=True, fields=None): val = value.to_mongo(use_db_field, fields) # If we its a document thats not inherited add _cls if isinstance(value, Document): - val = {"_ref": value.to_dbref(), "_cls": cls.__name__} + val = {'_ref': value.to_dbref(), '_cls': cls.__name__} if isinstance(value, EmbeddedDocument): val['_cls'] = cls.__name__ return val @@ -650,7 +638,7 @@ def to_mongo(self, value, use_db_field=True, fields=None): is_list = False if not hasattr(value, 'items'): is_list = True - value = dict([(k, v) for k, v in enumerate(value)]) + value = {k: v for k, v in enumerate(value)} data = {} for k, v in value.iteritems(): @@ -674,12 +662,12 @@ def lookup_member(self, member_name): return member_name def prepare_query_value(self, op, value): - if isinstance(value, basestring): + if isinstance(value, six.string_types): return StringField().prepare_query_value(op, value) return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) def validate(self, value, clean=True): - if hasattr(value, "validate"): + if hasattr(value, 'validate'): value.validate(clean=clean) @@ -699,21 +687,27 @@ def __init__(self, field=None, **kwargs): super(ListField, self).__init__(**kwargs) def validate(self, value): - """Make sure that a list of valid fields is being used. - """ + """Make sure that a list of valid fields is being used.""" if (not isinstance(value, (list, tuple, QuerySet)) or - isinstance(value, basestring)): + isinstance(value, six.string_types)): self.error('Only lists and tuples may be used in a list field') super(ListField, self).validate(value) def prepare_query_value(self, op, value): if self.field: - if op in ('set', 'unset', None) and ( - not isinstance(value, basestring) and - not isinstance(value, BaseDocument) and - hasattr(value, '__iter__')): + + # If the value is iterable and it's not a string nor a + # BaseDocument, call prepare_query_value for each of its items. + if ( + op in ('set', 'unset', None) and + hasattr(value, '__iter__') and + not isinstance(value, six.string_types) and + not isinstance(value, BaseDocument) + ): return [self.field.prepare_query_value(op, v) for v in value] + return self.field.prepare_query_value(op, value) + return super(ListField, self).prepare_query_value(op, value) @@ -726,7 +720,6 @@ class EmbeddedDocumentListField(ListField): :class:`~mongoengine.EmbeddedDocument`. .. versionadded:: 0.9 - """ def __init__(self, document_type, **kwargs): @@ -775,17 +768,17 @@ def to_mongo(self, value, use_db_field=True, fields=None): def key_not_string(d): - """ Helper function to recursively determine if any key in a dictionary is - not a string. + """Helper function to recursively determine if any key in a + dictionary is not a string. """ for k, v in d.items(): - if not isinstance(k, basestring) or (isinstance(v, dict) and key_not_string(v)): + if not isinstance(k, six.string_types) or (isinstance(v, dict) and key_not_string(v)): return True def key_has_dot_or_dollar(d): - """ Helper function to recursively determine if any key in a dictionary - contains a dot or a dollar sign. + """Helper function to recursively determine if any key in a + dictionary contains a dot or a dollar sign. """ for k, v in d.items(): if ('.' in k or '$' in k) or (isinstance(v, dict) and key_has_dot_or_dollar(v)): @@ -813,14 +806,13 @@ def __init__(self, basecls=None, field=None, *args, **kwargs): super(DictField, self).__init__(*args, **kwargs) def validate(self, value): - """Make sure that a list of valid fields is being used. - """ + """Make sure that a list of valid fields is being used.""" if not isinstance(value, dict): self.error('Only dictionaries may be used in a DictField') if key_not_string(value): - msg = ("Invalid dictionary key - documents must " - "have only string keys") + msg = ('Invalid dictionary key - documents must ' + 'have only string keys') self.error(msg) if key_has_dot_or_dollar(value): self.error('Invalid dictionary key name - keys may not contain "."' @@ -835,14 +827,15 @@ def prepare_query_value(self, op, value): 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] - if op in match_operators and isinstance(value, basestring): + if op in match_operators and isinstance(value, six.string_types): return StringField().prepare_query_value(op, value) if hasattr(self.field, 'field'): if op in ('set', 'unset') and isinstance(value, dict): - return dict( - (k, self.field.prepare_query_value(op, v)) - for k, v in value.items()) + return { + k: self.field.prepare_query_value(op, v) + for k, v in value.items() + } return self.field.prepare_query_value(op, value) return super(DictField, self).prepare_query_value(op, value) @@ -911,10 +904,12 @@ def __init__(self, document_type, dbref=False, A reference to an abstract document type is always stored as a :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`. """ - if not isinstance(document_type, basestring): - if not issubclass(document_type, (Document, basestring)): - self.error('Argument to ReferenceField constructor must be a ' - 'document class or a string') + if ( + not isinstance(document_type, six.string_types) and + not issubclass(document_type, Document) + ): + self.error('Argument to ReferenceField constructor must be a ' + 'document class or a string') self.dbref = dbref self.document_type_obj = document_type @@ -923,7 +918,7 @@ def __init__(self, document_type, dbref=False, @property def document_type(self): - if isinstance(self.document_type_obj, basestring): + if isinstance(self.document_type_obj, six.string_types): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: self.document_type_obj = self.owner_document else: @@ -931,8 +926,7 @@ def document_type(self): return self.document_type_obj def __get__(self, instance, owner): - """Descriptor to allow lazy dereferencing. - """ + """Descriptor to allow lazy dereferencing.""" if instance is None: # Document class being used rather than a document object return self @@ -989,8 +983,7 @@ def to_mongo(self, document): return id_ def to_python(self, value): - """Convert a MongoDB-compatible type to a Python type. - """ + """Convert a MongoDB-compatible type to a Python type.""" if (not self.dbref and not isinstance(value, (DBRef, Document, EmbeddedDocument))): collection = self.document_type._get_collection_name() @@ -1006,7 +999,7 @@ def prepare_query_value(self, op, value): def validate(self, value): if not isinstance(value, (self.document_type, DBRef)): - self.error("A ReferenceField only accepts DBRef or documents") + self.error('A ReferenceField only accepts DBRef or documents') if isinstance(value, Document) and value.id is None: self.error('You can only reference documents once they have been ' @@ -1030,14 +1023,19 @@ class CachedReferenceField(BaseField): .. versionadded:: 0.9 """ - def __init__(self, document_type, fields=[], auto_sync=True, **kwargs): + def __init__(self, document_type, fields=None, auto_sync=True, **kwargs): """Initialises the Cached Reference Field. :param fields: A list of fields to be cached in document :param auto_sync: if True documents are auto updated. """ - if not isinstance(document_type, basestring) and \ - not issubclass(document_type, (Document, basestring)): + if fields is None: + fields = [] + + if ( + not isinstance(document_type, six.string_types) and + not issubclass(document_type, Document) + ): self.error('Argument to CachedReferenceField constructor must be a' ' document class or a string') @@ -1053,18 +1051,20 @@ def start_listener(self): sender=self.document_type) def on_document_pre_save(self, sender, document, created, **kwargs): - if not created: - update_kwargs = dict( - ('set__%s__%s' % (self.name, k), v) - for k, v in document._delta()[0].items() - if k in self.fields) + if created: + return None - if update_kwargs: - filter_kwargs = {} - filter_kwargs[self.name] = document + update_kwargs = { + 'set__%s__%s' % (self.name, key): val + for key, val in document._delta()[0].items() + if key in self.fields + } + if update_kwargs: + filter_kwargs = {} + filter_kwargs[self.name] = document - self.owner_document.objects( - **filter_kwargs).update(**update_kwargs) + self.owner_document.objects( + **filter_kwargs).update(**update_kwargs) def to_python(self, value): if isinstance(value, dict): @@ -1077,7 +1077,7 @@ def to_python(self, value): @property def document_type(self): - if isinstance(self.document_type_obj, basestring): + if isinstance(self.document_type_obj, six.string_types): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: self.document_type_obj = self.owner_document else: @@ -1117,7 +1117,7 @@ def to_mongo(self, document, use_db_field=True, fields=None): # TODO: should raise here or will fail next statement value = SON(( - ("_id", id_field.to_mongo(id_)), + ('_id', id_field.to_mongo(id_)), )) if fields: @@ -1143,7 +1143,7 @@ def prepare_query_value(self, op, value): def validate(self, value): if not isinstance(value, self.document_type): - self.error("A CachedReferenceField only accepts documents") + self.error('A CachedReferenceField only accepts documents') if isinstance(value, Document) and value.id is None: self.error('You can only reference documents once they have been ' @@ -1191,13 +1191,13 @@ def __init__(self, *args, **kwargs): # Keep the choices as a list of allowed Document class names if choices: for choice in choices: - if isinstance(choice, basestring): + if isinstance(choice, six.string_types): self.choices.append(choice) elif isinstance(choice, type) and issubclass(choice, Document): self.choices.append(choice._class_name) else: self.error('Invalid choices provided: must be a list of' - 'Document subclasses and/or basestrings') + 'Document subclasses and/or six.string_typess') def _validate_choices(self, value): if isinstance(value, dict): @@ -1280,8 +1280,7 @@ def prepare_query_value(self, op, value): class BinaryField(BaseField): - """A binary data field. - """ + """A binary data field.""" def __init__(self, max_bytes=None, **kwargs): self.max_bytes = max_bytes @@ -1289,18 +1288,18 @@ def __init__(self, max_bytes=None, **kwargs): def __set__(self, instance, value): """Handle bytearrays in python 3.1""" - if PY3 and isinstance(value, bytearray): - value = bin_type(value) + if six.PY3 and isinstance(value, bytearray): + value = six.binary_type(value) return super(BinaryField, self).__set__(instance, value) def to_mongo(self, value): return Binary(value) def validate(self, value): - if not isinstance(value, (bin_type, txt_type, Binary)): - self.error("BinaryField only accepts instances of " - "(%s, %s, Binary)" % ( - bin_type.__name__, txt_type.__name__)) + if not isinstance(value, (six.binary_type, six.text_type, Binary)): + self.error('BinaryField only accepts instances of ' + '(%s, %s, Binary)' % ( + six.binary_type.__name__, six.text_type.__name__)) if self.max_bytes is not None and len(value) > self.max_bytes: self.error('Binary value is too long') @@ -1384,11 +1383,13 @@ def fs(self): get_db(self.db_alias), self.collection_name) return self._fs - def get(self, id=None): - if id: - self.grid_id = id + def get(self, grid_id=None): + if grid_id: + self.grid_id = grid_id + if self.grid_id is None: return None + try: if self.gridout is None: self.gridout = self.fs.get(self.grid_id) @@ -1432,7 +1433,7 @@ def read(self, size=-1): try: return gridout.read(size) except Exception: - return "" + return '' def delete(self): # Delete file from GridFS, FileField still remains @@ -1464,9 +1465,8 @@ class FileField(BaseField): """ proxy_class = GridFSProxy - def __init__(self, - db_alias=DEFAULT_CONNECTION_NAME, - collection_name="fs", **kwargs): + def __init__(self, db_alias=DEFAULT_CONNECTION_NAME, collection_name='fs', + **kwargs): super(FileField, self).__init__(**kwargs) self.collection_name = collection_name self.db_alias = db_alias @@ -1488,8 +1488,10 @@ def __get__(self, instance, owner): def __set__(self, instance, value): key = self.name - if ((hasattr(value, 'read') and not - isinstance(value, GridFSProxy)) or isinstance(value, str_types)): + if ( + (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or + isinstance(value, (six.binary_type, six.string_types)) + ): # using "FileField() = file/string" notation grid_file = instance._data.get(self.name) # If a file already exists, delete it @@ -1558,7 +1560,7 @@ def put(self, file_obj, **kwargs): try: img = Image.open(file_obj) img_format = img.format - except Exception, e: + except Exception as e: raise ValidationError('Invalid image: %s' % e) # Progressive JPEG @@ -1667,10 +1669,10 @@ def thumbnail(self): return self.fs.get(out.thumbnail_id) def write(self, *args, **kwargs): - raise RuntimeError("Please use \"put\" method instead") + raise RuntimeError('Please use "put" method instead') def writelines(self, *args, **kwargs): - raise RuntimeError("Please use \"put\" method instead") + raise RuntimeError('Please use "put" method instead') class ImproperlyConfigured(Exception): @@ -1695,14 +1697,17 @@ class ImageField(FileField): def __init__(self, size=None, thumbnail_size=None, collection_name='images', **kwargs): if not Image: - raise ImproperlyConfigured("PIL library was not found") + raise ImproperlyConfigured('PIL library was not found') params_size = ('width', 'height', 'force') - extra_args = dict(size=size, thumbnail_size=thumbnail_size) + extra_args = { + 'size': size, + 'thumbnail_size': thumbnail_size + } for att_name, att in extra_args.items(): value = None if isinstance(att, (tuple, list)): - if PY3: + if six.PY3: value = dict(itertools.zip_longest(params_size, att, fillvalue=None)) else: @@ -1763,10 +1768,10 @@ def generate(self): Generate and Increment the counter """ sequence_name = self.get_sequence_name() - sequence_id = "%s.%s" % (sequence_name, self.name) + sequence_id = '%s.%s' % (sequence_name, self.name) collection = get_db(alias=self.db_alias)[self.collection_name] - counter = collection.find_and_modify(query={"_id": sequence_id}, - update={"$inc": {"next": 1}}, + counter = collection.find_and_modify(query={'_id': sequence_id}, + update={'$inc': {'next': 1}}, new=True, upsert=True) return self.value_decorator(counter['next']) @@ -1789,9 +1794,9 @@ def get_next_value(self): as it is only fixed on set. """ sequence_name = self.get_sequence_name() - sequence_id = "%s.%s" % (sequence_name, self.name) + sequence_id = '%s.%s' % (sequence_name, self.name) collection = get_db(alias=self.db_alias)[self.collection_name] - data = collection.find_one({"_id": sequence_id}) + data = collection.find_one({'_id': sequence_id}) if data: return self.value_decorator(data['next'] + 1) @@ -1861,8 +1866,8 @@ def to_python(self, value): if not self._binary: original_value = value try: - if not isinstance(value, basestring): - value = unicode(value) + if not isinstance(value, six.string_types): + value = six.text_type(value) return uuid.UUID(value) except Exception: return original_value @@ -1870,8 +1875,8 @@ def to_python(self, value): def to_mongo(self, value): if not self._binary: - return unicode(value) - elif isinstance(value, basestring): + return six.text_type(value) + elif isinstance(value, six.string_types): return uuid.UUID(value) return value @@ -1882,11 +1887,11 @@ def prepare_query_value(self, op, value): def validate(self, value): if not isinstance(value, uuid.UUID): - if not isinstance(value, basestring): + if not isinstance(value, six.string_types): value = str(value) try: uuid.UUID(value) - except Exception, exc: + except Exception as exc: self.error('Could not convert to UUID: %s' % exc) @@ -1904,19 +1909,18 @@ class GeoPointField(BaseField): _geo_index = pymongo.GEO2D def validate(self, value): - """Make sure that a geo-value is of type (x, y) - """ + """Make sure that a geo-value is of type (x, y)""" if not isinstance(value, (list, tuple)): self.error('GeoPointField can only accept tuples or lists ' 'of (x, y)') if not len(value) == 2: - self.error("Value (%s) must be a two-dimensional point" % + self.error('Value (%s) must be a two-dimensional point' % repr(value)) elif (not isinstance(value[0], (float, int)) or not isinstance(value[1], (float, int))): self.error( - "Both values (%s) in point must be float or int" % repr(value)) + 'Both values (%s) in point must be float or int' % repr(value)) class PointField(GeoJsonBaseField): @@ -1926,8 +1930,8 @@ class PointField(GeoJsonBaseField): .. code-block:: js - { "type" : "Point" , - "coordinates" : [x, y]} + {'type' : 'Point' , + 'coordinates' : [x, y]} You can either pass a dict with the full information or a list to set the value. @@ -1936,7 +1940,7 @@ class PointField(GeoJsonBaseField): .. versionadded:: 0.8 """ - _type = "Point" + _type = 'Point' class LineStringField(GeoJsonBaseField): @@ -1946,8 +1950,8 @@ class LineStringField(GeoJsonBaseField): .. code-block:: js - { "type" : "LineString" , - "coordinates" : [[x1, y1], [x1, y1] ... [xn, yn]]} + {'type' : 'LineString' , + 'coordinates' : [[x1, y1], [x1, y1] ... [xn, yn]]} You can either pass a dict with the full information or a list of points. @@ -1955,7 +1959,7 @@ class LineStringField(GeoJsonBaseField): .. versionadded:: 0.8 """ - _type = "LineString" + _type = 'LineString' class PolygonField(GeoJsonBaseField): @@ -1965,9 +1969,9 @@ class PolygonField(GeoJsonBaseField): .. code-block:: js - { "type" : "Polygon" , - "coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], - [[x1, y1], [x1, y1] ... [xn, yn]]} + {'type' : 'Polygon' , + 'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]], + [[x1, y1], [x1, y1] ... [xn, yn]]} You can either pass a dict with the full information or a list of LineStrings. The first LineString being the outside and the rest being @@ -1977,7 +1981,7 @@ class PolygonField(GeoJsonBaseField): .. versionadded:: 0.8 """ - _type = "Polygon" + _type = 'Polygon' class MultiPointField(GeoJsonBaseField): @@ -1987,8 +1991,8 @@ class MultiPointField(GeoJsonBaseField): .. code-block:: js - { "type" : "MultiPoint" , - "coordinates" : [[x1, y1], [x2, y2]]} + {'type' : 'MultiPoint' , + 'coordinates' : [[x1, y1], [x2, y2]]} You can either pass a dict with the full information or a list to set the value. @@ -1997,7 +2001,7 @@ class MultiPointField(GeoJsonBaseField): .. versionadded:: 0.9 """ - _type = "MultiPoint" + _type = 'MultiPoint' class MultiLineStringField(GeoJsonBaseField): @@ -2007,9 +2011,9 @@ class MultiLineStringField(GeoJsonBaseField): .. code-block:: js - { "type" : "MultiLineString" , - "coordinates" : [[[x1, y1], [x1, y1] ... [xn, yn]], - [[x1, y1], [x1, y1] ... [xn, yn]]]} + {'type' : 'MultiLineString' , + 'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]], + [[x1, y1], [x1, y1] ... [xn, yn]]]} You can either pass a dict with the full information or a list of points. @@ -2017,7 +2021,7 @@ class MultiLineStringField(GeoJsonBaseField): .. versionadded:: 0.9 """ - _type = "MultiLineString" + _type = 'MultiLineString' class MultiPolygonField(GeoJsonBaseField): @@ -2027,14 +2031,14 @@ class MultiPolygonField(GeoJsonBaseField): .. code-block:: js - { "type" : "MultiPolygon" , - "coordinates" : [[ - [[x1, y1], [x1, y1] ... [xn, yn]], - [[x1, y1], [x1, y1] ... [xn, yn]] - ], [ - [[x1, y1], [x1, y1] ... [xn, yn]], - [[x1, y1], [x1, y1] ... [xn, yn]] - ] + {'type' : 'MultiPolygon' , + 'coordinates' : [[ + [[x1, y1], [x1, y1] ... [xn, yn]], + [[x1, y1], [x1, y1] ... [xn, yn]] + ], [ + [[x1, y1], [x1, y1] ... [xn, yn]], + [[x1, y1], [x1, y1] ... [xn, yn]] + ] } You can either pass a dict with the full information or a list @@ -2044,4 +2048,4 @@ class MultiPolygonField(GeoJsonBaseField): .. versionadded:: 0.9 """ - _type = "MultiPolygon" + _type = 'MultiPolygon' diff --git a/mongoengine/python_support.py b/mongoengine/python_support.py index f4b6f20fc..e51e1bc98 100644 --- a/mongoengine/python_support.py +++ b/mongoengine/python_support.py @@ -1,50 +1,25 @@ -"""Helper functions and types to aid with Python 2.6 - 3 support.""" - -import sys -import warnings - +""" +Helper functions, constants, and types to aid with Python v2.7 - v3.x and +PyMongo v2.7 - v3.x support. +""" import pymongo +import six -# Show a deprecation warning for people using Python v2.6 -# TODO remove in mongoengine v0.11.0 -if sys.version_info[0] == 2 and sys.version_info[1] == 6: - warnings.warn( - 'Python v2.6 support is deprecated and is going to be dropped ' - 'entirely in the upcoming v0.11.0 release. Update your Python ' - 'version if you want to have access to the latest features and ' - 'bug fixes in MongoEngine.', - DeprecationWarning - ) - if pymongo.version_tuple[0] < 3: IS_PYMONGO_3 = False else: IS_PYMONGO_3 = True -PY3 = sys.version_info[0] == 3 - -if PY3: - import codecs - from io import BytesIO as StringIO - # return s converted to binary. b('test') should be equivalent to b'test' - def b(s): - return codecs.latin_1_encode(s)[0] +# six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3. +StringIO = six.BytesIO - bin_type = bytes - txt_type = str -else: +# Additionally for Py2, try to use the faster cStringIO, if available +if not six.PY3: try: - from cStringIO import StringIO + import cStringIO except ImportError: - from StringIO import StringIO - - # Conversion to binary only necessary in Python 3 - def b(s): - return s - - bin_type = str - txt_type = unicode - -str_types = (bin_type, txt_type) + pass + else: + StringIO = cStringIO.StringIO diff --git a/mongoengine/queryset/__init__.py b/mongoengine/queryset/__init__.py index c8fa09dba..5219c39e1 100644 --- a/mongoengine/queryset/__init__.py +++ b/mongoengine/queryset/__init__.py @@ -1,11 +1,17 @@ -from mongoengine.errors import (DoesNotExist, InvalidQueryError, - MultipleObjectsReturned, NotUniqueError, - OperationError) +from mongoengine.errors import * from mongoengine.queryset.field_list import * from mongoengine.queryset.manager import * from mongoengine.queryset.queryset import * from mongoengine.queryset.transform import * from mongoengine.queryset.visitor import * -__all__ = (field_list.__all__ + manager.__all__ + queryset.__all__ + - transform.__all__ + visitor.__all__) +# Expose just the public subset of all imported objects and constants. +__all__ = ( + 'QuerySet', 'QuerySetNoCache', 'Q', 'queryset_manager', 'QuerySetManager', + 'QueryFieldList', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL', + + # Errors that might be related to a queryset, mostly here for backward + # compatibility + 'DoesNotExist', 'InvalidQueryError', 'MultipleObjectsReturned', + 'NotUniqueError', 'OperationError', +) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index ec48b4f36..3ee978b8b 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -12,9 +12,10 @@ import pymongo import pymongo.errors from pymongo.common import validate_read_preference +import six from mongoengine import signals -from mongoengine.base.common import get_document +from mongoengine.base import get_document from mongoengine.common import _import_class from mongoengine.connection import get_db from mongoengine.context_managers import switch_db @@ -73,10 +74,10 @@ def __init__(self, document, collection): # subclasses of the class being used if document._meta.get('allow_inheritance') is True: if len(self._document._subclasses) == 1: - self._initial_query = {"_cls": self._document._subclasses[0]} + self._initial_query = {'_cls': self._document._subclasses[0]} else: self._initial_query = { - "_cls": {"$in": self._document._subclasses}} + '_cls': {'$in': self._document._subclasses}} self._loaded_fields = QueryFieldList(always_include=['_cls']) self._cursor_obj = None self._limit = None @@ -105,8 +106,8 @@ def __call__(self, q_obj=None, class_check=True, read_preference=None, if q_obj: # make sure proper query object is passed if not isinstance(q_obj, QNode): - msg = ("Not a query object: %s. " - "Did you intend to use key=value?" % q_obj) + msg = ('Not a query object: %s. ' + 'Did you intend to use key=value?' % q_obj) raise InvalidQueryError(msg) query &= q_obj @@ -133,10 +134,10 @@ def __getstate__(self): obj_dict = self.__dict__.copy() # don't picke collection, instead pickle collection params - obj_dict.pop("_collection_obj") + obj_dict.pop('_collection_obj') # don't pickle cursor - obj_dict["_cursor_obj"] = None + obj_dict['_cursor_obj'] = None return obj_dict @@ -147,7 +148,7 @@ def __setstate__(self, obj_dict): See https://github.com/MongoEngine/mongoengine/issues/442 """ - obj_dict["_collection_obj"] = obj_dict["_document"]._get_collection() + obj_dict['_collection_obj'] = obj_dict['_document']._get_collection() # update attributes self.__dict__.update(obj_dict) @@ -166,7 +167,7 @@ def __getitem__(self, key): queryset._skip, queryset._limit = key.start, key.stop if key.start and key.stop: queryset._limit = key.stop - key.start - except IndexError, err: + except IndexError as err: # PyMongo raises an error if key.start == key.stop, catch it, # bin it, kill it. start = key.start or 0 @@ -199,19 +200,16 @@ def __iter__(self): raise NotImplementedError def _has_data(self): - """ Retrieves whether cursor has any data. """ - + """Return True if cursor has any data.""" queryset = self.order_by() return False if queryset.first() is None else True def __nonzero__(self): - """ Avoid to open all records in an if stmt in Py2. """ - + """Avoid to open all records in an if stmt in Py2.""" return self._has_data() def __bool__(self): - """ Avoid to open all records in an if stmt in Py3. """ - + """Avoid to open all records in an if stmt in Py3.""" return self._has_data() # Core functions @@ -239,7 +237,7 @@ def search_text(self, text, language=None): queryset = self.clone() if queryset._search_text: raise OperationError( - "It is not possible to use search_text two times.") + 'It is not possible to use search_text two times.') query_kwargs = SON({'$search': text}) if language: @@ -268,7 +266,7 @@ def get(self, *q_objs, **query): try: result = queryset.next() except StopIteration: - msg = ("%s matching query does not exist." + msg = ('%s matching query does not exist.' % queryset._document._class_name) raise queryset._document.DoesNotExist(msg) try: @@ -290,8 +288,7 @@ def create(self, **kwargs): return self._document(**kwargs).save() def first(self): - """Retrieve the first object matching the query. - """ + """Retrieve the first object matching the query.""" queryset = self.clone() try: result = queryset[0] @@ -340,7 +337,7 @@ def insert(self, doc_or_docs, load_bulk=True, % str(self._document)) raise OperationError(msg) if doc.pk and not doc._created: - msg = "Some documents have ObjectIds use doc.update() instead" + msg = 'Some documents have ObjectIds use doc.update() instead' raise OperationError(msg) signal_kwargs = signal_kwargs or {} @@ -350,17 +347,17 @@ def insert(self, doc_or_docs, load_bulk=True, raw = [doc.to_mongo() for doc in docs] try: ids = self._collection.insert(raw, **write_concern) - except pymongo.errors.DuplicateKeyError, err: + except pymongo.errors.DuplicateKeyError as err: message = 'Could not save document (%s)' - raise NotUniqueError(message % unicode(err)) - except pymongo.errors.OperationFailure, err: + raise NotUniqueError(message % six.text_type(err)) + except pymongo.errors.OperationFailure as err: message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', unicode(err)): + if re.match('^E1100[01] duplicate key', six.text_type(err)): # E11000 - duplicate key error index # E11001 - duplicate key on update message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - raise OperationError(message % unicode(err)) + raise NotUniqueError(message % six.text_type(err)) + raise OperationError(message % six.text_type(err)) if not load_bulk: signals.post_bulk_insert.send( @@ -386,7 +383,8 @@ def count(self, with_limit_and_skip=False): return 0 return self._cursor.count(with_limit_and_skip=with_limit_and_skip) - def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None): + def delete(self, write_concern=None, _from_doc_delete=False, + cascade_refs=None): """Delete the documents matched by the query. :param write_concern: Extra keyword arguments are passed down which @@ -409,8 +407,9 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None): # Handle deletes where skips or limits have been applied or # there is an untriggered delete signal has_delete_signal = signals.signals_available and ( - signals.pre_delete.has_receivers_for(self._document) or - signals.post_delete.has_receivers_for(self._document)) + signals.pre_delete.has_receivers_for(doc) or + signals.post_delete.has_receivers_for(doc) + ) call_document_delete = (queryset._skip or queryset._limit or has_delete_signal) and not _from_doc_delete @@ -423,37 +422,44 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None): return cnt delete_rules = doc._meta.get('delete_rules') or {} + delete_rules = list(delete_rules.items()) + # Check for DENY rules before actually deleting/nullifying any other # references - for rule_entry in delete_rules: + for rule_entry, rule in delete_rules: document_cls, field_name = rule_entry if document_cls._meta.get('abstract'): continue - rule = doc._meta['delete_rules'][rule_entry] - if rule == DENY and document_cls.objects( - **{field_name + '__in': self}).count() > 0: - msg = ("Could not delete document (%s.%s refers to it)" - % (document_cls.__name__, field_name)) - raise OperationError(msg) - for rule_entry in delete_rules: + if rule == DENY: + refs = document_cls.objects(**{field_name + '__in': self}) + if refs.limit(1).count() > 0: + raise OperationError( + 'Could not delete document (%s.%s refers to it)' + % (document_cls.__name__, field_name) + ) + + # Check all the other rules + for rule_entry, rule in delete_rules: document_cls, field_name = rule_entry if document_cls._meta.get('abstract'): continue - rule = doc._meta['delete_rules'][rule_entry] + if rule == CASCADE: cascade_refs = set() if cascade_refs is None else cascade_refs # Handle recursive reference if doc._collection == document_cls._collection: for ref in queryset: cascade_refs.add(ref.id) - ref_q = document_cls.objects(**{field_name + '__in': self, 'pk__nin': cascade_refs}) - ref_q_count = ref_q.count() - if ref_q_count > 0: - ref_q.delete(write_concern=write_concern, cascade_refs=cascade_refs) + refs = document_cls.objects(**{field_name + '__in': self, + 'pk__nin': cascade_refs}) + if refs.count() > 0: + refs.delete(write_concern=write_concern, + cascade_refs=cascade_refs) elif rule == NULLIFY: document_cls.objects(**{field_name + '__in': self}).update( - write_concern=write_concern, **{'unset__%s' % field_name: 1}) + write_concern=write_concern, + **{'unset__%s' % field_name: 1}) elif rule == PULL: document_cls.objects(**{field_name + '__in': self}).update( write_concern=write_concern, @@ -461,7 +467,7 @@ def delete(self, write_concern=None, _from_doc_delete=False, cascade_refs=None): result = queryset._collection.remove(queryset._query, **write_concern) if result: - return result.get("n") + return result.get('n') def update(self, upsert=False, multi=True, write_concern=None, full_result=False, **update): @@ -482,7 +488,7 @@ def update(self, upsert=False, multi=True, write_concern=None, .. versionadded:: 0.2 """ if not update and not upsert: - raise OperationError("No update parameters, would remove data") + raise OperationError('No update parameters, would remove data') if write_concern is None: write_concern = {} @@ -495,9 +501,9 @@ def update(self, upsert=False, multi=True, write_concern=None, # then ensure we add _cls to the update operation if upsert and '_cls' in query: if '$set' in update: - update["$set"]["_cls"] = queryset._document._class_name + update['$set']['_cls'] = queryset._document._class_name else: - update["$set"] = {"_cls": queryset._document._class_name} + update['$set'] = {'_cls': queryset._document._class_name} try: result = queryset._collection.update(query, update, multi=multi, upsert=upsert, **write_concern) @@ -505,13 +511,13 @@ def update(self, upsert=False, multi=True, write_concern=None, return result elif result: return result['n'] - except pymongo.errors.DuplicateKeyError, err: - raise NotUniqueError(u'Update failed (%s)' % unicode(err)) - except pymongo.errors.OperationFailure, err: - if unicode(err) == u'multi not coded yet': + except pymongo.errors.DuplicateKeyError as err: + raise NotUniqueError(u'Update failed (%s)' % six.text_type(err)) + except pymongo.errors.OperationFailure as err: + if six.text_type(err) == u'multi not coded yet': message = u'update() method requires MongoDB 1.1.3+' raise OperationError(message) - raise OperationError(u'Update failed (%s)' % unicode(err)) + raise OperationError(u'Update failed (%s)' % six.text_type(err)) def upsert_one(self, write_concern=None, **update): """Overwrite or add the first document matched by the query. @@ -582,11 +588,11 @@ def modify(self, upsert=False, full_response=False, remove=False, new=False, **u """ if remove and new: - raise OperationError("Conflicting parameters: remove and new") + raise OperationError('Conflicting parameters: remove and new') if not update and not upsert and not remove: raise OperationError( - "No update parameters, must either update or remove") + 'No update parameters, must either update or remove') queryset = self.clone() query = queryset._query @@ -597,7 +603,7 @@ def modify(self, upsert=False, full_response=False, remove=False, new=False, **u try: if IS_PYMONGO_3: if full_response: - msg = "With PyMongo 3+, it is not possible anymore to get the full response." + msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' warnings.warn(msg, DeprecationWarning) if remove: result = queryset._collection.find_one_and_delete( @@ -615,14 +621,14 @@ def modify(self, upsert=False, full_response=False, remove=False, new=False, **u result = queryset._collection.find_and_modify( query, update, upsert=upsert, sort=sort, remove=remove, new=new, full_response=full_response, **self._cursor_args) - except pymongo.errors.DuplicateKeyError, err: - raise NotUniqueError(u"Update failed (%s)" % err) - except pymongo.errors.OperationFailure, err: - raise OperationError(u"Update failed (%s)" % err) + except pymongo.errors.DuplicateKeyError as err: + raise NotUniqueError(u'Update failed (%s)' % err) + except pymongo.errors.OperationFailure as err: + raise OperationError(u'Update failed (%s)' % err) if full_response: - if result["value"] is not None: - result["value"] = self._document._from_son(result["value"], only_fields=self.only_fields) + if result['value'] is not None: + result['value'] = self._document._from_son(result['value'], only_fields=self.only_fields) else: if result is not None: result = self._document._from_son(result, only_fields=self.only_fields) @@ -640,7 +646,7 @@ def with_id(self, object_id): """ queryset = self.clone() if not queryset._query_obj.empty: - msg = "Cannot use a filter whilst using `with_id`" + msg = 'Cannot use a filter whilst using `with_id`' raise InvalidQueryError(msg) return queryset.filter(pk=object_id).first() @@ -684,7 +690,7 @@ def no_sub_classes(self): Only return instances of this document and not any inherited documents """ if self._document._meta.get('allow_inheritance') is True: - self._initial_query = {"_cls": self._document._class_name} + self._initial_query = {'_cls': self._document._class_name} return self @@ -810,49 +816,56 @@ def distinct(self, field): .. versionchanged:: 0.6 - Improved db_field refrence handling """ queryset = self.clone() + try: field = self._fields_to_dbfields([field]).pop() - finally: - distinct = self._dereference(queryset._cursor.distinct(field), 1, - name=field, instance=self._document) - - doc_field = self._document._fields.get(field.split('.', 1)[0]) - instance = False - # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) - EmbeddedDocumentField = _import_class('EmbeddedDocumentField') - ListField = _import_class('ListField') - GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') - if isinstance(doc_field, ListField): - doc_field = getattr(doc_field, "field", doc_field) - if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): - instance = getattr(doc_field, "document_type", False) - # handle distinct on subdocuments - if '.' in field: - for field_part in field.split('.')[1:]: - # if looping on embedded document, get the document type instance - if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): - doc_field = instance - # now get the subdocument - doc_field = getattr(doc_field, field_part, doc_field) - # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) - if isinstance(doc_field, ListField): - doc_field = getattr(doc_field, "field", doc_field) - if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): - instance = getattr(doc_field, "document_type", False) - if instance and isinstance(doc_field, (EmbeddedDocumentField, - GenericEmbeddedDocumentField)): - distinct = [instance(**doc) for doc in distinct] - return distinct + except LookUpError: + pass + + distinct = self._dereference(queryset._cursor.distinct(field), 1, + name=field, instance=self._document) + + doc_field = self._document._fields.get(field.split('.', 1)[0]) + instance = None + + # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) + EmbeddedDocumentField = _import_class('EmbeddedDocumentField') + ListField = _import_class('ListField') + GenericEmbeddedDocumentField = _import_class('GenericEmbeddedDocumentField') + if isinstance(doc_field, ListField): + doc_field = getattr(doc_field, 'field', doc_field) + if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): + instance = getattr(doc_field, 'document_type', None) + + # handle distinct on subdocuments + if '.' in field: + for field_part in field.split('.')[1:]: + # if looping on embedded document, get the document type instance + if instance and isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): + doc_field = instance + # now get the subdocument + doc_field = getattr(doc_field, field_part, doc_field) + # We may need to cast to the correct type eg. ListField(EmbeddedDocumentField) + if isinstance(doc_field, ListField): + doc_field = getattr(doc_field, 'field', doc_field) + if isinstance(doc_field, (EmbeddedDocumentField, GenericEmbeddedDocumentField)): + instance = getattr(doc_field, 'document_type', None) + + if instance and isinstance(doc_field, (EmbeddedDocumentField, + GenericEmbeddedDocumentField)): + distinct = [instance(**doc) for doc in distinct] + + return distinct def only(self, *fields): """Load only a subset of this document's fields. :: - post = BlogPost.objects(...).only("title", "author.name") + post = BlogPost.objects(...).only('title', 'author.name') .. note :: `only()` is chainable and will perform a union :: So with the following it will fetch both: `title` and `author.name`:: - post = BlogPost.objects.only("title").only("author.name") + post = BlogPost.objects.only('title').only('author.name') :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any field filters. @@ -862,19 +875,19 @@ def only(self, *fields): .. versionadded:: 0.3 .. versionchanged:: 0.5 - Added subfield support """ - fields = dict([(f, QueryFieldList.ONLY) for f in fields]) + fields = {f: QueryFieldList.ONLY for f in fields} self.only_fields = fields.keys() return self.fields(True, **fields) def exclude(self, *fields): """Opposite to .only(), exclude some document's fields. :: - post = BlogPost.objects(...).exclude("comments") + post = BlogPost.objects(...).exclude('comments') .. note :: `exclude()` is chainable and will perform a union :: So with the following it will exclude both: `title` and `author.name`:: - post = BlogPost.objects.exclude("title").exclude("author.name") + post = BlogPost.objects.exclude('title').exclude('author.name') :func:`~mongoengine.queryset.QuerySet.all_fields` will reset any field filters. @@ -883,7 +896,7 @@ def exclude(self, *fields): .. versionadded:: 0.5 """ - fields = dict([(f, QueryFieldList.EXCLUDE) for f in fields]) + fields = {f: QueryFieldList.EXCLUDE for f in fields} return self.fields(**fields) def fields(self, _only_called=False, **kwargs): @@ -904,7 +917,7 @@ def fields(self, _only_called=False, **kwargs): """ # Check for an operator and transform to mongo-style if there is - operators = ["slice"] + operators = ['slice'] cleaned_fields = [] for key, value in kwargs.items(): parts = key.split('__') @@ -928,7 +941,7 @@ def all_fields(self): """Include all fields. Reset all previously calls of .only() or .exclude(). :: - post = BlogPost.objects.exclude("comments").all_fields() + post = BlogPost.objects.exclude('comments').all_fields() .. versionadded:: 0.5 """ @@ -955,7 +968,7 @@ def comment(self, text): See https://docs.mongodb.com/manual/reference/method/cursor.comment/#cursor.comment for details. """ - return self._chainable_method("comment", text) + return self._chainable_method('comment', text) def explain(self, format=False): """Return an explain plan record for the @@ -964,8 +977,15 @@ def explain(self, format=False): :param format: format the plan before returning it """ plan = self._cursor.explain() + + # TODO remove this option completely - it's useless. If somebody + # wants to pretty-print the output, they easily can. if format: + msg = ('"format" param of BaseQuerySet.explain has been ' + 'deprecated and will be removed in future versions.') + warnings.warn(msg, DeprecationWarning) plan = pprint.pformat(plan) + return plan # DEPRECATED. Has no more impact on PyMongo 3+ @@ -978,7 +998,7 @@ def snapshot(self, enabled): .. deprecated:: Ignored with PyMongo 3+ """ if IS_PYMONGO_3: - msg = "snapshot is deprecated as it has no impact when using PyMongo 3+." + msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) queryset = self.clone() queryset._snapshot = enabled @@ -1004,7 +1024,7 @@ def slave_okay(self, enabled): .. deprecated:: Ignored with PyMongo 3+ """ if IS_PYMONGO_3: - msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+." + msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) queryset = self.clone() queryset._slave_okay = enabled @@ -1066,7 +1086,7 @@ def max_time_ms(self, ms): :param ms: the number of milliseconds before killing the query on the server """ - return self._chainable_method("max_time_ms", ms) + return self._chainable_method('max_time_ms', ms) # JSON Helpers @@ -1149,19 +1169,19 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, MapReduceDocument = _import_class('MapReduceDocument') - if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.7.1") + if not hasattr(self._collection, 'map_reduce'): + raise NotImplementedError('Requires MongoDB >= 1.7.1') map_f_scope = {} if isinstance(map_f, Code): map_f_scope = map_f.scope - map_f = unicode(map_f) + map_f = six.text_type(map_f) map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) reduce_f_scope = {} if isinstance(reduce_f, Code): reduce_f_scope = reduce_f.scope - reduce_f = unicode(reduce_f) + reduce_f = six.text_type(reduce_f) reduce_f_code = queryset._sub_js_fields(reduce_f) reduce_f = Code(reduce_f_code, reduce_f_scope) @@ -1171,7 +1191,7 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, finalize_f_scope = {} if isinstance(finalize_f, Code): finalize_f_scope = finalize_f.scope - finalize_f = unicode(finalize_f) + finalize_f = six.text_type(finalize_f) finalize_f_code = queryset._sub_js_fields(finalize_f) finalize_f = Code(finalize_f_code, finalize_f_scope) mr_args['finalize'] = finalize_f @@ -1187,7 +1207,7 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, else: map_reduce_function = 'map_reduce' - if isinstance(output, basestring): + if isinstance(output, six.string_types): mr_args['out'] = output elif isinstance(output, dict): @@ -1200,7 +1220,7 @@ def map_reduce(self, map_f, reduce_f, output, finalize_f=None, limit=None, break else: - raise OperationError("actionData not specified for output") + raise OperationError('actionData not specified for output') db_alias = output.get('db_alias') remaing_args = ['db', 'sharded', 'nonAtomic'] @@ -1430,7 +1450,7 @@ def _cursor_args(self): # snapshot is not handled at all by PyMongo 3+ # TODO: evaluate similar possibilities using modifiers if self._snapshot: - msg = "The snapshot option is not anymore available with PyMongo 3+" + msg = 'The snapshot option is not anymore available with PyMongo 3+' warnings.warn(msg, DeprecationWarning) cursor_args = { 'no_cursor_timeout': not self._timeout @@ -1442,7 +1462,7 @@ def _cursor_args(self): if fields_name not in cursor_args: cursor_args[fields_name] = {} - cursor_args[fields_name]['_text_score'] = {'$meta': "textScore"} + cursor_args[fields_name]['_text_score'] = {'$meta': 'textScore'} return cursor_args @@ -1497,8 +1517,8 @@ def _query(self): if self._mongo_query is None: self._mongo_query = self._query_obj.to_query(self._document) if self._class_check and self._initial_query: - if "_cls" in self._mongo_query: - self._mongo_query = {"$and": [self._initial_query, self._mongo_query]} + if '_cls' in self._mongo_query: + self._mongo_query = {'$and': [self._initial_query, self._mongo_query]} else: self._mongo_query.update(self._initial_query) return self._mongo_query @@ -1510,8 +1530,7 @@ def _dereference(self): return self.__dereference def no_dereference(self): - """Turn off any dereferencing for the results of this queryset. - """ + """Turn off any dereferencing for the results of this queryset.""" queryset = self.clone() queryset._auto_dereference = False return queryset @@ -1540,7 +1559,7 @@ def _item_frequencies_map_reduce(self, field, normalize=False): emit(null, 1); } } - """ % dict(field=field) + """ % {'field': field} reduce_func = """ function(key, values) { var total = 0; @@ -1562,8 +1581,8 @@ def _item_frequencies_map_reduce(self, field, normalize=False): if normalize: count = sum(frequencies.values()) - frequencies = dict([(k, float(v) / count) - for k, v in frequencies.items()]) + frequencies = {k: float(v) / count + for k, v in frequencies.items()} return frequencies @@ -1615,10 +1634,10 @@ def _item_frequencies_exec_js(self, field, normalize=False): } """ total, data, types = self.exec_js(freq_func, field) - values = dict([(types.get(k), int(v)) for k, v in data.iteritems()]) + values = {types.get(k): int(v) for k, v in data.iteritems()} if normalize: - values = dict([(k, float(v) / total) for k, v in values.items()]) + values = {k: float(v) / total for k, v in values.items()} frequencies = {} for k, v in values.iteritems(): @@ -1640,14 +1659,14 @@ def _fields_to_dbfields(self, fields): for x in document._subclasses][1:] for field in fields: try: - field = ".".join(f.db_field for f in + field = '.'.join(f.db_field for f in document._lookup_field(field.split('.'))) ret.append(field) - except LookUpError, err: + except LookUpError as err: found = False for subdoc in subclasses: try: - subfield = ".".join(f.db_field for f in + subfield = '.'.join(f.db_field for f in subdoc._lookup_field(field.split('.'))) ret.append(subfield) found = True @@ -1660,15 +1679,14 @@ def _fields_to_dbfields(self, fields): return ret def _get_order_by(self, keys): - """Creates a list of order by fields - """ + """Creates a list of order by fields""" key_list = [] for key in keys: if not key: continue if key == '$text_score': - key_list.append(('_text_score', {'$meta': "textScore"})) + key_list.append(('_text_score', {'$meta': 'textScore'})) continue direction = pymongo.ASCENDING @@ -1740,7 +1758,7 @@ def clean(data, path=None): # If we need to coerce types, we need to determine the # type of this field and use the corresponding # .to_python(...) - from mongoengine.fields import EmbeddedDocumentField + EmbeddedDocumentField = _import_class('EmbeddedDocumentField') obj = self._document for chunk in path.split('.'): @@ -1774,7 +1792,7 @@ def field_path_sub(match): field_name = match.group(1).split('.') fields = self._document._lookup_field(field_name) # Substitute the correct name for the field into the javascript - return ".".join([f.db_field for f in fields]) + return '.'.join([f.db_field for f in fields]) code = re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) code = re.sub(u'\{\{\s*~([A-z_][A-z_0-9.]+?)\s*\}\}', field_path_sub, @@ -1785,21 +1803,21 @@ def _chainable_method(self, method_name, val): queryset = self.clone() method = getattr(queryset._cursor, method_name) method(val) - setattr(queryset, "_" + method_name, val) + setattr(queryset, '_' + method_name, val) return queryset # Deprecated def ensure_index(self, **kwargs): """Deprecated use :func:`Document.ensure_index`""" - msg = ("Doc.objects()._ensure_index() is deprecated. " - "Use Doc.ensure_index() instead.") + msg = ('Doc.objects()._ensure_index() is deprecated. ' + 'Use Doc.ensure_index() instead.') warnings.warn(msg, DeprecationWarning) self._document.__class__.ensure_index(**kwargs) return self def _ensure_indexes(self): """Deprecated use :func:`~Document.ensure_indexes`""" - msg = ("Doc.objects()._ensure_indexes() is deprecated. " - "Use Doc.ensure_indexes() instead.") + msg = ('Doc.objects()._ensure_indexes() is deprecated. ' + 'Use Doc.ensure_indexes() instead.') warnings.warn(msg, DeprecationWarning) self._document.__class__.ensure_indexes() diff --git a/mongoengine/queryset/field_list.py b/mongoengine/queryset/field_list.py index c10ad5525..0524c3bbb 100644 --- a/mongoengine/queryset/field_list.py +++ b/mongoengine/queryset/field_list.py @@ -67,7 +67,7 @@ def __nonzero__(self): return bool(self.fields) def as_dict(self): - field_list = dict((field, self.value) for field in self.fields) + field_list = {field: self.value for field in self.fields} if self.slice: field_list.update(self.slice) if self._id is not None: diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index b185b340a..9c1f24e12 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -53,15 +53,14 @@ def __len__(self): return self._len def __repr__(self): - """Provides the string representation of the QuerySet - """ + """Provide a string representation of the QuerySet""" if self._iter: return '.. queryset mid-iteration ..' self._populate_cache() data = self._result_cache[:REPR_OUTPUT_SIZE + 1] if len(data) > REPR_OUTPUT_SIZE: - data[-1] = "...(remaining elements truncated)..." + data[-1] = '...(remaining elements truncated)...' return repr(data) def _iter_results(self): @@ -113,7 +112,7 @@ def _populate_cache(self): # Pull in ITER_CHUNK_SIZE docs from the database and store them in # the result cache. try: - for i in xrange(ITER_CHUNK_SIZE): + for _ in xrange(ITER_CHUNK_SIZE): self._result_cache.append(self.next()) except StopIteration: # Getting this exception means there are no more docs in the @@ -142,7 +141,7 @@ def no_cache(self): .. versionadded:: 0.8.3 Convert to non caching queryset """ if self._result_cache is not None: - raise OperationError("QuerySet already cached") + raise OperationError('QuerySet already cached') return self.clone_into(QuerySetNoCache(self._document, self._collection)) @@ -165,13 +164,14 @@ def __repr__(self): return '.. queryset mid-iteration ..' data = [] - for i in xrange(REPR_OUTPUT_SIZE + 1): + for _ in xrange(REPR_OUTPUT_SIZE + 1): try: data.append(self.next()) except StopIteration: break + if len(data) > REPR_OUTPUT_SIZE: - data[-1] = "...(remaining elements truncated)..." + data[-1] = '...(remaining elements truncated)...' self.rewind() return repr(data) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index b3acca40f..af59917cf 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -3,8 +3,9 @@ from bson import ObjectId, SON from bson.dbref import DBRef import pymongo +import six -from mongoengine.base.fields import UPDATE_OPERATORS +from mongoengine.base import UPDATE_OPERATORS from mongoengine.common import _import_class from mongoengine.connection import get_connection from mongoengine.errors import InvalidQueryError @@ -29,12 +30,11 @@ # TODO make this less complex def query(_doc_cls=None, **kwargs): - """Transform a query from Django-style format to Mongo format. - """ + """Transform a query from Django-style format to Mongo format.""" mongo_query = {} merge_query = defaultdict(list) for key, value in sorted(kwargs.items()): - if key == "__raw__": + if key == '__raw__': mongo_query.update(value) continue @@ -47,7 +47,7 @@ def query(_doc_cls=None, **kwargs): op = parts.pop() # Allow to escape operator-like field name by __ - if len(parts) > 1 and parts[-1] == "": + if len(parts) > 1 and parts[-1] == '': parts.pop() negate = False @@ -59,7 +59,7 @@ def query(_doc_cls=None, **kwargs): # Switch field names to proper names [set in Field(name='foo')] try: fields = _doc_cls._lookup_field(parts) - except Exception, e: + except Exception as e: raise InvalidQueryError(e) parts = [] @@ -69,7 +69,7 @@ def query(_doc_cls=None, **kwargs): cleaned_fields = [] for field in fields: append_field = True - if isinstance(field, basestring): + if isinstance(field, six.string_types): parts.append(field) append_field = False # is last and CachedReferenceField @@ -87,9 +87,9 @@ def query(_doc_cls=None, **kwargs): singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not'] singular_ops += STRING_OPERATORS if op in singular_ops: - if isinstance(field, basestring): + if isinstance(field, six.string_types): if (op in STRING_OPERATORS and - isinstance(value, basestring)): + isinstance(value, six.string_types)): StringField = _import_class('StringField') value = StringField.prepare_query_value(op, value) else: @@ -129,10 +129,10 @@ def query(_doc_cls=None, **kwargs): value = query(field.field.document_type, **value) else: value = field.prepare_query_value(op, value) - value = {"$elemMatch": value} + value = {'$elemMatch': value} elif op in CUSTOM_OPERATORS: - NotImplementedError("Custom method '%s' has not " - "been implemented" % op) + NotImplementedError('Custom method "%s" has not ' + 'been implemented' % op) elif op not in STRING_OPERATORS: value = {'$' + op: value} @@ -197,15 +197,16 @@ def query(_doc_cls=None, **kwargs): def update(_doc_cls=None, **update): - """Transform an update spec from Django-style format to Mongo format. + """Transform an update spec from Django-style format to Mongo + format. """ mongo_update = {} for key, value in update.items(): - if key == "__raw__": + if key == '__raw__': mongo_update.update(value) continue parts = key.split('__') - # if there is no operator, default to "set" + # if there is no operator, default to 'set' if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS: parts.insert(0, 'set') # Check for an operator and transform to mongo-style if there is @@ -224,21 +225,21 @@ def update(_doc_cls=None, **update): elif op == 'add_to_set': op = 'addToSet' elif op == 'set_on_insert': - op = "setOnInsert" + op = 'setOnInsert' match = None if parts[-1] in COMPARISON_OPERATORS: match = parts.pop() # Allow to escape operator-like field name by __ - if len(parts) > 1 and parts[-1] == "": + if len(parts) > 1 and parts[-1] == '': parts.pop() if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] try: fields = _doc_cls._lookup_field(parts) - except Exception, e: + except Exception as e: raise InvalidQueryError(e) parts = [] @@ -246,7 +247,7 @@ def update(_doc_cls=None, **update): appended_sub_field = False for field in fields: append_field = True - if isinstance(field, basestring): + if isinstance(field, six.string_types): # Convert the S operator to $ if field == 'S': field = '$' @@ -267,7 +268,7 @@ def update(_doc_cls=None, **update): else: field = cleaned_fields[-1] - GeoJsonBaseField = _import_class("GeoJsonBaseField") + GeoJsonBaseField = _import_class('GeoJsonBaseField') if isinstance(field, GeoJsonBaseField): value = field.to_mongo(value) @@ -281,7 +282,7 @@ def update(_doc_cls=None, **update): value = [field.prepare_query_value(op, v) for v in value] elif field.required or value is not None: value = field.prepare_query_value(op, value) - elif op == "unset": + elif op == 'unset': value = 1 if match: @@ -291,16 +292,16 @@ def update(_doc_cls=None, **update): key = '.'.join(parts) if not op: - raise InvalidQueryError("Updates must supply an operation " - "eg: set__FIELD=value") + raise InvalidQueryError('Updates must supply an operation ' + 'eg: set__FIELD=value') if 'pull' in op and '.' in key: # Dot operators don't work on pull operations # unless they point to a list field # Otherwise it uses nested dict syntax if op == 'pullAll': - raise InvalidQueryError("pullAll operations only support " - "a single field depth") + raise InvalidQueryError('pullAll operations only support ' + 'a single field depth') # Look for the last list field and use dot notation until there field_classes = [c.__class__ for c in cleaned_fields] @@ -311,7 +312,7 @@ def update(_doc_cls=None, **update): # Then process as normal last_listField = len( cleaned_fields) - field_classes.index(ListField) - key = ".".join(parts[:last_listField]) + key = '.'.join(parts[:last_listField]) parts = parts[last_listField:] parts.insert(0, key) @@ -319,7 +320,7 @@ def update(_doc_cls=None, **update): for key in parts: value = {key: value} elif op == 'addToSet' and isinstance(value, list): - value = {key: {"$each": value}} + value = {key: {'$each': value}} else: value = {key: value} key = '$' + op @@ -333,78 +334,82 @@ def update(_doc_cls=None, **update): def _geo_operator(field, op, value): - """Helper to return the query for a given geo query""" - if op == "max_distance": + """Helper to return the query for a given geo query.""" + if op == 'max_distance': value = {'$maxDistance': value} - elif op == "min_distance": + elif op == 'min_distance': value = {'$minDistance': value} elif field._geo_index == pymongo.GEO2D: - if op == "within_distance": + if op == 'within_distance': value = {'$within': {'$center': value}} - elif op == "within_spherical_distance": + elif op == 'within_spherical_distance': value = {'$within': {'$centerSphere': value}} - elif op == "within_polygon": + elif op == 'within_polygon': value = {'$within': {'$polygon': value}} - elif op == "near": + elif op == 'near': value = {'$near': value} - elif op == "near_sphere": + elif op == 'near_sphere': value = {'$nearSphere': value} elif op == 'within_box': value = {'$within': {'$box': value}} else: - raise NotImplementedError("Geo method '%s' has not " - "been implemented for a GeoPointField" % op) + raise NotImplementedError('Geo method "%s" has not been ' + 'implemented for a GeoPointField' % op) else: - if op == "geo_within": - value = {"$geoWithin": _infer_geometry(value)} - elif op == "geo_within_box": - value = {"$geoWithin": {"$box": value}} - elif op == "geo_within_polygon": - value = {"$geoWithin": {"$polygon": value}} - elif op == "geo_within_center": - value = {"$geoWithin": {"$center": value}} - elif op == "geo_within_sphere": - value = {"$geoWithin": {"$centerSphere": value}} - elif op == "geo_intersects": - value = {"$geoIntersects": _infer_geometry(value)} - elif op == "near": + if op == 'geo_within': + value = {'$geoWithin': _infer_geometry(value)} + elif op == 'geo_within_box': + value = {'$geoWithin': {'$box': value}} + elif op == 'geo_within_polygon': + value = {'$geoWithin': {'$polygon': value}} + elif op == 'geo_within_center': + value = {'$geoWithin': {'$center': value}} + elif op == 'geo_within_sphere': + value = {'$geoWithin': {'$centerSphere': value}} + elif op == 'geo_intersects': + value = {'$geoIntersects': _infer_geometry(value)} + elif op == 'near': value = {'$near': _infer_geometry(value)} else: - raise NotImplementedError("Geo method '%s' has not " - "been implemented for a %s " % (op, field._name)) + raise NotImplementedError( + 'Geo method "%s" has not been implemented for a %s ' + % (op, field._name) + ) return value def _infer_geometry(value): - """Helper method that tries to infer the $geometry shape for a given value""" + """Helper method that tries to infer the $geometry shape for a + given value. + """ if isinstance(value, dict): - if "$geometry" in value: + if '$geometry' in value: return value elif 'coordinates' in value and 'type' in value: - return {"$geometry": value} - raise InvalidQueryError("Invalid $geometry dictionary should have " - "type and coordinates keys") + return {'$geometry': value} + raise InvalidQueryError('Invalid $geometry dictionary should have ' + 'type and coordinates keys') elif isinstance(value, (list, set)): # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon? # TODO: should both TypeError and IndexError be alike interpreted? try: value[0][0][0] - return {"$geometry": {"type": "Polygon", "coordinates": value}} + return {'$geometry': {'type': 'Polygon', 'coordinates': value}} except (TypeError, IndexError): pass try: value[0][0] - return {"$geometry": {"type": "LineString", "coordinates": value}} + return {'$geometry': {'type': 'LineString', 'coordinates': value}} except (TypeError, IndexError): pass try: value[0] - return {"$geometry": {"type": "Point", "coordinates": value}} + return {'$geometry': {'type': 'Point', 'coordinates': value}} except (TypeError, IndexError): pass - raise InvalidQueryError("Invalid $geometry data. Can be either a dictionary " - "or (nested) lists of coordinate(s)") + raise InvalidQueryError('Invalid $geometry data. Can be either a ' + 'dictionary or (nested) lists of coordinate(s)') diff --git a/mongoengine/queryset/visitor.py b/mongoengine/queryset/visitor.py index 84365f56b..bcf93a13e 100644 --- a/mongoengine/queryset/visitor.py +++ b/mongoengine/queryset/visitor.py @@ -69,9 +69,9 @@ def __init__(self, document): self.document = document def visit_combination(self, combination): - operator = "$and" + operator = '$and' if combination.operation == combination.OR: - operator = "$or" + operator = '$or' return {operator: combination.children} def visit_query(self, query): @@ -79,8 +79,7 @@ def visit_query(self, query): class QNode(object): - """Base class for nodes in query trees. - """ + """Base class for nodes in query trees.""" AND = 0 OR = 1 @@ -94,7 +93,8 @@ def accept(self, visitor): raise NotImplementedError def _combine(self, other, operation): - """Combine this node with another node into a QCombination object. + """Combine this node with another node into a QCombination + object. """ if getattr(other, 'empty', True): return self @@ -116,8 +116,8 @@ def __and__(self, other): class QCombination(QNode): - """Represents the combination of several conditions by a given logical - operator. + """Represents the combination of several conditions by a given + logical operator. """ def __init__(self, operation, children): diff --git a/mongoengine/signals.py b/mongoengine/signals.py index 64828448b..a892dec0d 100644 --- a/mongoengine/signals.py +++ b/mongoengine/signals.py @@ -1,7 +1,5 @@ -# -*- coding: utf-8 -*- - -__all__ = ['pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', - 'post_save', 'pre_delete', 'post_delete'] +__all__ = ('pre_init', 'post_init', 'pre_save', 'pre_save_post_validation', + 'post_save', 'pre_delete', 'post_delete') signals_available = False try: @@ -34,6 +32,7 @@ def _fail(self, *args, **kwargs): temporarily_connected_to = _fail del _fail + # the namespace for code signals. If you are not mongoengine code, do # not put signals in here. Create your own namespace instead. _signals = Namespace() diff --git a/setup.cfg b/setup.cfg index 09a593b42..1887c4768 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,13 +1,11 @@ [nosetests] -verbosity = 2 -detailed-errors = 1 -cover-erase = 1 -cover-branches = 1 -cover-package = mongoengine -tests = tests +verbosity=2 +detailed-errors=1 +tests=tests +cover-package=mongoengine [flake8] ignore=E501,F401,F403,F405,I201 -exclude=build,dist,docs,venv,.tox,.eggs,tests +exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests max-complexity=45 application-import-names=mongoengine,tests diff --git a/setup.py b/setup.py index 816a6de98..fa682d208 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,9 @@ def get_version(version_tuple): - if not isinstance(version_tuple[-1], int): - return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] + """Return the version tuple as a string, e.g. for (0, 10, 7), + return '0.10.7'. + """ return '.'.join(map(str, version_tuple)) @@ -41,31 +42,29 @@ def get_version(version_tuple): 'Operating System :: OS Independent', 'Programming Language :: Python', "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.6", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.2", "Programming Language :: Python :: 3.3", "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", 'Topic :: Database', 'Topic :: Software Development :: Libraries :: Python Modules', ] -extra_opts = {"packages": find_packages(exclude=["tests", "tests.*"])} +extra_opts = { + 'packages': find_packages(exclude=['tests', 'tests.*']), + 'tests_require': ['nose', 'coverage==4.2', 'blinker', 'Pillow>=2.0.0'] +} if sys.version_info[0] == 3: extra_opts['use_2to3'] = True - extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0'] - if "test" in sys.argv or "nosetests" in sys.argv: + if 'test' in sys.argv or 'nosetests' in sys.argv: extra_opts['packages'] = find_packages() - extra_opts['package_data'] = {"tests": ["fields/mongoengine.png", "fields/mongodb_leaf.png"]} + extra_opts['package_data'] = { + 'tests': ['fields/mongoengine.png', 'fields/mongodb_leaf.png']} else: - # coverage 4 does not support Python 3.2 anymore - extra_opts['tests_require'] = ['nose', 'coverage==3.7.1', 'blinker', 'Pillow>=2.0.0', 'python-dateutil'] - - if sys.version_info[0] == 2 and sys.version_info[1] == 6: - extra_opts['tests_require'].append('unittest2') + extra_opts['tests_require'] += ['python-dateutil'] setup( name='mongoengine', diff --git a/tests/__init__.py b/tests/__init__.py index b24df5d22..eab0ddc7c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,4 +2,3 @@ from document import * from queryset import * from fields import * -from migration import * diff --git a/tests/all_warnings/__init__.py b/tests/all_warnings/__init__.py index 53ce638c1..3aebe4baf 100644 --- a/tests/all_warnings/__init__.py +++ b/tests/all_warnings/__init__.py @@ -3,8 +3,6 @@ only get triggered on first hit. This way we can ensure its imported into the top level and called first by the test suite. """ -import sys -sys.path[0:0] = [""] import unittest import warnings diff --git a/tests/document/__init__.py b/tests/document/__init__.py index 1acc9f4b8..f71376ea6 100644 --- a/tests/document/__init__.py +++ b/tests/document/__init__.py @@ -1,5 +1,3 @@ -import sys -sys.path[0:0] = [""] import unittest from class_methods import * diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index 5da474aca..dd3addb76 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] import unittest from mongoengine import * diff --git a/tests/document/delta.py b/tests/document/delta.py index cd37f415f..add4fe8d3 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] import unittest from bson import SON diff --git a/tests/document/dynamic.py b/tests/document/dynamic.py index 85325b06b..a478df421 100644 --- a/tests/document/dynamic.py +++ b/tests/document/dynamic.py @@ -1,6 +1,4 @@ import unittest -import sys -sys.path[0:0] = [""] from mongoengine import * from mongoengine.connection import get_db @@ -143,11 +141,9 @@ def test_complex_data_lookups(self): def test_three_level_complex_data_lookups(self): """Ensure you can query three level document dynamic fields""" - p = self.Person() - p.misc = {'hello': {'hello2': 'world'}} - p.save() - # from pprint import pprint as pp; import pdb; pdb.set_trace(); - print self.Person.objects(misc__hello__hello2='world') + p = self.Person.objects.create( + misc={'hello': {'hello2': 'world'}} + ) self.assertEqual(1, self.Person.objects(misc__hello__hello2='world').count()) def test_complex_embedded_document_validation(self): diff --git a/tests/document/indexes.py b/tests/document/indexes.py index 75314bb06..af93e7db2 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -556,8 +556,8 @@ class BlogPost(Document): BlogPost.drop_collection() - for i in xrange(0, 10): - tags = [("tag %i" % n) for n in xrange(0, i % 2)] + for i in range(0, 10): + tags = [("tag %i" % n) for n in range(0, i % 2)] BlogPost(tags=tags).save() self.assertEqual(BlogPost.objects.count(), 10) diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 957938be0..2897e1d15 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] import unittest import warnings @@ -253,19 +251,17 @@ class Human(Mammal): pass self.assertEqual(classes, [Human]) def test_allow_inheritance(self): - """Ensure that inheritance may be disabled on simple classes and that - _cls and _subclasses will not be used. + """Ensure that inheritance is disabled by default on simple + classes and that _cls will not be used. """ - class Animal(Document): name = StringField() - def create_dog_class(): + # can't inherit because Animal didn't explicitly allow inheritance + with self.assertRaises(ValueError): class Dog(Animal): pass - self.assertRaises(ValueError, create_dog_class) - # Check that _cls etc aren't present on simple documents dog = Animal(name='dog').save() self.assertEqual(dog.to_mongo().keys(), ['_id', 'name']) @@ -275,17 +271,15 @@ class Dog(Animal): self.assertFalse('_cls' in obj) def test_cant_turn_off_inheritance_on_subclass(self): - """Ensure if inheritance is on in a subclass you cant turn it off + """Ensure if inheritance is on in a subclass you cant turn it off. """ - class Animal(Document): name = StringField() meta = {'allow_inheritance': True} - def create_mammal_class(): + with self.assertRaises(ValueError): class Mammal(Animal): meta = {'allow_inheritance': False} - self.assertRaises(ValueError, create_mammal_class) def test_allow_inheritance_abstract_document(self): """Ensure that abstract documents can set inheritance rules and that @@ -298,10 +292,9 @@ class FinalDocument(Document): class Animal(FinalDocument): name = StringField() - def create_mammal_class(): + with self.assertRaises(ValueError): class Mammal(Animal): pass - self.assertRaises(ValueError, create_mammal_class) # Check that _cls isn't present in simple documents doc = Animal(name='dog') @@ -360,29 +353,26 @@ class EuropeanCity(City): self.assertEqual(berlin.pk, berlin.auto_id_0) def test_abstract_document_creation_does_not_fail(self): - class City(Document): continent = StringField() meta = {'abstract': True, 'allow_inheritance': False} + bkk = City(continent='asia') self.assertEqual(None, bkk.pk) # TODO: expected error? Shouldn't we create a new error type? - self.assertRaises(KeyError, lambda: setattr(bkk, 'pk', 1)) + with self.assertRaises(KeyError): + setattr(bkk, 'pk', 1) def test_allow_inheritance_embedded_document(self): - """Ensure embedded documents respect inheritance - """ - + """Ensure embedded documents respect inheritance.""" class Comment(EmbeddedDocument): content = StringField() - def create_special_comment(): + with self.assertRaises(ValueError): class SpecialComment(Comment): pass - self.assertRaises(ValueError, create_special_comment) - doc = Comment(content='test') self.assertFalse('_cls' in doc.to_mongo()) @@ -454,11 +444,11 @@ class Human(Mammal): pass self.assertEqual(Guppy._get_collection_name(), 'fish') self.assertEqual(Human._get_collection_name(), 'human') - def create_bad_abstract(): + # ensure that a subclass of a non-abstract class can't be abstract + with self.assertRaises(ValueError): class EvilHuman(Human): evil = BooleanField(default=True) meta = {'abstract': True} - self.assertRaises(ValueError, create_bad_abstract) def test_abstract_embedded_documents(self): # 789: EmbeddedDocument shouldn't inherit abstract diff --git a/tests/document/instance.py b/tests/document/instance.py index 1342f981c..b92bafa9a 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -1,7 +1,4 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] - import bson import os import pickle @@ -16,12 +13,12 @@ PickleDynamicEmbedded, PickleDynamicTest) from mongoengine import * +from mongoengine.base import get_document, _document_registry +from mongoengine.connection import get_db from mongoengine.errors import (NotRegistered, InvalidDocumentError, InvalidQueryError, NotUniqueError, FieldDoesNotExist, SaveConditionError) from mongoengine.queryset import NULLIFY, Q -from mongoengine.connection import get_db -from mongoengine.base import get_document from mongoengine.context_managers import switch_db, query_counter from mongoengine import signals @@ -102,21 +99,18 @@ class Log(Document): self.assertEqual(options['size'], 4096) # Check that the document cannot be redefined with different options - def recreate_log_document(): - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 11, - } - # Create the collection by accessing Document.objects - Log.objects - self.assertRaises(InvalidCollectionError, recreate_log_document) + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_documents': 11, + } - Log.drop_collection() + # Accessing Document.objects creates the collection + with self.assertRaises(InvalidCollectionError): + Log.objects def test_capped_collection_default(self): - """Ensure that capped collections defaults work properly. - """ + """Ensure that capped collections defaults work properly.""" class Log(Document): date = DateTimeField(default=datetime.now) meta = { @@ -134,16 +128,14 @@ class Log(Document): self.assertEqual(options['size'], 10 * 2**20) # Check that the document with default value can be recreated - def recreate_log_document(): - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_documents': 10, - } - # Create the collection by accessing Document.objects - Log.objects - recreate_log_document() - Log.drop_collection() + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_documents': 10, + } + + # Create the collection by accessing Document.objects + Log.objects def test_capped_collection_no_max_size_problems(self): """Ensure that capped collections with odd max_size work properly. @@ -166,16 +158,14 @@ class Log(Document): self.assertTrue(options['size'] >= 10000) # Check that the document with odd max_size value can be recreated - def recreate_log_document(): - class Log(Document): - date = DateTimeField(default=datetime.now) - meta = { - 'max_size': 10000, - } - # Create the collection by accessing Document.objects - Log.objects - recreate_log_document() - Log.drop_collection() + class Log(Document): + date = DateTimeField(default=datetime.now) + meta = { + 'max_size': 10000, + } + + # Create the collection by accessing Document.objects + Log.objects def test_repr(self): """Ensure that unicode representation works @@ -286,7 +276,7 @@ class CompareStats(Document): list_stats = [] - for i in xrange(10): + for i in range(10): s = Stats() s.save() list_stats.append(s) @@ -356,14 +346,14 @@ class User(Document): self.assertEqual(User._fields['username'].db_field, '_id') self.assertEqual(User._meta['id_field'], 'username') - def create_invalid_user(): - User(name='test').save() # no primary key field - self.assertRaises(ValidationError, create_invalid_user) + # test no primary key field + self.assertRaises(ValidationError, User(name='test').save) - def define_invalid_user(): + # define a subclass with a different primary key field than the + # parent + with self.assertRaises(ValueError): class EmailUser(User): email = StringField(primary_key=True) - self.assertRaises(ValueError, define_invalid_user) class EmailUser(User): email = StringField() @@ -411,12 +401,10 @@ class NicePlace(Place): # Mimic Place and NicePlace definitions being in a different file # and the NicePlace model not being imported in at query time. - from mongoengine.base import _document_registry del(_document_registry['Place.NicePlace']) - def query_without_importing_nice_place(): - print Place.objects.all() - self.assertRaises(NotRegistered, query_without_importing_nice_place) + with self.assertRaises(NotRegistered): + list(Place.objects.all()) def test_document_registry_regressions(self): @@ -745,7 +733,7 @@ def clean(self): try: t.save() - except ValidationError, e: + except ValidationError as e: expect_msg = "Draft entries may not have a publication date." self.assertTrue(expect_msg in e.message) self.assertEqual(e.to_dict(), {'__all__': expect_msg}) @@ -784,7 +772,7 @@ class TestDocument(Document): t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) try: t.save() - except ValidationError, e: + except ValidationError as e: expect_msg = "Value of z != x + y" self.assertTrue(expect_msg in e.message) self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}}) @@ -798,8 +786,10 @@ class TestDocument(Document): def test_modify_empty(self): doc = self.Person(name="bob", age=10).save() - self.assertRaises( - InvalidDocumentError, lambda: self.Person().modify(set__age=10)) + + with self.assertRaises(InvalidDocumentError): + self.Person().modify(set__age=10) + self.assertDbEqual([dict(doc.to_mongo())]) def test_modify_invalid_query(self): @@ -807,9 +797,8 @@ def test_modify_invalid_query(self): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - self.assertRaises( - InvalidQueryError, - lambda: doc1.modify(dict(id=doc2.id), set__value=20)) + with self.assertRaises(InvalidQueryError): + doc1.modify({'id': doc2.id}, set__value=20) self.assertDbEqual(docs) @@ -818,7 +807,7 @@ def test_modify_match_another_document(self): doc2 = self.Person(name="jim", age=20).save() docs = [dict(doc1.to_mongo()), dict(doc2.to_mongo())] - assert not doc1.modify(dict(name=doc2.name), set__age=100) + assert not doc1.modify({'name': doc2.name}, set__age=100) self.assertDbEqual(docs) @@ -827,7 +816,7 @@ def test_modify_not_exists(self): doc2 = self.Person(id=ObjectId(), name="jim", age=20) docs = [dict(doc1.to_mongo())] - assert not doc2.modify(dict(name=doc2.name), set__age=100) + assert not doc2.modify({'name': doc2.name}, set__age=100) self.assertDbEqual(docs) @@ -1293,12 +1282,11 @@ class Doc(Document): def test_document_update(self): - def update_not_saved_raises(): + # try updating a non-saved document + with self.assertRaises(OperationError): person = self.Person(name='dcrosta') person.update(set__name='Dan Crosta') - self.assertRaises(OperationError, update_not_saved_raises) - author = self.Person(name='dcrosta') author.save() @@ -1308,19 +1296,17 @@ def update_not_saved_raises(): p1 = self.Person.objects.first() self.assertEqual(p1.name, author.name) - def update_no_value_raises(): + # try sending an empty update + with self.assertRaises(OperationError): person = self.Person.objects.first() person.update() - self.assertRaises(OperationError, update_no_value_raises) - - def update_no_op_should_default_to_set(): - person = self.Person.objects.first() - person.update(name="Dan") - person.reload() - return person.name - - self.assertEqual("Dan", update_no_op_should_default_to_set()) + # update that doesn't explicitly specify an operator should default + # to 'set__' + person = self.Person.objects.first() + person.update(name="Dan") + person.reload() + self.assertEqual("Dan", person.name) def test_update_unique_field(self): class Doc(Document): @@ -1329,8 +1315,8 @@ class Doc(Document): doc1 = Doc(name="first").save() doc2 = Doc(name="second").save() - self.assertRaises(NotUniqueError, lambda: - doc2.update(set__name=doc1.name)) + with self.assertRaises(NotUniqueError): + doc2.update(set__name=doc1.name) def test_embedded_update(self): """ @@ -1848,15 +1834,13 @@ class BlogPost(Document): def test_duplicate_db_fields_raise_invalid_document_error(self): """Ensure a InvalidDocumentError is thrown if duplicate fields - declare the same db_field""" - - def throw_invalid_document_error(): + declare the same db_field. + """ + with self.assertRaises(InvalidDocumentError): class Foo(Document): name = StringField() name2 = StringField(db_field='name') - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - def test_invalid_son(self): """Raise an error if loading invalid data""" class Occurrence(EmbeddedDocument): @@ -1868,11 +1852,13 @@ class Word(Document): forms = ListField(StringField(), default=list) occurs = ListField(EmbeddedDocumentField(Occurrence), default=list) - def raise_invalid_document(): - Word._from_son({'stem': [1, 2, 3], 'forms': 1, 'count': 'one', - 'occurs': {"hello": None}}) - - self.assertRaises(InvalidDocumentError, raise_invalid_document) + with self.assertRaises(InvalidDocumentError): + Word._from_son({ + 'stem': [1, 2, 3], + 'forms': 1, + 'count': 'one', + 'occurs': {"hello": None} + }) def test_reverse_delete_rule_cascade_and_nullify(self): """Ensure that a referenced document is also deleted upon deletion. @@ -2103,8 +2089,7 @@ class Foo(Document): self.assertEqual(Bar.objects.get().foo, None) def test_invalid_reverse_delete_rule_raise_errors(self): - - def throw_invalid_document_error(): + with self.assertRaises(InvalidDocumentError): class Blog(Document): content = StringField() authors = MapField(ReferenceField( @@ -2114,21 +2099,15 @@ class Blog(Document): self.Person, reverse_delete_rule=NULLIFY)) - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - - def throw_invalid_document_error_embedded(): + with self.assertRaises(InvalidDocumentError): class Parents(EmbeddedDocument): father = ReferenceField('Person', reverse_delete_rule=DENY) mother = ReferenceField('Person', reverse_delete_rule=DENY) - self.assertRaises( - InvalidDocumentError, throw_invalid_document_error_embedded) - def test_reverse_delete_rule_cascade_recurs(self): """Ensure that a chain of documents is also deleted upon cascaded deletion. """ - class BlogPost(Document): content = StringField() author = ReferenceField(self.Person, reverse_delete_rule=CASCADE) @@ -2344,15 +2323,14 @@ def test_picklable_on_signals(self): pickle_doc.save() pickle_doc.delete() - def test_throw_invalid_document_error(self): - - # test handles people trying to upsert - def throw_invalid_document_error(): + def test_override_method_with_field(self): + """Test creating a field with a field name that would override + the "validate" method. + """ + with self.assertRaises(InvalidDocumentError): class Blog(Document): validate = DictField() - self.assertRaises(InvalidDocumentError, throw_invalid_document_error) - def test_mutating_documents(self): class B(EmbeddedDocument): @@ -2815,11 +2793,10 @@ class LogEntry(Document): log.log = "Saving" log.save() - def change_shard_key(): + # try to change the shard key + with self.assertRaises(OperationError): log.machine = "127.0.0.1" - self.assertRaises(OperationError, change_shard_key) - def test_shard_key_in_embedded_document(self): class Foo(EmbeddedDocument): foo = StringField() @@ -2840,12 +2817,11 @@ class Bar(Document): bar_doc.bar = 'baz' bar_doc.save() - def change_shard_key(): + # try to change the shard key + with self.assertRaises(OperationError): bar_doc.foo.foo = 'something' bar_doc.save() - self.assertRaises(OperationError, change_shard_key) - def test_shard_key_primary(self): class LogEntry(Document): machine = StringField(primary_key=True) @@ -2866,11 +2842,10 @@ class LogEntry(Document): log.log = "Saving" log.save() - def change_shard_key(): + # try to change the shard key + with self.assertRaises(OperationError): log.machine = "127.0.0.1" - self.assertRaises(OperationError, change_shard_key) - def test_kwargs_simple(self): class Embedded(EmbeddedDocument): @@ -2955,11 +2930,9 @@ class Person(DynamicDocument): def test_bad_mixed_creation(self): """Ensure that document gives correct error when duplicating arguments """ - def construct_bad_instance(): + with self.assertRaises(TypeError): return self.Person("Test User", 42, name="Bad User") - self.assertRaises(TypeError, construct_bad_instance) - def test_data_contains_id_field(self): """Ensure that asking for _data returns 'id' """ @@ -3118,17 +3091,17 @@ class Person(Document): p4 = Person.objects()[0] p4.save() self.assertEquals(p4.height, 189) - + # However the default will not be fixed in DB self.assertEquals(Person.objects(height=189).count(), 0) - + # alter DB for the new default coll = Person._get_collection() for person in Person.objects.as_pymongo(): if 'height' not in person: person['height'] = 189 coll.save(person) - + self.assertEquals(Person.objects(height=189).count(), 1) def test_from_son(self): diff --git a/tests/document/json_serialisation.py b/tests/document/json_serialisation.py index f47b5de5b..110f1e14d 100644 --- a/tests/document/json_serialisation.py +++ b/tests/document/json_serialisation.py @@ -1,6 +1,3 @@ -import sys -sys.path[0:0] = [""] - import unittest import uuid diff --git a/tests/document/validation.py b/tests/document/validation.py index ba03366e5..105bc8b0b 100644 --- a/tests/document/validation.py +++ b/tests/document/validation.py @@ -1,7 +1,4 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] - import unittest from datetime import datetime @@ -60,7 +57,7 @@ class User(Document): try: User().validate() - except ValidationError, e: + except ValidationError as e: self.assertTrue("User:None" in e.message) self.assertEqual(e.to_dict(), { 'username': 'Field is required', @@ -70,7 +67,7 @@ class User(Document): user.name = None try: user.save() - except ValidationError, e: + except ValidationError as e: self.assertTrue("User:RossC0" in e.message) self.assertEqual(e.to_dict(), { 'name': 'Field is required'}) @@ -118,7 +115,7 @@ class Doc(Document): try: Doc(id="bad").validate() - except ValidationError, e: + except ValidationError as e: self.assertTrue("SubDoc:None" in e.message) self.assertEqual(e.to_dict(), { "e": {'val': 'OK could not be converted to int'}}) @@ -136,7 +133,7 @@ class Doc(Document): doc.e.val = "OK" try: doc.save() - except ValidationError, e: + except ValidationError as e: self.assertTrue("Doc:test" in e.message) self.assertEqual(e.to_dict(), { "e": {'val': 'OK could not be converted to int'}}) @@ -156,14 +153,14 @@ class Doc(Document): s = SubDoc() - self.assertRaises(ValidationError, lambda: s.validate()) + self.assertRaises(ValidationError, s.validate) d1.e = s d2.e = s del d1 - self.assertRaises(ValidationError, lambda: d2.validate()) + self.assertRaises(ValidationError, d2.validate) def test_parent_reference_in_child_document(self): """ diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 6cf4f1286..678786fa1 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1,11 +1,7 @@ # -*- coding: utf-8 -*- -import sys - import six from nose.plugins.skip import SkipTest -sys.path[0:0] = [""] - import datetime import unittest import uuid @@ -29,10 +25,9 @@ from mongoengine import * from mongoengine.connection import get_db -from mongoengine.base import _document_registry -from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList +from mongoengine.base import (BaseDict, BaseField, EmbeddedDocumentList, + _document_registry) from mongoengine.errors import NotRegistered, DoesNotExist -from mongoengine.python_support import PY3, b, bin_type __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") @@ -653,8 +648,8 @@ class LogEntry(Document): # Post UTC - microseconds are rounded (down) nearest millisecond and # dropped - d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) - d2 = datetime.datetime(1970, 01, 01, 00, 00, 01) + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) + d2 = datetime.datetime(1970, 1, 1, 0, 0, 1) log = LogEntry() log.date = d1 log.save() @@ -663,15 +658,15 @@ class LogEntry(Document): self.assertEqual(log.date, d2) # Post UTC - microseconds are rounded (down) nearest millisecond - d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) - d2 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9000) + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) + d2 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9000) log.date = d1 log.save() log.reload() self.assertNotEqual(log.date, d1) self.assertEqual(log.date, d2) - if not PY3: + if not six.PY3: # Pre UTC dates microseconds below 1000 are dropped # This does not seem to be true in PY3 d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, 999) @@ -691,7 +686,7 @@ class LogEntry(Document): LogEntry.drop_collection() - d1 = datetime.datetime(1970, 01, 01, 00, 00, 01) + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1) log = LogEntry() log.date = d1 log.validate() @@ -708,8 +703,8 @@ class LogEntry(Document): LogEntry.drop_collection() # create 60 log entries - for i in xrange(1950, 2010): - d = datetime.datetime(i, 01, 01, 00, 00, 01) + for i in range(1950, 2010): + d = datetime.datetime(i, 1, 1, 0, 0, 1) LogEntry(date=d).save() self.assertEqual(LogEntry.objects.count(), 60) @@ -756,7 +751,7 @@ class LogEntry(Document): # Post UTC - microseconds are rounded (down) nearest millisecond and # dropped - with default datetimefields - d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) log = LogEntry() log.date = d1 log.save() @@ -765,7 +760,7 @@ class LogEntry(Document): # Post UTC - microseconds are rounded (down) nearest millisecond - with # default datetimefields - d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 9999) + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 9999) log.date = d1 log.save() log.reload() @@ -782,7 +777,7 @@ class LogEntry(Document): # Pre UTC microseconds above 1000 is wonky - with default datetimefields # log.date has an invalid microsecond value so I can't construct # a date to compare. - for i in xrange(1001, 3113, 33): + for i in range(1001, 3113, 33): d1 = datetime.datetime(1969, 12, 31, 23, 59, 59, i) log.date = d1 log.save() @@ -792,7 +787,7 @@ class LogEntry(Document): self.assertEqual(log, log1) # Test string padding - microsecond = map(int, [math.pow(10, x) for x in xrange(6)]) + microsecond = map(int, [math.pow(10, x) for x in range(6)]) mm = dd = hh = ii = ss = [1, 10] for values in itertools.product([2014], mm, dd, hh, ii, ss, microsecond): @@ -814,7 +809,7 @@ class LogEntry(Document): LogEntry.drop_collection() - d1 = datetime.datetime(1970, 01, 01, 00, 00, 01, 999) + d1 = datetime.datetime(1970, 1, 1, 0, 0, 1, 999) log = LogEntry() log.date = d1 log.save() @@ -825,8 +820,8 @@ class LogEntry(Document): LogEntry.drop_collection() # create 60 log entries - for i in xrange(1950, 2010): - d = datetime.datetime(i, 01, 01, 00, 00, 01, 999) + for i in range(1950, 2010): + d = datetime.datetime(i, 1, 1, 0, 0, 1, 999) LogEntry(date=d).save() self.assertEqual(LogEntry.objects.count(), 60) @@ -1134,12 +1129,11 @@ class Simple(Document): e.mapping = [1] e.save() - def create_invalid_mapping(): + # try creating an invalid mapping + with self.assertRaises(ValidationError): e.mapping = ["abc"] e.save() - self.assertRaises(ValidationError, create_invalid_mapping) - Simple.drop_collection() def test_list_field_rejects_strings(self): @@ -1406,12 +1400,11 @@ class Simple(Document): e.mapping['someint'] = 1 e.save() - def create_invalid_mapping(): + # try creating an invalid mapping + with self.assertRaises(ValidationError): e.mapping['somestring'] = "abc" e.save() - self.assertRaises(ValidationError, create_invalid_mapping) - Simple.drop_collection() def test_dictfield_complex(self): @@ -1484,11 +1477,10 @@ class Simple(Document): self.assertEqual(BaseDict, type(e.mapping)) self.assertEqual({"ints": [3, 4]}, e.mapping) - def create_invalid_mapping(): + # try creating an invalid mapping + with self.assertRaises(ValueError): e.update(set__mapping={"somestrings": ["foo", "bar", ]}) - self.assertRaises(ValueError, create_invalid_mapping) - Simple.drop_collection() def test_mapfield(self): @@ -1503,18 +1495,14 @@ class Simple(Document): e.mapping['someint'] = 1 e.save() - def create_invalid_mapping(): + with self.assertRaises(ValidationError): e.mapping['somestring'] = "abc" e.save() - self.assertRaises(ValidationError, create_invalid_mapping) - - def create_invalid_class(): + with self.assertRaises(ValidationError): class NoDeclaredType(Document): mapping = MapField() - self.assertRaises(ValidationError, create_invalid_class) - Simple.drop_collection() def test_complex_mapfield(self): @@ -1543,14 +1531,10 @@ class Extensible(Document): self.assertTrue(isinstance(e2.mapping['somestring'], StringSetting)) self.assertTrue(isinstance(e2.mapping['someint'], IntegerSetting)) - def create_invalid_mapping(): + with self.assertRaises(ValidationError): e.mapping['someint'] = 123 e.save() - self.assertRaises(ValidationError, create_invalid_mapping) - - Extensible.drop_collection() - def test_embedded_mapfield_db_field(self): class Embedded(EmbeddedDocument): @@ -1760,8 +1744,8 @@ class Bar(Document): # Reference is no longer valid foo.delete() bar = Bar.objects.get() - self.assertRaises(DoesNotExist, lambda: getattr(bar, 'ref')) - self.assertRaises(DoesNotExist, lambda: getattr(bar, 'generic_ref')) + self.assertRaises(DoesNotExist, getattr, bar, 'ref') + self.assertRaises(DoesNotExist, getattr, bar, 'generic_ref') # When auto_dereference is disabled, there is no trouble returning DBRef bar = Bar.objects.get() @@ -2036,7 +2020,7 @@ class Person(Document): }) def test_cached_reference_fields_on_embedded_documents(self): - def build(): + with self.assertRaises(InvalidDocumentError): class Test(Document): name = StringField() @@ -2045,8 +2029,6 @@ class Test(Document): 'test': CachedReferenceField(Test) }) - self.assertRaises(InvalidDocumentError, build) - def test_cached_reference_auto_sync(self): class Person(Document): TYPES = ( @@ -2863,7 +2845,7 @@ class Attachment(Document): content_type = StringField() blob = BinaryField() - BLOB = b('\xe6\x00\xc4\xff\x07') + BLOB = six.b('\xe6\x00\xc4\xff\x07') MIME_TYPE = 'application/octet-stream' Attachment.drop_collection() @@ -2873,7 +2855,7 @@ class Attachment(Document): attachment_1 = Attachment.objects().first() self.assertEqual(MIME_TYPE, attachment_1.content_type) - self.assertEqual(BLOB, bin_type(attachment_1.blob)) + self.assertEqual(BLOB, six.binary_type(attachment_1.blob)) Attachment.drop_collection() @@ -2900,13 +2882,13 @@ class AttachmentSizeLimit(Document): attachment_required = AttachmentRequired() self.assertRaises(ValidationError, attachment_required.validate) - attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07')) + attachment_required.blob = Binary(six.b('\xe6\x00\xc4\xff\x07')) attachment_required.validate() attachment_size_limit = AttachmentSizeLimit( - blob=b('\xe6\x00\xc4\xff\x07')) + blob=six.b('\xe6\x00\xc4\xff\x07')) self.assertRaises(ValidationError, attachment_size_limit.validate) - attachment_size_limit.blob = b('\xe6\x00\xc4\xff') + attachment_size_limit.blob = six.b('\xe6\x00\xc4\xff') attachment_size_limit.validate() Attachment.drop_collection() @@ -3152,7 +3134,7 @@ class Shirt(Document): try: shirt.validate() - except ValidationError, error: + except ValidationError as error: # get the validation rules error_dict = error.to_dict() self.assertEqual(error_dict['size'], SIZE_MESSAGE) @@ -3181,7 +3163,7 @@ class Person(Document): self.db['mongoengine.counters'].drop() Person.drop_collection() - for x in xrange(10): + for x in range(10): Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) @@ -3205,7 +3187,7 @@ class Person(Document): self.db['mongoengine.counters'].drop() Person.drop_collection() - for x in xrange(10): + for x in range(10): Person(name="Person %s" % x).save() self.assertEqual(Person.id.get_next_value(), 11) @@ -3220,7 +3202,7 @@ class Person(Document): self.db['mongoengine.counters'].drop() Person.drop_collection() - for x in xrange(10): + for x in range(10): Person(name="Person %s" % x).save() self.assertEqual(Person.id.get_next_value(), '11') @@ -3236,7 +3218,7 @@ class Person(Document): self.db['mongoengine.counters'].drop() Person.drop_collection() - for x in xrange(10): + for x in range(10): Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'jelly.id'}) @@ -3261,7 +3243,7 @@ class Person(Document): self.db['mongoengine.counters'].drop() Person.drop_collection() - for x in xrange(10): + for x in range(10): Person(name="Person %s" % x).save() c = self.db['mongoengine.counters'].find_one({'_id': 'person.id'}) @@ -3323,7 +3305,7 @@ class Person(Document): Animal.drop_collection() Person.drop_collection() - for x in xrange(10): + for x in range(10): Animal(name="Animal %s" % x).save() Person(name="Person %s" % x).save() @@ -3353,7 +3335,7 @@ class Person(Document): self.db['mongoengine.counters'].drop() Person.drop_collection() - for x in xrange(10): + for x in range(10): p = Person(name="Person %s" % x) p.save() @@ -3540,7 +3522,7 @@ class Post(Document): self.assertRaises(ValidationError, post.validate) try: post.validate() - except ValidationError, error: + except ValidationError as error: # ValidationError.errors property self.assertTrue(hasattr(error, 'errors')) self.assertTrue(isinstance(error.errors, dict)) @@ -3601,8 +3583,6 @@ def test_tuples_as_tuples(self): Ensure that tuples remain tuples when they are inside a ComplexBaseField """ - from mongoengine.base import BaseField - class EnumField(BaseField): def __init__(self, **kwargs): @@ -3836,9 +3816,7 @@ def test_no_keyword_filter(self): filtered = self.post1.comments.filter() # Ensure nothing was changed - # < 2.6 Incompatible > - # self.assertListEqual(filtered, self.post1.comments) - self.assertEqual(filtered, self.post1.comments) + self.assertListEqual(filtered, self.post1.comments) def test_single_keyword_filter(self): """ @@ -3889,10 +3867,8 @@ def test_unknown_keyword_filter(self): Tests the filter method of a List of Embedded Documents when the keyword is not a known keyword. """ - # < 2.6 Incompatible > - # with self.assertRaises(AttributeError): - # self.post2.comments.filter(year=2) - self.assertRaises(AttributeError, self.post2.comments.filter, year=2) + with self.assertRaises(AttributeError): + self.post2.comments.filter(year=2) def test_no_keyword_exclude(self): """ @@ -3902,9 +3878,7 @@ def test_no_keyword_exclude(self): filtered = self.post1.comments.exclude() # Ensure everything was removed - # < 2.6 Incompatible > - # self.assertListEqual(filtered, []) - self.assertEqual(filtered, []) + self.assertListEqual(filtered, []) def test_single_keyword_exclude(self): """ @@ -3950,10 +3924,8 @@ def test_unknown_keyword_exclude(self): Tests the exclude method of a List of Embedded Documents when the keyword is not a known keyword. """ - # < 2.6 Incompatible > - # with self.assertRaises(AttributeError): - # self.post2.comments.exclude(year=2) - self.assertRaises(AttributeError, self.post2.comments.exclude, year=2) + with self.assertRaises(AttributeError): + self.post2.comments.exclude(year=2) def test_chained_filter_exclude(self): """ @@ -3991,10 +3963,7 @@ def test_single_keyword_get(self): single keyword. """ comment = self.post1.comments.get(author='user1') - - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment.author, 'user1') def test_multi_keyword_get(self): @@ -4003,10 +3972,7 @@ def test_multi_keyword_get(self): multiple keywords. """ comment = self.post2.comments.get(author='user2', message='message2') - - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment.author, 'user2') self.assertEqual(comment.message, 'message2') @@ -4015,44 +3981,32 @@ def test_no_keyword_multiple_return_get(self): Tests the get method of a List of Embedded Documents without a keyword to return multiple documents. """ - # < 2.6 Incompatible > - # with self.assertRaises(MultipleObjectsReturned): - # self.post1.comments.get() - self.assertRaises(MultipleObjectsReturned, self.post1.comments.get) + with self.assertRaises(MultipleObjectsReturned): + self.post1.comments.get() def test_keyword_multiple_return_get(self): """ Tests the get method of a List of Embedded Documents with a keyword to return multiple documents. """ - # < 2.6 Incompatible > - # with self.assertRaises(MultipleObjectsReturned): - # self.post2.comments.get(author='user2') - self.assertRaises( - MultipleObjectsReturned, self.post2.comments.get, author='user2' - ) + with self.assertRaises(MultipleObjectsReturned): + self.post2.comments.get(author='user2') def test_unknown_keyword_get(self): """ Tests the get method of a List of Embedded Documents with an unknown keyword. """ - # < 2.6 Incompatible > - # with self.assertRaises(AttributeError): - # self.post2.comments.get(year=2020) - self.assertRaises(AttributeError, self.post2.comments.get, year=2020) + with self.assertRaises(AttributeError): + self.post2.comments.get(year=2020) def test_no_result_get(self): """ Tests the get method of a List of Embedded Documents where get returns no results. """ - # < 2.6 Incompatible > - # with self.assertRaises(DoesNotExist): - # self.post1.comments.get(author='user3') - self.assertRaises( - DoesNotExist, self.post1.comments.get, author='user3' - ) + with self.assertRaises(DoesNotExist): + self.post1.comments.get(author='user3') def test_first(self): """ @@ -4062,9 +4016,7 @@ def test_first(self): comment = self.post1.comments.first() # Ensure a Comment object was returned. - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment, self.post1.comments[0]) def test_create(self): @@ -4077,22 +4029,14 @@ def test_create(self): self.post1.save() # Ensure the returned value is the comment object. - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment.author, 'user4') self.assertEqual(comment.message, 'message1') # Ensure the new comment was actually saved to the database. - # < 2.6 Incompatible > - # self.assertIn( - # comment, - # self.BlogPost.objects(comments__author='user4')[0].comments - # ) - self.assertTrue( - comment in self.BlogPost.objects( - comments__author='user4' - )[0].comments + self.assertIn( + comment, + self.BlogPost.objects(comments__author='user4')[0].comments ) def test_filtered_create(self): @@ -4107,22 +4051,14 @@ def test_filtered_create(self): self.post1.save() # Ensure the returned value is the comment object. - # < 2.6 Incompatible > - # self.assertIsInstance(comment, self.Comments) - self.assertTrue(isinstance(comment, self.Comments)) + self.assertIsInstance(comment, self.Comments) self.assertEqual(comment.author, 'user4') self.assertEqual(comment.message, 'message1') # Ensure the new comment was actually saved to the database. - # < 2.6 Incompatible > - # self.assertIn( - # comment, - # self.BlogPost.objects(comments__author='user4')[0].comments - # ) - self.assertTrue( - comment in self.BlogPost.objects( - comments__author='user4' - )[0].comments + self.assertIn( + comment, + self.BlogPost.objects(comments__author='user4')[0].comments ) def test_no_keyword_update(self): @@ -4135,22 +4071,14 @@ def test_no_keyword_update(self): self.post1.save() # Ensure that nothing was altered. - # < 2.6 Incompatible > - # self.assertIn( - # original[0], - # self.BlogPost.objects(id=self.post1.id)[0].comments - # ) - self.assertTrue( - original[0] in self.BlogPost.objects(id=self.post1.id)[0].comments + self.assertIn( + original[0], + self.BlogPost.objects(id=self.post1.id)[0].comments ) - # < 2.6 Incompatible > - # self.assertIn( - # original[1], - # self.BlogPost.objects(id=self.post1.id)[0].comments - # ) - self.assertTrue( - original[1] in self.BlogPost.objects(id=self.post1.id)[0].comments + self.assertIn( + original[1], + self.BlogPost.objects(id=self.post1.id)[0].comments ) # Ensure the method returned 0 as the number of entries @@ -4196,13 +4124,9 @@ def test_save(self): comments.save() # Ensure that the new comment has been added to the database. - # < 2.6 Incompatible > - # self.assertIn( - # new_comment, - # self.BlogPost.objects(id=self.post1.id)[0].comments - # ) - self.assertTrue( - new_comment in self.BlogPost.objects(id=self.post1.id)[0].comments + self.assertIn( + new_comment, + self.BlogPost.objects(id=self.post1.id)[0].comments ) def test_delete(self): @@ -4214,23 +4138,15 @@ def test_delete(self): # Ensure that all the comments under post1 were deleted in the # database. - # < 2.6 Incompatible > - # self.assertListEqual( - # self.BlogPost.objects(id=self.post1.id)[0].comments, [] - # ) - self.assertEqual( + self.assertListEqual( self.BlogPost.objects(id=self.post1.id)[0].comments, [] ) # Ensure that post1 comments were deleted from the list. - # < 2.6 Incompatible > - # self.assertListEqual(self.post1.comments, []) - self.assertEqual(self.post1.comments, []) + self.assertListEqual(self.post1.comments, []) # Ensure that comments still returned a EmbeddedDocumentList object. - # < 2.6 Incompatible > - # self.assertIsInstance(self.post1.comments, EmbeddedDocumentList) - self.assertTrue(isinstance(self.post1.comments, EmbeddedDocumentList)) + self.assertIsInstance(self.post1.comments, EmbeddedDocumentList) # Ensure that the delete method returned 2 as the number of entries # deleted from the database @@ -4270,21 +4186,15 @@ def test_filtered_delete(self): self.post1.save() # Ensure that only the user2 comment was deleted. - # < 2.6 Incompatible > - # self.assertNotIn( - # comment, self.BlogPost.objects(id=self.post1.id)[0].comments - # ) - self.assertTrue( - comment not in self.BlogPost.objects(id=self.post1.id)[0].comments + self.assertNotIn( + comment, self.BlogPost.objects(id=self.post1.id)[0].comments ) self.assertEqual( len(self.BlogPost.objects(id=self.post1.id)[0].comments), 1 ) # Ensure that the user2 comment no longer exists in the list. - # < 2.6 Incompatible > - # self.assertNotIn(comment, self.post1.comments) - self.assertTrue(comment not in self.post1.comments) + self.assertNotIn(comment, self.post1.comments) self.assertEqual(len(self.post1.comments), 1) # Ensure that the delete method returned 1 as the number of entries diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index 7c5abeac8..b266a5e58 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -1,18 +1,16 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] - import copy import os import unittest import tempfile import gridfs +import six from nose.plugins.skip import SkipTest from mongoengine import * from mongoengine.connection import get_db -from mongoengine.python_support import b, StringIO +from mongoengine.python_support import StringIO try: from PIL import Image @@ -49,7 +47,7 @@ class PutFile(Document): PutFile.drop_collection() - text = b('Hello, World!') + text = six.b('Hello, World!') content_type = 'text/plain' putfile = PutFile() @@ -88,8 +86,8 @@ class StreamFile(Document): StreamFile.drop_collection() - text = b('Hello, World!') - more_text = b('Foo Bar') + text = six.b('Hello, World!') + more_text = six.b('Foo Bar') content_type = 'text/plain' streamfile = StreamFile() @@ -123,8 +121,8 @@ class StreamFile(Document): StreamFile.drop_collection() - text = b('Hello, World!') - more_text = b('Foo Bar') + text = six.b('Hello, World!') + more_text = six.b('Foo Bar') content_type = 'text/plain' streamfile = StreamFile() @@ -155,8 +153,8 @@ def test_file_fields_set(self): class SetFile(Document): the_file = FileField() - text = b('Hello, World!') - more_text = b('Foo Bar') + text = six.b('Hello, World!') + more_text = six.b('Foo Bar') SetFile.drop_collection() @@ -185,7 +183,7 @@ class GridDocument(Document): GridDocument.drop_collection() with tempfile.TemporaryFile() as f: - f.write(b("Hello World!")) + f.write(six.b("Hello World!")) f.flush() # Test without default @@ -202,7 +200,7 @@ class GridDocument(Document): self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) # Test with default - doc_d = GridDocument(the_file=b('')) + doc_d = GridDocument(the_file=six.b('')) doc_d.save() doc_e = GridDocument.objects.with_id(doc_d.id) @@ -228,7 +226,7 @@ class TestFile(Document): # First instance test_file = TestFile() test_file.name = "Hello, World!" - test_file.the_file.put(b('Hello, World!')) + test_file.the_file.put(six.b('Hello, World!')) test_file.save() # Second instance @@ -282,7 +280,7 @@ class TestFile(Document): test_file = TestFile() self.assertFalse(bool(test_file.the_file)) - test_file.the_file.put(b('Hello, World!'), content_type='text/plain') + test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain') test_file.save() self.assertTrue(bool(test_file.the_file)) @@ -297,66 +295,66 @@ class TestFile(Document): test_file = TestFile() self.assertFalse(test_file.the_file in [{"test": 1}]) - def test_file_disk_space(self): - """ Test disk space usage when we delete/replace a file """ + def test_file_disk_space(self): + """ Test disk space usage when we delete/replace a file """ class TestFile(Document): the_file = FileField() - - text = b('Hello, World!') + + text = six.b('Hello, World!') content_type = 'text/plain' testfile = TestFile() testfile.the_file.put(text, content_type=content_type, filename="hello") testfile.save() - - # Now check fs.files and fs.chunks + + # Now check fs.files and fs.chunks db = TestFile._get_db() - + files = db.fs.files.find() chunks = db.fs.chunks.find() self.assertEquals(len(list(files)), 1) self.assertEquals(len(list(chunks)), 1) - # Deleting the docoument should delete the files + # Deleting the docoument should delete the files testfile.delete() - + files = db.fs.files.find() chunks = db.fs.chunks.find() self.assertEquals(len(list(files)), 0) self.assertEquals(len(list(chunks)), 0) - - # Test case where we don't store a file in the first place + + # Test case where we don't store a file in the first place testfile = TestFile() testfile.save() - + files = db.fs.files.find() chunks = db.fs.chunks.find() self.assertEquals(len(list(files)), 0) self.assertEquals(len(list(chunks)), 0) - + testfile.delete() - + files = db.fs.files.find() chunks = db.fs.chunks.find() self.assertEquals(len(list(files)), 0) self.assertEquals(len(list(chunks)), 0) - - # Test case where we overwrite the file + + # Test case where we overwrite the file testfile = TestFile() testfile.the_file.put(text, content_type=content_type, filename="hello") testfile.save() - - text = b('Bonjour, World!') + + text = six.b('Bonjour, World!') testfile.the_file.replace(text, content_type=content_type, filename="hello") testfile.save() - + files = db.fs.files.find() chunks = db.fs.chunks.find() self.assertEquals(len(list(files)), 1) self.assertEquals(len(list(chunks)), 1) - + testfile.delete() - + files = db.fs.files.find() chunks = db.fs.chunks.find() self.assertEquals(len(list(files)), 0) @@ -372,14 +370,14 @@ class TestImage(Document): TestImage.drop_collection() with tempfile.TemporaryFile() as f: - f.write(b("Hello World!")) + f.write(six.b("Hello World!")) f.flush() t = TestImage() try: t.image.put(f) self.fail("Should have raised an invalidation error") - except ValidationError, e: + except ValidationError as e: self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f) t = TestImage() @@ -496,7 +494,7 @@ class TestFile(Document): # First instance test_file = TestFile() test_file.name = "Hello, World!" - test_file.the_file.put(b('Hello, World!'), + test_file.the_file.put(six.b('Hello, World!'), name="hello.txt") test_file.save() @@ -504,16 +502,15 @@ class TestFile(Document): self.assertEqual(data.get('name'), 'hello.txt') test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), - b('Hello, World!')) + self.assertEqual(test_file.the_file.read(), six.b('Hello, World!')) test_file = TestFile.objects.first() - test_file.the_file = b('HELLO, WORLD!') + test_file.the_file = six.b('HELLO, WORLD!') test_file.save() test_file = TestFile.objects.first() self.assertEqual(test_file.the_file.read(), - b('HELLO, WORLD!')) + six.b('HELLO, WORLD!')) def test_copyable(self): class PutFile(Document): @@ -521,7 +518,7 @@ class PutFile(Document): PutFile.drop_collection() - text = b('Hello, World!') + text = six.b('Hello, World!') content_type = 'text/plain' putfile = PutFile() diff --git a/tests/fields/geo.py b/tests/fields/geo.py index c3f414812..1c5bccc0b 100644 --- a/tests/fields/geo.py +++ b/tests/fields/geo.py @@ -1,7 +1,4 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] - import unittest from mongoengine import * diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py deleted file mode 100644 index ef62d8760..000000000 --- a/tests/migration/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -import unittest - -from convert_to_new_inheritance_model import * -from decimalfield_as_float import * -from referencefield_dbref_to_object_id import * -from turn_off_inheritance import * -from uuidfield_to_binary import * - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/migration/convert_to_new_inheritance_model.py b/tests/migration/convert_to_new_inheritance_model.py deleted file mode 100644 index 89ee9e9d2..000000000 --- a/tests/migration/convert_to_new_inheritance_model.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -import unittest - -from mongoengine import Document, connect -from mongoengine.connection import get_db -from mongoengine.fields import StringField - -__all__ = ('ConvertToNewInheritanceModel', ) - - -class ConvertToNewInheritanceModel(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue - self.db.drop_collection(collection) - - def test_how_to_convert_to_the_new_inheritance_model(self): - """Demonstrates migrating from 0.7 to 0.8 - """ - - # 1. Declaration of the class - class Animal(Document): - name = StringField() - meta = { - 'allow_inheritance': True, - 'indexes': ['name'] - } - - # 2. Remove _types - collection = Animal._get_collection() - collection.update({}, {"$unset": {"_types": 1}}, multi=True) - - # 3. Confirm extra data is removed - count = collection.find({'_types': {"$exists": True}}).count() - self.assertEqual(0, count) - - # 4. Remove indexes - info = collection.index_information() - indexes_to_drop = [key for key, value in info.iteritems() - if '_types' in dict(value['key'])] - for index in indexes_to_drop: - collection.drop_index(index) - - # 5. Recreate indexes - Animal.ensure_indexes() diff --git a/tests/migration/decimalfield_as_float.py b/tests/migration/decimalfield_as_float.py deleted file mode 100644 index 3903c913f..000000000 --- a/tests/migration/decimalfield_as_float.py +++ /dev/null @@ -1,50 +0,0 @@ - # -*- coding: utf-8 -*- -import unittest -import decimal -from decimal import Decimal - -from mongoengine import Document, connect -from mongoengine.connection import get_db -from mongoengine.fields import StringField, DecimalField, ListField - -__all__ = ('ConvertDecimalField', ) - - -class ConvertDecimalField(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - def test_how_to_convert_decimal_fields(self): - """Demonstrates migrating from 0.7 to 0.8 - """ - - # 1. Old definition - using dbrefs - class Person(Document): - name = StringField() - money = DecimalField(force_string=True) - monies = ListField(DecimalField(force_string=True)) - - Person.drop_collection() - Person(name="Wilson Jr", money=Decimal("2.50"), - monies=[Decimal("2.10"), Decimal("5.00")]).save() - - # 2. Start the migration by changing the schema - # Change DecimalField - add precision and rounding settings - class Person(Document): - name = StringField() - money = DecimalField(precision=2, rounding=decimal.ROUND_HALF_UP) - monies = ListField(DecimalField(precision=2, - rounding=decimal.ROUND_HALF_UP)) - - # 3. Loop all the objects and mark parent as changed - for p in Person.objects: - p._mark_as_changed('money') - p._mark_as_changed('monies') - p.save() - - # 4. Confirmation of the fix! - wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] - self.assertTrue(isinstance(wilson['money'], float)) - self.assertTrue(all([isinstance(m, float) for m in wilson['monies']])) diff --git a/tests/migration/referencefield_dbref_to_object_id.py b/tests/migration/referencefield_dbref_to_object_id.py deleted file mode 100644 index d3acbe923..000000000 --- a/tests/migration/referencefield_dbref_to_object_id.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -import unittest - -from mongoengine import Document, connect -from mongoengine.connection import get_db -from mongoengine.fields import StringField, ReferenceField, ListField - -__all__ = ('ConvertToObjectIdsModel', ) - - -class ConvertToObjectIdsModel(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - def test_how_to_convert_to_object_id_reference_fields(self): - """Demonstrates migrating from 0.7 to 0.8 - """ - - # 1. Old definition - using dbrefs - class Person(Document): - name = StringField() - parent = ReferenceField('self', dbref=True) - friends = ListField(ReferenceField('self', dbref=True)) - - Person.drop_collection() - - p1 = Person(name="Wilson", parent=None).save() - f1 = Person(name="John", parent=None).save() - f2 = Person(name="Paul", parent=None).save() - f3 = Person(name="George", parent=None).save() - f4 = Person(name="Ringo", parent=None).save() - Person(name="Wilson Jr", parent=p1, friends=[f1, f2, f3, f4]).save() - - # 2. Start the migration by changing the schema - # Change ReferenceField as now dbref defaults to False - class Person(Document): - name = StringField() - parent = ReferenceField('self') - friends = ListField(ReferenceField('self')) - - # 3. Loop all the objects and mark parent as changed - for p in Person.objects: - p._mark_as_changed('parent') - p._mark_as_changed('friends') - p.save() - - # 4. Confirmation of the fix! - wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] - self.assertEqual(p1.id, wilson['parent']) - self.assertEqual([f1.id, f2.id, f3.id, f4.id], wilson['friends']) diff --git a/tests/migration/turn_off_inheritance.py b/tests/migration/turn_off_inheritance.py deleted file mode 100644 index ee461a84b..000000000 --- a/tests/migration/turn_off_inheritance.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- -import unittest - -from mongoengine import Document, connect -from mongoengine.connection import get_db -from mongoengine.fields import StringField - -__all__ = ('TurnOffInheritanceTest', ) - - -class TurnOffInheritanceTest(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - def tearDown(self): - for collection in self.db.collection_names(): - if 'system.' in collection: - continue - self.db.drop_collection(collection) - - def test_how_to_turn_off_inheritance(self): - """Demonstrates migrating from allow_inheritance = True to False. - """ - - # 1. Old declaration of the class - - class Animal(Document): - name = StringField() - meta = { - 'allow_inheritance': True, - 'indexes': ['name'] - } - - # 2. Turn off inheritance - class Animal(Document): - name = StringField() - meta = { - 'allow_inheritance': False, - 'indexes': ['name'] - } - - # 3. Remove _types and _cls - collection = Animal._get_collection() - collection.update({}, {"$unset": {"_types": 1, "_cls": 1}}, multi=True) - - # 3. Confirm extra data is removed - count = collection.find({"$or": [{'_types': {"$exists": True}}, - {'_cls': {"$exists": True}}]}).count() - assert count == 0 - - # 4. Remove indexes - info = collection.index_information() - indexes_to_drop = [key for key, value in info.iteritems() - if '_types' in dict(value['key']) - or '_cls' in dict(value['key'])] - for index in indexes_to_drop: - collection.drop_index(index) - - # 5. Recreate indexes - Animal.ensure_indexes() diff --git a/tests/migration/uuidfield_to_binary.py b/tests/migration/uuidfield_to_binary.py deleted file mode 100644 index a535e91fa..000000000 --- a/tests/migration/uuidfield_to_binary.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -import unittest -import uuid - -from mongoengine import Document, connect -from mongoengine.connection import get_db -from mongoengine.fields import StringField, UUIDField, ListField - -__all__ = ('ConvertToBinaryUUID', ) - - -class ConvertToBinaryUUID(unittest.TestCase): - - def setUp(self): - connect(db='mongoenginetest') - self.db = get_db() - - def test_how_to_convert_to_binary_uuid_fields(self): - """Demonstrates migrating from 0.7 to 0.8 - """ - - # 1. Old definition - using dbrefs - class Person(Document): - name = StringField() - uuid = UUIDField(binary=False) - uuids = ListField(UUIDField(binary=False)) - - Person.drop_collection() - Person(name="Wilson Jr", uuid=uuid.uuid4(), - uuids=[uuid.uuid4(), uuid.uuid4()]).save() - - # 2. Start the migration by changing the schema - # Change UUIDFIeld as now binary defaults to True - class Person(Document): - name = StringField() - uuid = UUIDField() - uuids = ListField(UUIDField()) - - # 3. Loop all the objects and mark parent as changed - for p in Person.objects: - p._mark_as_changed('uuid') - p._mark_as_changed('uuids') - p.save() - - # 4. Confirmation of the fix! - wilson = Person.objects(name="Wilson Jr").as_pymongo()[0] - self.assertTrue(isinstance(wilson['uuid'], uuid.UUID)) - self.assertTrue(all([isinstance(u, uuid.UUID) for u in wilson['uuids']])) diff --git a/tests/queryset/field_list.py b/tests/queryset/field_list.py index 7d66d2639..76d5f779b 100644 --- a/tests/queryset/field_list.py +++ b/tests/queryset/field_list.py @@ -1,6 +1,3 @@ -import sys -sys.path[0:0] = [""] - import unittest from mongoengine import * @@ -95,7 +92,7 @@ class MyDoc(Document): exclude = ['d', 'e'] only = ['b', 'c'] - qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) + qs = MyDoc.objects.fields(**{i: 1 for i in include}) self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1, 'd': 1, 'e': 1}) qs = qs.only(*only) @@ -103,14 +100,14 @@ class MyDoc(Document): qs = qs.exclude(*exclude) self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) - qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) + qs = MyDoc.objects.fields(**{i: 1 for i in include}) qs = qs.exclude(*exclude) self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) qs = qs.only(*only) self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) qs = MyDoc.objects.exclude(*exclude) - qs = qs.fields(**dict(((i, 1) for i in include))) + qs = qs.fields(**{i: 1 for i in include}) self.assertEqual(qs._loaded_fields.as_dict(), {'a': 1, 'b': 1, 'c': 1}) qs = qs.only(*only) self.assertEqual(qs._loaded_fields.as_dict(), {'b': 1, 'c': 1}) @@ -129,7 +126,7 @@ class MyDoc(Document): exclude = ['d', 'e'] only = ['b', 'c'] - qs = MyDoc.objects.fields(**dict(((i, 1) for i in include))) + qs = MyDoc.objects.fields(**{i: 1 for i in include}) qs = qs.exclude(*exclude) qs = qs.only(*only) qs = qs.fields(slice__b=5) diff --git a/tests/queryset/geo.py b/tests/queryset/geo.py index 9aac44f53..d10c51cd0 100644 --- a/tests/queryset/geo.py +++ b/tests/queryset/geo.py @@ -1,9 +1,5 @@ -import sys - -sys.path[0:0] = [""] - -import unittest from datetime import datetime, timedelta +import unittest from pymongo.errors import OperationFailure from mongoengine import * diff --git a/tests/queryset/modify.py b/tests/queryset/modify.py index e0c7d1fe6..607937f68 100644 --- a/tests/queryset/modify.py +++ b/tests/queryset/modify.py @@ -1,6 +1,3 @@ -import sys -sys.path[0:0] = [""] - import unittest from mongoengine import connect, Document, IntField @@ -99,4 +96,4 @@ def test_modify_with_fields(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 2c00838a6..e4c71de76 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -9,13 +9,13 @@ import pymongo from pymongo.errors import ConfigurationError from pymongo.read_preferences import ReadPreference - +import six from mongoengine import * from mongoengine.connection import get_connection, get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import InvalidQueryError -from mongoengine.python_support import IS_PYMONGO_3, PY3 +from mongoengine.python_support import IS_PYMONGO_3 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, QuerySet, QuerySetManager, queryset_manager) @@ -25,7 +25,10 @@ class db_ops_tracker(query_counter): def get_ops(self): - ignore_query = {"ns": {"$ne": "%s.system.indexes" % self.db.name}} + ignore_query = { + 'ns': {'$ne': '%s.system.indexes' % self.db.name}, + 'command.count': {'$ne': 'system.profile'} + } return list(self.db.system.profile.find(ignore_query)) @@ -94,12 +97,12 @@ class BlogPost(Document): author = ReferenceField(self.Person) author2 = GenericReferenceField() - def test_reference(): + # test addressing a field from a reference + with self.assertRaises(InvalidQueryError): list(BlogPost.objects(author__name="test")) - self.assertRaises(InvalidQueryError, test_reference) - - def test_generic_reference(): + # should fail for a generic reference as well + with self.assertRaises(InvalidQueryError): list(BlogPost.objects(author2__name="test")) def test_find(self): @@ -174,7 +177,7 @@ def test_find(self): # Test larger slice __repr__ self.Person.objects.delete() - for i in xrange(55): + for i in range(55): self.Person(name='A%s' % i, age=i).save() self.assertEqual(self.Person.objects.count(), 55) @@ -218,14 +221,15 @@ def test_find_one(self): person = self.Person.objects[1] self.assertEqual(person.name, "User B") - self.assertRaises(IndexError, self.Person.objects.__getitem__, 2) + with self.assertRaises(IndexError): + self.Person.objects[2] # Find a document using just the object id person = self.Person.objects.with_id(person1.id) self.assertEqual(person.name, "User A") - self.assertRaises( - InvalidQueryError, self.Person.objects(name="User A").with_id, person1.id) + with self.assertRaises(InvalidQueryError): + self.Person.objects(name="User A").with_id(person1.id) def test_find_only_one(self): """Ensure that a query using ``get`` returns at most one result. @@ -363,7 +367,8 @@ class A(Document): # test invalid batch size qs = A.objects.batch_size(-1) - self.assertRaises(ValueError, lambda: list(qs)) + with self.assertRaises(ValueError): + list(qs) def test_update_write_concern(self): """Test that passing write_concern works""" @@ -392,18 +397,14 @@ def test_update_update_has_a_value(self): """Test to ensure that update is passed a value to update to""" self.Person.drop_collection() - author = self.Person(name='Test User') - author.save() + author = self.Person.objects.create(name='Test User') - def update_raises(): + with self.assertRaises(OperationError): self.Person.objects(pk=author.pk).update({}) - def update_one_raises(): + with self.assertRaises(OperationError): self.Person.objects(pk=author.pk).update_one({}) - self.assertRaises(OperationError, update_raises) - self.assertRaises(OperationError, update_one_raises) - def test_update_array_position(self): """Ensure that updating by array position works. @@ -431,8 +432,8 @@ class Blog(Document): Blog.objects.create(posts=[post2, post1]) # Update all of the first comments of second posts of all blogs - Blog.objects().update(set__posts__1__comments__0__name="testc") - testc_blogs = Blog.objects(posts__1__comments__0__name="testc") + Blog.objects().update(set__posts__1__comments__0__name='testc') + testc_blogs = Blog.objects(posts__1__comments__0__name='testc') self.assertEqual(testc_blogs.count(), 2) Blog.drop_collection() @@ -441,14 +442,13 @@ class Blog(Document): # Update only the first blog returned by the query Blog.objects().update_one( - set__posts__1__comments__1__name="testc") - testc_blogs = Blog.objects(posts__1__comments__1__name="testc") + set__posts__1__comments__1__name='testc') + testc_blogs = Blog.objects(posts__1__comments__1__name='testc') self.assertEqual(testc_blogs.count(), 1) # Check that using this indexing syntax on a non-list fails - def non_list_indexing(): - Blog.objects().update(set__posts__1__comments__0__name__1="asdf") - self.assertRaises(InvalidQueryError, non_list_indexing) + with self.assertRaises(InvalidQueryError): + Blog.objects().update(set__posts__1__comments__0__name__1='asdf') Blog.drop_collection() @@ -516,15 +516,12 @@ class Simple(Document): self.assertEqual(simple.x, [1, 2, None, 4, 3, 2, 3, 4]) # Nested updates arent supported yet.. - def update_nested(): + with self.assertRaises(OperationError): Simple.drop_collection() Simple(x=[{'test': [1, 2, 3, 4]}]).save() Simple.objects(x__test=2).update(set__x__S__test__S=3) self.assertEqual(simple.x, [1, 2, 3, 4]) - self.assertRaises(OperationError, update_nested) - Simple.drop_collection() - def test_update_using_positional_operator_embedded_document(self): """Ensure that the embedded documents can be updated using the positional operator.""" @@ -617,11 +614,11 @@ class Club(Document): members = DictField() club = Club() - club.members['John'] = dict(gender="M", age=13) + club.members['John'] = {'gender': 'M', 'age': 13} club.save() Club.objects().update( - set__members={"John": dict(gender="F", age=14)}) + set__members={"John": {'gender': 'F', 'age': 14}}) club = Club.objects().first() self.assertEqual(club.members['John']['gender'], "F") @@ -802,7 +799,7 @@ class Blog(Document): post2 = Post(comments=[comment2, comment2]) blogs = [] - for i in xrange(1, 100): + for i in range(1, 100): blogs.append(Blog(title="post %s" % i, posts=[post1, post2])) Blog.objects.insert(blogs, load_bulk=False) @@ -839,30 +836,31 @@ class Blog(Document): self.assertEqual(Blog.objects.count(), 2) - # test handles people trying to upsert - def throw_operation_error(): + # test inserting an existing document (shouldn't be allowed) + with self.assertRaises(OperationError): + blog = Blog.objects.first() + Blog.objects.insert(blog) + + # test inserting a query set + with self.assertRaises(OperationError): blogs = Blog.objects Blog.objects.insert(blogs) - self.assertRaises(OperationError, throw_operation_error) - - # Test can insert new doc + # insert a new doc new_post = Blog(title="code123", id=ObjectId()) Blog.objects.insert(new_post) - # test handles other classes being inserted - def throw_operation_error_wrong_doc(): - class Author(Document): - pass - Blog.objects.insert(Author()) + class Author(Document): + pass - self.assertRaises(OperationError, throw_operation_error_wrong_doc) + # try inserting a different document class + with self.assertRaises(OperationError): + Blog.objects.insert(Author()) - def throw_operation_error_not_a_document(): + # try inserting a non-document + with self.assertRaises(OperationError): Blog.objects.insert("HELLO WORLD") - self.assertRaises(OperationError, throw_operation_error_not_a_document) - Blog.drop_collection() blog1 = Blog(title="code", posts=[post1, post2]) @@ -882,14 +880,13 @@ def throw_operation_error_not_a_document(): blog3 = Blog(title="baz", posts=[post1, post2]) Blog.objects.insert([blog1, blog2]) - def throw_operation_error_not_unique(): + with self.assertRaises(NotUniqueError): Blog.objects.insert([blog2, blog3]) - self.assertRaises(NotUniqueError, throw_operation_error_not_unique) self.assertEqual(Blog.objects.count(), 2) - Blog.objects.insert([blog2, blog3], write_concern={"w": 0, - 'continue_on_error': True}) + Blog.objects.insert([blog2, blog3], + write_concern={"w": 0, 'continue_on_error': True}) self.assertEqual(Blog.objects.count(), 3) def test_get_changed_fields_query_count(self): @@ -1022,7 +1019,7 @@ def __repr__(self): Doc.drop_collection() - for i in xrange(1000): + for i in range(1000): Doc(number=i).save() docs = Doc.objects.order_by('number') @@ -1176,7 +1173,7 @@ def assertSequence(self, qs, expected): qs = list(qs) expected = list(expected) self.assertEqual(len(qs), len(expected)) - for i in xrange(len(qs)): + for i in range(len(qs)): self.assertEqual(qs[i], expected[i]) def test_ordering(self): @@ -1216,7 +1213,8 @@ class BlogPost(Document): self.assertSequence(qs, expected) def test_clear_ordering(self): - """ Ensure that the default ordering can be cleared by calling order_by(). + """Ensure that the default ordering can be cleared by calling + order_by() w/o any arguments. """ class BlogPost(Document): title = StringField() @@ -1232,12 +1230,13 @@ class BlogPost(Document): BlogPost.objects.filter(title='whatever').first() self.assertEqual(len(q.get_ops()), 1) self.assertEqual( - q.get_ops()[0]['query']['$orderby'], {u'published_date': -1}) + q.get_ops()[0]['query']['$orderby'], + {'published_date': -1} + ) with db_ops_tracker() as q: BlogPost.objects.filter(title='whatever').order_by().first() self.assertEqual(len(q.get_ops()), 1) - print q.get_ops()[0]['query'] self.assertFalse('$orderby' in q.get_ops()[0]['query']) def test_no_ordering_for_get(self): @@ -1710,7 +1709,7 @@ class Log(Document): Log.drop_collection() - for i in xrange(10): + for i in range(10): Log().save() Log.objects()[3:5].delete() @@ -1910,12 +1909,10 @@ class Site(Document): Site.objects(id=s.id).update_one(pull__collaborators__user='Esteban') self.assertEqual(Site.objects.first().collaborators, []) - def pull_all(): + with self.assertRaises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__user=['Ross']) - self.assertRaises(InvalidQueryError, pull_all) - def test_pull_from_nested_embedded(self): class User(EmbeddedDocument): @@ -1946,12 +1943,10 @@ class Site(Document): pull__collaborators__unhelpful={'name': 'Frank'}) self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) - def pull_all(): + with self.assertRaises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__name=['Ross']) - self.assertRaises(InvalidQueryError, pull_all) - def test_pull_from_nested_mapfield(self): class Collaborator(EmbeddedDocument): @@ -1980,12 +1975,10 @@ class Site(Document): pull__collaborators__unhelpful={'user': 'Frank'}) self.assertEqual(Site.objects.first().collaborators['unhelpful'], []) - def pull_all(): + with self.assertRaises(InvalidQueryError): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__user=['Ross']) - self.assertRaises(InvalidQueryError, pull_all) - def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -2610,7 +2603,7 @@ class BlogPost(Document): BlogPost(hits=2, tags=['music', 'actors']).save() def test_assertions(f): - f = dict((key, int(val)) for key, val in f.items()) + f = {key: int(val) for key, val in f.items()} self.assertEqual( set(['music', 'film', 'actors', 'watch']), set(f.keys())) self.assertEqual(f['music'], 3) @@ -2625,7 +2618,7 @@ def test_assertions(f): # Ensure query is taken into account def test_assertions(f): - f = dict((key, int(val)) for key, val in f.items()) + f = {key: int(val) for key, val in f.items()} self.assertEqual(set(['music', 'actors', 'watch']), set(f.keys())) self.assertEqual(f['music'], 2) self.assertEqual(f['actors'], 1) @@ -2689,7 +2682,7 @@ class Person(Document): doc.save() def test_assertions(f): - f = dict((key, int(val)) for key, val in f.items()) + f = {key: int(val) for key, val in f.items()} self.assertEqual( set(['62-3331-1656', '62-3332-1656']), set(f.keys())) self.assertEqual(f['62-3331-1656'], 2) @@ -2703,7 +2696,7 @@ def test_assertions(f): # Ensure query is taken into account def test_assertions(f): - f = dict((key, int(val)) for key, val in f.items()) + f = {key: int(val) for key, val in f.items()} self.assertEqual(set(['62-3331-1656']), set(f.keys())) self.assertEqual(f['62-3331-1656'], 2) @@ -2810,10 +2803,10 @@ class Test(Document): Test.drop_collection() - for i in xrange(50): + for i in range(50): Test(val=1).save() - for i in xrange(20): + for i in range(20): Test(val=2).save() freqs = Test.objects.item_frequencies( @@ -3603,7 +3596,7 @@ class Post(Document): Post.drop_collection() - for i in xrange(10): + for i in range(10): Post(title="Post %s" % i).save() self.assertEqual(5, Post.objects.limit(5).skip(5).count(with_limit_and_skip=True)) @@ -3618,7 +3611,7 @@ class MyDoc(Document): pass MyDoc.drop_collection() - for i in xrange(0, 10): + for i in range(0, 10): MyDoc().save() self.assertEqual(MyDoc.objects.count(), 10) @@ -3674,7 +3667,7 @@ class Number(Document): Number.drop_collection() - for i in xrange(1, 101): + for i in range(1, 101): t = Number(n=i) t.save() @@ -3821,11 +3814,9 @@ class IntPair(Document): self.assertTrue(a in results) self.assertTrue(c in results) - def invalid_where(): + with self.assertRaises(TypeError): list(IntPair.objects.where(fielda__gte=3)) - self.assertRaises(TypeError, invalid_where) - def test_scalar(self): class Organization(Document): @@ -4081,7 +4072,7 @@ def test_scalar_cursor_behaviour(self): # Test larger slice __repr__ self.Person.objects.delete() - for i in xrange(55): + for i in range(55): self.Person(name='A%s' % i, age=i).save() self.assertEqual(self.Person.objects.scalar('name').count(), 55) @@ -4089,7 +4080,7 @@ def test_scalar_cursor_behaviour(self): "A0", "%s" % self.Person.objects.order_by('name').scalar('name').first()) self.assertEqual( "A0", "%s" % self.Person.objects.scalar('name').order_by('name')[0]) - if PY3: + if six.PY3: self.assertEqual("['A1', 'A2']", "%s" % self.Person.objects.order_by( 'age').scalar('name')[1:3]) self.assertEqual("['A51', 'A52']", "%s" % self.Person.objects.order_by( @@ -4107,7 +4098,7 @@ def test_scalar_cursor_behaviour(self): pks = self.Person.objects.order_by('age').scalar('pk')[1:3] names = self.Person.objects.scalar('name').in_bulk(list(pks)).values() - if PY3: + if six.PY3: expected = "['A1', 'A2']" else: expected = "[u'A1', u'A2']" @@ -4463,7 +4454,7 @@ class Person(Document): name = StringField() Person.drop_collection() - for i in xrange(100): + for i in range(100): Person(name="No: %s" % i).save() with query_counter() as q: @@ -4494,7 +4485,7 @@ class Person(Document): name = StringField() Person.drop_collection() - for i in xrange(100): + for i in range(100): Person(name="No: %s" % i).save() with query_counter() as q: @@ -4538,7 +4529,7 @@ class Noddy(Document): fields = DictField() Noddy.drop_collection() - for i in xrange(100): + for i in range(100): noddy = Noddy() for j in range(20): noddy.fields["key" + str(j)] = "value " + str(j) @@ -4550,7 +4541,9 @@ class Noddy(Document): self.assertEqual(counter, 100) self.assertEqual(len(list(docs)), 100) - self.assertRaises(TypeError, lambda: len(docs)) + + with self.assertRaises(TypeError): + len(docs) with query_counter() as q: self.assertEqual(q, 0) @@ -4739,7 +4732,7 @@ class Person(Document): name = StringField() Person.drop_collection() - for i in xrange(100): + for i in range(100): Person(name="No: %s" % i).save() with query_counter() as q: @@ -4863,10 +4856,10 @@ class Person(Document): ]) def test_delete_count(self): - [self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)] + [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] self.assertEqual(self.Person.objects().delete(), 3) # test ordinary QuerySey delete count - [self.Person(name="User {0}".format(i), age=i * 10).save() for i in xrange(1, 4)] + [self.Person(name="User {0}".format(i), age=i * 10).save() for i in range(1, 4)] self.assertEqual(self.Person.objects().skip(1).delete(), 2) # test Document delete with existing documents @@ -4875,12 +4868,14 @@ def test_delete_count(self): def test_max_time_ms(self): # 778: max_time_ms can get only int or None as input - self.assertRaises(TypeError, self.Person.objects(name="name").max_time_ms, "not a number") + self.assertRaises(TypeError, + self.Person.objects(name="name").max_time_ms, + 'not a number') def test_subclass_field_query(self): class Animal(Document): is_mamal = BooleanField() - meta = dict(allow_inheritance=True) + meta = {'allow_inheritance': True} class Cat(Animal): whiskers_length = FloatField() @@ -4925,7 +4920,7 @@ def test_len_during_iteration(self): class Data(Document): pass - for i in xrange(300): + for i in range(300): Data().save() records = Data.objects.limit(250) @@ -4957,7 +4952,7 @@ def test_iteration_within_iteration(self): class Data(Document): pass - for i in xrange(300): + for i in range(300): Data().save() qs = Data.objects.limit(250) diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 06fe4ea5d..20ab0b3fc 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -238,7 +238,8 @@ class Event(Document): box = [(35.0, -125.0), (40.0, -100.0)] # I *meant* to execute location__within_box=box events = Event.objects(location__within=box) - self.assertRaises(InvalidQueryError, lambda: events.count()) + with self.assertRaises(InvalidQueryError): + events.count() if __name__ == '__main__': diff --git a/tests/queryset/visitor.py b/tests/queryset/visitor.py index 7e5fcf195..6f020e881 100644 --- a/tests/queryset/visitor.py +++ b/tests/queryset/visitor.py @@ -185,7 +185,7 @@ class TestDoc(Document): x = IntField() TestDoc.drop_collection() - for i in xrange(1, 101): + for i in range(1, 101): t = TestDoc(x=i) t.save() @@ -268,14 +268,13 @@ class BlogPost(Document): self.assertEqual(self.Person.objects(Q(age__in=[20, 30])).count(), 3) # Test invalid query objs - def wrong_query_objs(): + with self.assertRaises(InvalidQueryError): self.Person.objects('user1') - def wrong_query_objs_filter(): - self.Person.objects('user1') + # filter should fail, too + with self.assertRaises(InvalidQueryError): + self.Person.objects.filter('user1') - self.assertRaises(InvalidQueryError, wrong_query_objs) - self.assertRaises(InvalidQueryError, wrong_query_objs_filter) def test_q_regex(self): """Ensure that Q objects can be queried using regexes. diff --git a/tests/test_connection.py b/tests/test_connection.py index e64318918..d8f1a79e4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,9 +1,6 @@ -import sys import datetime from pymongo.errors import OperationFailure -sys.path[0:0] = [""] - try: import unittest2 as unittest except ImportError: @@ -19,7 +16,8 @@ ) from mongoengine.python_support import IS_PYMONGO_3 import mongoengine.connection -from mongoengine.connection import get_db, get_connection, ConnectionError +from mongoengine.connection import (MongoEngineConnectionError, get_db, + get_connection) def get_tz_awareness(connection): @@ -159,7 +157,10 @@ def test_connect_uri(self): c.mongoenginetest.add_user("username", "password") if not IS_PYMONGO_3: - self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') + self.assertRaises( + MongoEngineConnectionError, connect, 'testdb_uri_bad', + host='mongodb://test:password@localhost' + ) connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') @@ -229,10 +230,11 @@ def test_connect_uri_with_authsource(self): self.assertRaises(OperationFailure, test_conn.server_info) else: self.assertRaises( - ConnectionError, connect, 'mongoenginetest', alias='test1', + MongoEngineConnectionError, connect, 'mongoenginetest', + alias='test1', host='mongodb://username2:password@localhost/mongoenginetest' ) - self.assertRaises(ConnectionError, get_db, 'test1') + self.assertRaises(MongoEngineConnectionError, get_db, 'test1') # Authentication succeeds with "authSource" connect( @@ -253,7 +255,7 @@ def test_register_connection(self): """ register_connection('testdb', 'mongoenginetest2') - self.assertRaises(ConnectionError, get_connection) + self.assertRaises(MongoEngineConnectionError, get_connection) conn = get_connection('testdb') self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index c201a5fc0..0f6bf815e 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -1,5 +1,3 @@ -import sys -sys.path[0:0] = [""] import unittest from mongoengine import * @@ -79,7 +77,7 @@ class Group(Document): User.drop_collection() Group.drop_collection() - for i in xrange(1, 51): + for i in range(1, 51): User(name='user %s' % i).save() user = User.objects.first() @@ -117,7 +115,7 @@ class Group(Document): User.drop_collection() Group.drop_collection() - for i in xrange(1, 51): + for i in range(1, 51): User(name='user %s' % i).save() user = User.objects.first() @@ -195,7 +193,7 @@ def test_query_counter(self): with query_counter() as q: self.assertEqual(0, q) - for i in xrange(1, 51): + for i in range(1, 51): db.test.find({}).count() self.assertEqual(50, q) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index eb40e767f..6830a188f 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -23,7 +23,8 @@ def test_repr(self): self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}') def test_init_fails_on_nonexisting_attrs(self): - self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) + with self.assertRaises(AttributeError): + self.dtype(a=1, b=2, d=3) def test_eq(self): d = self.dtype(a=1, b=1, c=1) @@ -46,14 +47,12 @@ def test_setattr_getattr(self): d = self.dtype() d.a = 1 self.assertEqual(d.a, 1) - self.assertRaises(AttributeError, lambda: d.b) + self.assertRaises(AttributeError, getattr, d, 'b') def test_setattr_raises_on_nonexisting_attr(self): d = self.dtype() - - def _f(): + with self.assertRaises(AttributeError): d.x = 1 - self.assertRaises(AttributeError, _f) def test_setattr_getattr_special(self): d = self.strict_dict_class(["items"]) diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 11bdd6121..7f58a85b0 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] import unittest from bson import DBRef, ObjectId @@ -32,7 +30,7 @@ class Group(Document): User.drop_collection() Group.drop_collection() - for i in xrange(1, 51): + for i in range(1, 51): user = User(name='user %s' % i) user.save() @@ -90,7 +88,7 @@ class Group(Document): User.drop_collection() Group.drop_collection() - for i in xrange(1, 51): + for i in range(1, 51): user = User(name='user %s' % i) user.save() @@ -162,7 +160,7 @@ class Group(Document): User.drop_collection() Group.drop_collection() - for i in xrange(1, 26): + for i in range(1, 26): user = User(name='user %s' % i) user.save() @@ -440,7 +438,7 @@ class Group(Document): Group.drop_collection() members = [] - for i in xrange(1, 51): + for i in range(1, 51): a = UserA(name='User A %s' % i) a.save() @@ -531,7 +529,7 @@ class Group(Document): Group.drop_collection() members = [] - for i in xrange(1, 51): + for i in range(1, 51): a = UserA(name='User A %s' % i) a.save() @@ -614,15 +612,15 @@ class Group(Document): Group.drop_collection() members = [] - for i in xrange(1, 51): + for i in range(1, 51): user = User(name='user %s' % i) user.save() members.append(user) - group = Group(members=dict([(str(u.id), u) for u in members])) + group = Group(members={str(u.id): u for u in members}) group.save() - group = Group(members=dict([(str(u.id), u) for u in members])) + group = Group(members={str(u.id): u for u in members}) group.save() with query_counter() as q: @@ -687,7 +685,7 @@ class Group(Document): Group.drop_collection() members = [] - for i in xrange(1, 51): + for i in range(1, 51): a = UserA(name='User A %s' % i) a.save() @@ -699,9 +697,9 @@ class Group(Document): members += [a, b, c] - group = Group(members=dict([(str(u.id), u) for u in members])) + group = Group(members={str(u.id): u for u in members}) group.save() - group = Group(members=dict([(str(u.id), u) for u in members])) + group = Group(members={str(u.id): u for u in members}) group.save() with query_counter() as q: @@ -783,16 +781,16 @@ class Group(Document): Group.drop_collection() members = [] - for i in xrange(1, 51): + for i in range(1, 51): a = UserA(name='User A %s' % i) a.save() members += [a] - group = Group(members=dict([(str(u.id), u) for u in members])) + group = Group(members={str(u.id): u for u in members}) group.save() - group = Group(members=dict([(str(u.id), u) for u in members])) + group = Group(members={str(u.id): u for u in members}) group.save() with query_counter() as q: @@ -866,7 +864,7 @@ class Group(Document): Group.drop_collection() members = [] - for i in xrange(1, 51): + for i in range(1, 51): a = UserA(name='User A %s' % i) a.save() @@ -878,9 +876,9 @@ class Group(Document): members += [a, b, c] - group = Group(members=dict([(str(u.id), u) for u in members])) + group = Group(members={str(u.id): u for u in members}) group.save() - group = Group(members=dict([(str(u.id), u) for u in members])) + group = Group(members={str(u.id): u for u in members}) group.save() with query_counter() as q: @@ -1103,7 +1101,7 @@ class Group(Document): User.drop_collection() Group.drop_collection() - for i in xrange(1, 51): + for i in range(1, 51): User(name='user %s' % i).save() Group(name="Test", members=User.objects).save() @@ -1132,7 +1130,7 @@ class Group(Document): User.drop_collection() Group.drop_collection() - for i in xrange(1, 51): + for i in range(1, 51): User(name='user %s' % i).save() Group(name="Test", members=User.objects).save() @@ -1169,7 +1167,7 @@ class Group(Document): Group.drop_collection() members = [] - for i in xrange(1, 51): + for i in range(1, 51): a = UserA(name='User A %s' % i).save() b = UserB(name='User B %s' % i).save() c = UserC(name='User C %s' % i).save() diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index 361cff412..a53f5903e 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -1,6 +1,3 @@ -import sys - -sys.path[0:0] = [""] import unittest from pymongo import ReadPreference @@ -18,7 +15,7 @@ import mongoengine from mongoengine import * -from mongoengine.connection import ConnectionError +from mongoengine.connection import MongoEngineConnectionError class ConnectionTest(unittest.TestCase): @@ -41,7 +38,7 @@ def test_replicaset_uri_passes_read_preference(self): conn = connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=READ_PREF) - except ConnectionError, e: + except MongoEngineConnectionError as e: return if not isinstance(conn, CONN_CLASS): diff --git a/tests/test_signals.py b/tests/test_signals.py index 23da7cd4a..df687d0ee 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import sys -sys.path[0:0] = [""] import unittest from mongoengine import *