diff --git a/.github/workflows/development.yaml b/.github/workflows/development.yaml index 9785580b3..c4c2b3475 100644 --- a/.github/workflows/development.yaml +++ b/.github/workflows/development.yaml @@ -15,7 +15,7 @@ jobs: strategy: matrix: py_ver: ["3.8"] - mysql_ver: ["8.0", "5.7", "5.6"] + mysql_ver: ["8.0", "5.7"] include: - py_ver: "3.7" mysql_ver: "5.7" diff --git a/CHANGELOG.md b/CHANGELOG.md index 6625d1d5b..d59027904 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,23 @@ ## Release notes +### 0.13.1 -- TBD +* Add `None` as an alias for `IS NULL` comparison in `dict` restrictions (#824) PR #893 +* Drop support for MySQL 5.6 since it has reached EOL PR #893 +* Bugfix - `schema.list_tables()` is not topologically sorted (#838) PR #893 +* Bugfix - Diagram part tables do not show proper class name (#882) PR #893 +* Bugfix - Error in complex restrictions (#892) PR #893 +* Bugfix - WHERE and GROUP BY clases are dropped on joins with aggregation (#898, #899) PR #893 + ### 0.13.0 -- Mar 24, 2021 -* Re-implement query transpilation into SQL, fixing issues (#386, #449, #450, #484). PR #754 -* Re-implement cascading deletes for better performance. PR #839. -* Add table method `.update1` to update a row in the table with new values PR #763 -* Python datatypes are now enabled by default in blobs (#761). PR #785 +* Re-implement query transpilation into SQL, fixing issues (#386, #449, #450, #484, #558). PR #754 +* Re-implement cascading deletes for better performance. PR #839 +* Add support for deferred schema activation to allow for greater modularity. (#834) PR #839 +* Add query caching mechanism for offline development (#550) PR #839 +* Add table method `.update1` to update a row in the table with new values (#867) PR #763, #889 +* Python datatypes are now enabled by default in blobs (#761). PR #859 * Added permissive join and restriction operators `@` and `^` (#785) PR #754 * Support DataJoint datatype and connection plugins (#715, #729) PR 730, #735 -* Add `dj.key_hash` alias to `dj.hash.key_hash` +* Add `dj.key_hash` alias to `dj.hash.key_hash` (#804) PR #862 * Default enable_python_native_blobs to True * Bugfix - Regression error on joins with same attribute name (#857) PR #878 * Bugfix - Error when `fetch1('KEY')` when `dj.config['fetch_format']='frame'` set (#876) PR #880, #878 @@ -15,7 +25,7 @@ * Add deprecation warning for `_update`. PR #889 * Add `purge_query_cache` utility. PR #889 * Add tests for query caching and permissive join and restriction. PR #889 -* Drop support for Python 3.5 +* Drop support for Python 3.5 (#829) PR #861 ### 0.12.9 -- Mar 12, 2021 * Fix bug with fetch1 with `dj.config['fetch_format']="frame"`. (#876) PR #880 diff --git a/datajoint/condition.py b/datajoint/condition.py index 7d921be4f..fed138cf1 100644 --- a/datajoint/condition.py +++ b/datajoint/condition.py @@ -21,8 +21,9 @@ def __init__(self, operand): class AndList(list): """ - A list of conditions to by applied to a query expression by logical conjunction: the conditions are AND-ed. - All other collections (lists, sets, other entity sets, etc) are applied by logical disjunction (OR). + A list of conditions to by applied to a query expression by logical conjunction: the + conditions are AND-ed. All other collections (lists, sets, other entity sets, etc) are + applied by logical disjunction (OR). Example: expr2 = expr & dj.AndList((cond1, cond2, cond3)) @@ -49,6 +50,7 @@ def assert_join_compatibility(expr1, expr2): the matching attributes in the two expressions must be in the primary key of one or the other expression. Raises an exception if not compatible. + :param expr1: A QueryExpression object :param expr2: A QueryExpression object """ @@ -56,7 +58,8 @@ def assert_join_compatibility(expr1, expr2): for rel in (expr1, expr2): if not isinstance(rel, (U, QueryExpression)): - raise DataJointError('Object %r is not a QueryExpression and cannot be joined.' % rel) + raise DataJointError( + 'Object %r is not a QueryExpression and cannot be joined.' % rel) if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible try: raise DataJointError( @@ -70,9 +73,11 @@ def assert_join_compatibility(expr1, expr2): def make_condition(query_expression, condition, columns): """ Translate the input condition into the equivalent SQL condition (a string) + :param query_expression: a dj.QueryExpression object to apply condition :param condition: any valid restriction object. - :param columns: a set passed by reference to collect all column names used in the condition. + :param columns: a set passed by reference to collect all column names used in the + condition. :return: an SQL condition string or a boolean value. """ from .expression import QueryExpression, Aggregation, U @@ -102,12 +107,13 @@ def prep_value(k, v): # restrict by string if isinstance(condition, str): columns.update(extract_column_names(condition)) - return template % condition.strip().replace("%", "%%") # escape % in strings, see issue #376 + return template % condition.strip().replace("%", "%%") # escape %, see issue #376 # restrict by AndList if isinstance(condition, AndList): # omit all conditions that evaluate to True - items = [item for item in (make_condition(query_expression, cond, columns) for cond in condition) + items = [item for item in (make_condition(query_expression, cond, columns) + for cond in condition) if item is not True] if any(item is False for item in items): return negate # if any item is False, the whole thing is False @@ -123,18 +129,21 @@ def prep_value(k, v): if isinstance(condition, bool): return negate != condition - # restrict by a mapping such as a dict -- convert to an AndList of string equality conditions + # restrict by a mapping/dict -- convert to an AndList of string equality conditions if isinstance(condition, collections.abc.Mapping): common_attributes = set(condition).intersection(query_expression.heading.names) if not common_attributes: return not negate # no matching attributes -> evaluates to True columns.update(common_attributes) return template % ('(' + ') AND ('.join( - '`%s`=%s' % (k, prep_value(k, condition[k])) for k in common_attributes) + ')') + '`%s`%s' % (k, ' IS NULL' if condition[k] is None + else f'={prep_value(k, condition[k])}') + for k in common_attributes) + ')') # restrict by a numpy record -- convert to an AndList of string equality conditions if isinstance(condition, numpy.void): - common_attributes = set(condition.dtype.fields).intersection(query_expression.heading.names) + common_attributes = set(condition.dtype.fields).intersection( + query_expression.heading.names) if not common_attributes: return not negate # no matching attributes -> evaluate to True columns.update(common_attributes) @@ -154,7 +163,8 @@ def prep_value(k, v): if isinstance(condition, QueryExpression): if check_compatibility: assert_join_compatibility(query_expression, condition) - common_attributes = [q for q in condition.heading.names if q in query_expression.heading.names] + common_attributes = [q for q in condition.heading.names + if q in query_expression.heading.names] columns.update(common_attributes) if isinstance(condition, Aggregation): condition = condition.make_subquery() @@ -176,15 +186,17 @@ def prep_value(k, v): except TypeError: raise DataJointError('Invalid restriction type %r' % condition) else: - or_list = [item for item in or_list if item is not False] # ignore all False conditions - if any(item is True for item in or_list): # if any item is True, the whole thing is True + or_list = [item for item in or_list if item is not False] # ignore False conditions + if any(item is True for item in or_list): # if any item is True, entirely True return not negate - return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate # an empty or list is False + return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate def extract_column_names(sql_expression): """ - extract all presumed column names from an sql expression such as the WHERE clause, for example. + extract all presumed column names from an sql expression such as the WHERE clause, + for example. + :param sql_expression: a string containing an SQL expression :return: set of extracted column names This may be MySQL-specific for now. @@ -206,5 +218,8 @@ def extract_column_names(sql_expression): s = re.sub(r"(\b[a-z][a-z_0-9]*)\(", "(", s) remaining_tokens = set(re.findall(r"\b[a-z][a-z_0-9]*\b", s)) # update result removing reserved words - result.update(remaining_tokens - {"is", "in", "between", "like", "and", "or", "null", "not"}) + result.update(remaining_tokens - {"is", "in", "between", "like", "and", "or", "null", + "not", "interval", "second", "minute", "hour", "day", + "month", "week", "year" + }) return result diff --git a/datajoint/connection.py b/datajoint/connection.py index 9db3dcb77..c82ca5b3e 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -203,7 +203,7 @@ def connect(self): self._conn = client.connect( init_command=self.init_fun, sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," - "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION", + "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", charset=config['connection.charset'], **{k: v for k, v in self.conn_info.items() if k not in ['ssl_input', 'host_input']}) @@ -211,7 +211,7 @@ def connect(self): self._conn = client.connect( init_command=self.init_fun, sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," - "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION", + "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", charset=config['connection.charset'], **{k: v for k, v in self.conn_info.items() if not(k in ['ssl_input', 'host_input'] or diff --git a/datajoint/diagram.py b/datajoint/diagram.py index dd48e7b17..c4f823035 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -219,24 +219,32 @@ def _make_graph(self): """ Make the self.graph - a graph object ready for drawing """ - # mark "distinguished" tables, i.e. those that introduce new primary key attributes + # mark "distinguished" tables, i.e. those that introduce new primary key + # attributes for name in self.nodes_to_show: foreign_attributes = set( - attr for p in self.in_edges(name, data=True) for attr in p[2]['attr_map'] if p[2]['primary']) + attr for p in self.in_edges(name, data=True) + for attr in p[2]['attr_map'] if p[2]['primary']) self.nodes[name]['distinguished'] = ( - 'primary_key' in self.nodes[name] and foreign_attributes < self.nodes[name]['primary_key']) + 'primary_key' in self.nodes[name] and + foreign_attributes < self.nodes[name]['primary_key']) # include aliased nodes that are sandwiched between two displayed nodes - gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection( - nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show)) + gaps = set(nx.algorithms.boundary.node_boundary( + self, self.nodes_to_show)).intersection( + nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), + self.nodes_to_show)) nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit) # construct subgraph and rename nodes to class names graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes)) - nx.set_node_attributes(graph, name='node_type', values={n: _get_tier(n) for n in graph}) + nx.set_node_attributes(graph, name='node_type', values={n: _get_tier(n) + for n in graph}) # relabel nodes to class names - mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()} + mapping = {node: lookup_class_name(node, self.context) or node + for node in graph.nodes()} new_names = [mapping.values()] if len(new_names) > len(set(new_names)): - raise DataJointError('Some classes have identical names. The Diagram cannot be plotted.') + raise DataJointError( + 'Some classes have identical names. The Diagram cannot be plotted.') nx.relabel_nodes(graph, mapping, copy=False) return graph diff --git a/datajoint/expression.py b/datajoint/expression.py index 6d07784f6..507af14f3 100644 --- a/datajoint/expression.py +++ b/datajoint/expression.py @@ -84,14 +84,15 @@ def restriction_attributes(self): def primary_key(self): return self.heading.primary_key - _subquery_alias_count = count() # count for alias names used in from_clause + _subquery_alias_count = count() # count for alias names used in the FROM clause def from_clause(self): - support = ('(' + src.make_sql() + ') as `_s%x`' % next( - self._subquery_alias_count) if isinstance(src, QueryExpression) else src for src in self.support) + support = ('(' + src.make_sql() + ') as `$%x`' % next( + self._subquery_alias_count) if isinstance(src, QueryExpression) + else src for src in self.support) clause = next(support) for s, left in zip(support, self._left): - clause += 'NATURAL{left} JOIN {clause}'.format( + clause += ' NATURAL{left} JOIN {clause}'.format( left=" LEFT" if left else "", clause=s) return clause @@ -264,8 +265,10 @@ def join(self, other, semantic_check=True, left=False): (set(self.original_heading.names) & set(other.original_heading.names)) - join_attributes) # need subquery if any of the join attributes are derived - need_subquery1 = need_subquery1 or any(n in self.heading.new_attributes for n in join_attributes) - need_subquery2 = need_subquery2 or any(n in other.heading.new_attributes for n in join_attributes) + need_subquery1 = (need_subquery1 or isinstance(self, Aggregation) or + any(n in self.heading.new_attributes for n in join_attributes)) + need_subquery2 = (need_subquery2 or isinstance(other, Aggregation) or + any(n in other.heading.new_attributes for n in join_attributes)) if need_subquery1: self = self.make_subquery() if need_subquery2: @@ -721,8 +724,9 @@ def __and__(self, other): def join(self, other, left=False): """ - Joining U with a query expression has the effect of promoting the attributes of U to the primary key of - the other query expression. + Joining U with a query expression has the effect of promoting the attributes of U to + the primary key of the other query expression. + :param other: the other query expression to join with. :param left: ignored. dj.U always acts as if left=False :return: a copy of the other query expression with the primary key extended. @@ -733,12 +737,14 @@ def join(self, other, left=False): raise DataJointError('Set U can only be joined with a QueryExpression.') try: raise DataJointError( - 'Attribute `%s` not found' % next(k for k in self.primary_key if k not in other.heading.names)) + 'Attribute `%s` not found' % next(k for k in self.primary_key + if k not in other.heading.names)) except StopIteration: pass # all ok result = copy.copy(other) result._heading = result.heading.set_primary_key( - other.primary_key + [k for k in self.primary_key if k not in other.primary_key]) + other.primary_key + [k for k in self.primary_key + if k not in other.primary_key]) return result def __mul__(self, other): diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 2c050a915..d108caef2 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -29,9 +29,7 @@ def ordered_dir(class_): """ attr_list = list() for c in reversed(class_.mro()): - attr_list.extend(e for e in ( - c._ordered_class_members if hasattr(c, '_ordered_class_members') else c.__dict__) - if e not in attr_list) + attr_list.extend(e for e in c.__dict__ if e not in attr_list) return attr_list @@ -374,9 +372,9 @@ def list_tables(self): as ~logs and ~job :return: A list of table names from the database schema. """ - return [table_name for (table_name,) in self.connection.query(""" - SELECT table_name FROM information_schema.tables - WHERE table_schema = %s and table_name NOT LIKE '~%%'""", args=(self.database,))] + return [t for d, t in (full_t.replace('`', '').split('.') + for full_t in Diagram(self).topological_sort()) + if d == self.database] class VirtualModule(types.ModuleType): diff --git a/datajoint/table.py b/datajoint/table.py index d79c07a75..9a322aafb 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -720,14 +720,18 @@ def lookup_class_name(name, context, depth=3): if member.full_table_name == name: # found it! return '.'.join([node['context_name'], member_name]).lstrip('.') try: # look for part tables - parts = member._ordered_class_members + parts = member.__dict__ except AttributeError: pass # not a UserTable -- cannot have part tables. else: - for part in (getattr(member, p) for p in parts if p[0].isupper() and hasattr(member, p)): - if inspect.isclass(part) and issubclass(part, Table) and part.full_table_name == name: - return '.'.join([node['context_name'], member_name, part.__name__]).lstrip('.') - elif node['depth'] > 0 and inspect.ismodule(member) and member.__name__ != 'datajoint': + for part in (getattr(member, p) for p in parts + if p[0].isupper() and hasattr(member, p)): + if inspect.isclass(part) and issubclass(part, Table) and \ + part.full_table_name == name: + return '.'.join([node['context_name'], + member_name, part.__name__]).lstrip('.') + elif node['depth'] > 0 and inspect.ismodule(member) and \ + member.__name__ != 'datajoint': try: nodes.append( dict(context=dict(inspect.getmembers(member)), diff --git a/datajoint/version.py b/datajoint/version.py index a7571b6c4..403e38347 100644 --- a/datajoint/version.py +++ b/datajoint/version.py @@ -1,3 +1,3 @@ -__version__ = "0.13.0" +__version__ = "0.13.1" assert len(__version__) <= 10 # The log table limits version to the 10 characters diff --git a/docs-parts/intro/Releases_lang1.rst b/docs-parts/intro/Releases_lang1.rst index 3dc72f2ab..ca6decaff 100644 --- a/docs-parts/intro/Releases_lang1.rst +++ b/docs-parts/intro/Releases_lang1.rst @@ -1,12 +1,23 @@ +0.13.1 -- TBD +---------------------- +* Add `None` as an alias for `IS NULL` comparison in `dict` restrictions (#824) PR #893 +* Drop support for MySQL 5.6 since it has reached EOL PR #893 +* Bugfix - `schema.list_tables()` is not topologically sorted (#838) PR #893 +* Bugfix - Diagram part tables do not show proper class name (#882) PR #893 +* Bugfix - Error in complex restrictions (#892) PR #893 +* Bugfix - WHERE and GROUP BY clases are dropped on joins with aggregation (#898, #899) PR #893 + 0.13.0 -- Mar 24, 2021 ---------------------- -* Re-implement query transpilation into SQL, fixing issues (#386, #449, #450, #484). PR #754 -* Re-implement cascading deletes for better performance. PR #839. -* Add table method `.update1` to update a row in the table with new values PR #763 -* Python datatypes are now enabled by default in blobs (#761). PR #785 +* Re-implement query transpilation into SQL, fixing issues (#386, #449, #450, #484, #558). PR #754 +* Re-implement cascading deletes for better performance. PR #839 +* Add support for deferred schema activation to allow for greater modularity. (#834) PR #839 +* Add query caching mechanism for offline development (#550) PR #839 +* Add table method `.update1` to update a row in the table with new values (#867) PR #763, #889 +* Python datatypes are now enabled by default in blobs (#761). PR #859 * Added permissive join and restriction operators `@` and `^` (#785) PR #754 * Support DataJoint datatype and connection plugins (#715, #729) PR 730, #735 -* Add `dj.key_hash` alias to `dj.hash.key_hash` +* Add `dj.key_hash` alias to `dj.hash.key_hash` (#804) PR #862 * Default enable_python_native_blobs to True * Bugfix - Regression error on joins with same attribute name (#857) PR #878 * Bugfix - Error when `fetch1('KEY')` when `dj.config['fetch_format']='frame'` set (#876) PR #880, #878 @@ -14,7 +25,7 @@ * Add deprecation warning for `_update`. PR #889 * Add `purge_query_cache` utility. PR #889 * Add tests for query caching and permissive join and restriction. PR #889 -* Drop support for Python 3.5 +* Drop support for Python 3.5 (#829) PR #861 0.12.9 -- Mar 12, 2021 ---------------------- diff --git a/tests/__init__.py b/tests/__init__.py index 1a48a3a91..6b802e332 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -101,7 +101,7 @@ def setup_package(): conn_root.query( "GRANT SELECT ON `djtest%%`.* TO 'djssl'@'%%';") else: - # grant permissions. For mysql5.6/5.7 this also automatically creates user + # grant permissions. For MySQL 5.7 this also automatically creates user # if not exists conn_root.query(""" GRANT ALL PRIVILEGES ON `djtest%%`.* TO 'datajoint'@'%%' diff --git a/tests/schema.py b/tests/schema.py index 1fd187637..a0b336a1c 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -379,3 +379,63 @@ class ComplexChild(dj.Lookup): definition = '\n'.join(['-> ComplexParent'] + ['child_id_{}: int'.format(i+1) for i in range(1)]) contents = [tuple(i for i in range(9))] + + +@schema +class SubjectA(dj.Lookup): + definition = """ + subject_id: varchar(32) + --- + dob : date + sex : enum('M', 'F', 'U') + """ + contents = [ + ('mouse1', '2020-09-01', 'M'), + ('mouse2', '2020-03-19', 'F'), + ('mouse3', '2020-08-23', 'F') + ] + + +@schema +class SessionA(dj.Lookup): + definition = """ + -> SubjectA + session_start_time: datetime + --- + session_dir='' : varchar(32) + """ + contents = [ + ('mouse1', '2020-12-01 12:32:34', ''), + ('mouse1', '2020-12-02 12:32:34', ''), + ('mouse1', '2020-12-03 12:32:34', ''), + ('mouse1', '2020-12-04 12:32:34', '') + ] + + +@schema +class SessionStatusA(dj.Lookup): + definition = """ + -> SessionA + --- + status: enum('in_training', 'trained_1a', 'trained_1b', 'ready4ephys') + """ + contents = [ + ('mouse1', '2020-12-01 12:32:34', 'in_training'), + ('mouse1', '2020-12-02 12:32:34', 'trained_1a'), + ('mouse1', '2020-12-03 12:32:34', 'trained_1b'), + ('mouse1', '2020-12-04 12:32:34', 'ready4ephys'), + ] + + +@schema +class SessionDateA(dj.Lookup): + definition = """ + -> SubjectA + session_date: date + """ + contents = [ + ('mouse1', '2020-12-01'), + ('mouse1', '2020-12-02'), + ('mouse1', '2020-12-03'), + ('mouse1', '2020-12-04') + ] diff --git a/tests/schema_simple.py b/tests/schema_simple.py index c7aebaa45..c4ec45e00 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -7,6 +7,7 @@ from . import PREFIX, CONN_INFO import numpy as np +from datetime import date, timedelta schema = dj.Schema(PREFIX + '_relational', locals(), connection=dj.conn(**CONN_INFO)) @@ -195,3 +196,22 @@ class ReservedWord(dj.Manual): int : int select : varchar(25) """ + + +@schema +class OutfitLaunch(dj.Lookup): + definition = """ + # Monthly released designer outfits + release_id: int + --- + day: date + """ + contents = [(0, date.today() - timedelta(days=15))] + + class OutfitPiece(dj.Part, dj.Lookup): + definition = """ + # Outfit piece associated with outfit + -> OutfitLaunch + piece: varchar(20) + """ + contents = [(0, 'jeans'), (0, 'sneakers'), (0, 'polo')] diff --git a/tests/test_erd.py b/tests/test_erd.py index 0939ca254..6c4ae24b7 100644 --- a/tests/test_erd.py +++ b/tests/test_erd.py @@ -1,6 +1,6 @@ from nose.tools import assert_false, assert_true import datajoint as dj -from .schema_simple import A, B, D, E, L, schema +from .schema_simple import A, B, D, E, L, schema, OutfitLaunch from . import schema_advanced namespace = locals() @@ -64,5 +64,10 @@ def test_make_image(): img = erd.make_image() assert_true(img.ndim == 3 and img.shape[2] in (3, 4)) - - + @staticmethod + def test_part_table_parsing(): + # https://github.com/datajoint/datajoint-python/issues/882 + erd = dj.Di(schema) + graph = erd._make_graph() + assert 'OutfitLaunch' in graph.nodes() + assert 'OutfitLaunch.OutfitPiece' in graph.nodes() diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index f37dafb31..108bf895b 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -4,11 +4,14 @@ import datetime import numpy as np -from nose.tools import assert_equal, assert_false, assert_true, raises, assert_set_equal, assert_list_equal +from nose.tools import (assert_equal, assert_false, assert_true, raises, assert_set_equal, + assert_list_equal) import datajoint as dj -from .schema_simple import A, B, D, E, F, L, DataA, DataB, TTestUpdate, IJ, JI, ReservedWord -from .schema import Experiment, TTest3, Trial, Ephys, Child, Parent +from .schema_simple import (A, B, D, E, F, L, DataA, DataB, TTestUpdate, IJ, JI, + ReservedWord, OutfitLaunch) +from .schema import (Experiment, TTest3, Trial, Ephys, Child, Parent, SubjectA, SessionA, + SessionStatusA, SessionDateA) def setup(): @@ -459,3 +462,46 @@ def test_permissive_join_basic(): def test_permissive_restriction_basic(): """Verify join compatibility check is skipped for restriction""" Child ^ Parent + + @staticmethod + def test_complex_date_restriction(): + # https://github.com/datajoint/datajoint-python/issues/892 + """Test a complex date restriction""" + q = OutfitLaunch & 'day between curdate() - interval 30 day and curdate()' + assert len(q) == 1 + q = OutfitLaunch & 'day between curdate() - interval 4 week and curdate()' + assert len(q) == 1 + q = OutfitLaunch & 'day between curdate() - interval 1 month and curdate()' + assert len(q) == 1 + q = OutfitLaunch & 'day between curdate() - interval 1 year and curdate()' + assert len(q) == 1 + q = OutfitLaunch & '`day` between curdate() - interval 30 day and curdate()' + assert len(q) == 1 + q.delete() + + @staticmethod + def test_null_dict_restriction(): + # https://github.com/datajoint/datajoint-python/issues/824 + """Test a restriction for null using dict""" + F.insert([dict(id=5)]) + q = F & dj.AndList([dict(id=5), 'date is NULL']) + assert len(q) == 1 + q = F & dict(id=5, date=None) + assert len(q) == 1 + + @staticmethod + def test_joins_with_aggregation(): + # https://github.com/datajoint/datajoint-python/issues/898 + # https://github.com/datajoint/datajoint-python/issues/899 + subjects = SubjectA.aggr( + SessionStatusA & 'status="trained_1a" or status="trained_1b"', + date_trained='min(date(session_start_time))') + assert len(SessionDateA * subjects) == 4 + assert len(subjects * SessionDateA) == 4 + + subj_query = SubjectA.aggr( + SessionA * SessionStatusA & 'status="trained_1a" or status="trained_1b"', + date_trained='min(date(session_start_time))') + session_dates = ((SessionDateA * (subj_query & 'date_trained<"2020-12-21"')) & + 'session_date Schema_A.Subject + id: smallint + """ + + schema2.drop() + schema1.drop()