Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e0d68ce
Add interval and day to reserved key words.
guzman-raphael Mar 25, 2021
edd2961
Fix styling and add to MySQL reserved words.
guzman-raphael Mar 25, 2021
2045fc8
Add erd part table parsing test.
guzman-raphael Mar 25, 2021
c21f722
Fix style in list_tables test.
guzman-raphael Mar 25, 2021
7802afb
Add debug statements.
guzman-raphael Mar 25, 2021
2b49fac
Add debug of parts.
guzman-raphael Mar 25, 2021
0008645
Remove usage of _ordered_class_members, ordered_dir since upgrading t…
guzman-raphael Mar 25, 2021
65b4c46
Clean up.
guzman-raphael Mar 25, 2021
84c5c9d
Clean up2.
guzman-raphael Mar 25, 2021
4c2613c
Add topological sort to list_tables and additional reserved MySQL key…
guzman-raphael Mar 25, 2021
202ecd3
Add test for uppercase schema.
guzman-raphael Mar 25, 2021
33e95db
Update test for uppercased schema.
guzman-raphael Mar 25, 2021
2547748
Use root user for uppercase schema test due to permissions.
guzman-raphael Mar 25, 2021
f773991
Allow None to be used in dict restrictions.
guzman-raphael Mar 26, 2021
8c0b8a3
Restrict test to just the created id.
guzman-raphael Mar 26, 2021
0c98e86
Update release log and bump version.
guzman-raphael Mar 26, 2021
ed16dc8
Incorporate feedback and add test for #898.
guzman-raphael Apr 7, 2021
f71fe38
Fix join with aggregations.
guzman-raphael Apr 9, 2021
c7e9331
Update changelog.
guzman-raphael Apr 9, 2021
efeb90b
Update test.
guzman-raphael Apr 9, 2021
614937a
Drop tests for MySQL 5.6 since it has reached EOL.
guzman-raphael Apr 12, 2021
4fdebb5
Update changelog.
guzman-raphael Apr 12, 2021
80db98d
Adjust wording.
guzman-raphael Apr 12, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/development.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
py_ver: ["3.8"]
mysql_ver: ["8.0", "5.7", "5.6"]
mysql_ver: ["8.0", "5.7"]
include:
- py_ver: "3.7"
mysql_ver: "5.7"
Expand Down
22 changes: 16 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
## Release notes

### 0.13.1 -- TBD
* Add `None` as an alias for `IS NULL` comparison in `dict` restrictions (#824) PR #893
* Drop support for MySQL 5.6 since it has reached EOL PR #893
* Bugfix - `schema.list_tables()` is not topologically sorted (#838) PR #893
* Bugfix - Diagram part tables do not show proper class name (#882) PR #893
* Bugfix - Error in complex restrictions (#892) PR #893
* Bugfix - WHERE and GROUP BY clases are dropped on joins with aggregation (#898, #899) PR #893

### 0.13.0 -- Mar 24, 2021
* Re-implement query transpilation into SQL, fixing issues (#386, #449, #450, #484). PR #754
* Re-implement cascading deletes for better performance. PR #839.
* Add table method `.update1` to update a row in the table with new values PR #763
* Python datatypes are now enabled by default in blobs (#761). PR #785
* Re-implement query transpilation into SQL, fixing issues (#386, #449, #450, #484, #558). PR #754
* Re-implement cascading deletes for better performance. PR #839
* Add support for deferred schema activation to allow for greater modularity. (#834) PR #839
* Add query caching mechanism for offline development (#550) PR #839
* Add table method `.update1` to update a row in the table with new values (#867) PR #763, #889
* Python datatypes are now enabled by default in blobs (#761). PR #859
* Added permissive join and restriction operators `@` and `^` (#785) PR #754
* Support DataJoint datatype and connection plugins (#715, #729) PR 730, #735
* Add `dj.key_hash` alias to `dj.hash.key_hash`
* Add `dj.key_hash` alias to `dj.hash.key_hash` (#804) PR #862
* Default enable_python_native_blobs to True
* Bugfix - Regression error on joins with same attribute name (#857) PR #878
* Bugfix - Error when `fetch1('KEY')` when `dj.config['fetch_format']='frame'` set (#876) PR #880, #878
* Bugfix - Error when cascading deletes in tables with many, complex keys (#883, #886) PR #839
* Add deprecation warning for `_update`. PR #889
* Add `purge_query_cache` utility. PR #889
* Add tests for query caching and permissive join and restriction. PR #889
* Drop support for Python 3.5
* Drop support for Python 3.5 (#829) PR #861

### 0.12.9 -- Mar 12, 2021
* Fix bug with fetch1 with `dj.config['fetch_format']="frame"`. (#876) PR #880
Expand Down
45 changes: 30 additions & 15 deletions datajoint/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def __init__(self, operand):

class AndList(list):
"""
A list of conditions to by applied to a query expression by logical conjunction: the conditions are AND-ed.
All other collections (lists, sets, other entity sets, etc) are applied by logical disjunction (OR).
A list of conditions to by applied to a query expression by logical conjunction: the
conditions are AND-ed. All other collections (lists, sets, other entity sets, etc) are
applied by logical disjunction (OR).

Example:
expr2 = expr & dj.AndList((cond1, cond2, cond3))
Expand All @@ -49,14 +50,16 @@ def assert_join_compatibility(expr1, expr2):
the matching attributes in the two expressions must be in the primary key of one or the
other expression.
Raises an exception if not compatible.

:param expr1: A QueryExpression object
:param expr2: A QueryExpression object
"""
from .expression import QueryExpression, U

for rel in (expr1, expr2):
if not isinstance(rel, (U, QueryExpression)):
raise DataJointError('Object %r is not a QueryExpression and cannot be joined.' % rel)
raise DataJointError(
'Object %r is not a QueryExpression and cannot be joined.' % rel)
if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible
try:
raise DataJointError(
Expand All @@ -70,9 +73,11 @@ def assert_join_compatibility(expr1, expr2):
def make_condition(query_expression, condition, columns):
"""
Translate the input condition into the equivalent SQL condition (a string)

:param query_expression: a dj.QueryExpression object to apply condition
:param condition: any valid restriction object.
:param columns: a set passed by reference to collect all column names used in the condition.
:param columns: a set passed by reference to collect all column names used in the
condition.
:return: an SQL condition string or a boolean value.
"""
from .expression import QueryExpression, Aggregation, U
Expand Down Expand Up @@ -102,12 +107,13 @@ def prep_value(k, v):
# restrict by string
if isinstance(condition, str):
columns.update(extract_column_names(condition))
return template % condition.strip().replace("%", "%%") # escape % in strings, see issue #376
return template % condition.strip().replace("%", "%%") # escape %, see issue #376

# restrict by AndList
if isinstance(condition, AndList):
# omit all conditions that evaluate to True
items = [item for item in (make_condition(query_expression, cond, columns) for cond in condition)
items = [item for item in (make_condition(query_expression, cond, columns)
for cond in condition)
if item is not True]
if any(item is False for item in items):
return negate # if any item is False, the whole thing is False
Expand All @@ -123,18 +129,21 @@ def prep_value(k, v):
if isinstance(condition, bool):
return negate != condition

# restrict by a mapping such as a dict -- convert to an AndList of string equality conditions
# restrict by a mapping/dict -- convert to an AndList of string equality conditions
if isinstance(condition, collections.abc.Mapping):
common_attributes = set(condition).intersection(query_expression.heading.names)
if not common_attributes:
return not negate # no matching attributes -> evaluates to True
columns.update(common_attributes)
return template % ('(' + ') AND ('.join(
'`%s`=%s' % (k, prep_value(k, condition[k])) for k in common_attributes) + ')')
'`%s`%s' % (k, ' IS NULL' if condition[k] is None
else f'={prep_value(k, condition[k])}')
for k in common_attributes) + ')')

# restrict by a numpy record -- convert to an AndList of string equality conditions
if isinstance(condition, numpy.void):
common_attributes = set(condition.dtype.fields).intersection(query_expression.heading.names)
common_attributes = set(condition.dtype.fields).intersection(
query_expression.heading.names)
if not common_attributes:
return not negate # no matching attributes -> evaluate to True
columns.update(common_attributes)
Expand All @@ -154,7 +163,8 @@ def prep_value(k, v):
if isinstance(condition, QueryExpression):
if check_compatibility:
assert_join_compatibility(query_expression, condition)
common_attributes = [q for q in condition.heading.names if q in query_expression.heading.names]
common_attributes = [q for q in condition.heading.names
if q in query_expression.heading.names]
columns.update(common_attributes)
if isinstance(condition, Aggregation):
condition = condition.make_subquery()
Expand All @@ -176,15 +186,17 @@ def prep_value(k, v):
except TypeError:
raise DataJointError('Invalid restriction type %r' % condition)
else:
or_list = [item for item in or_list if item is not False] # ignore all False conditions
if any(item is True for item in or_list): # if any item is True, the whole thing is True
or_list = [item for item in or_list if item is not False] # ignore False conditions
if any(item is True for item in or_list): # if any item is True, entirely True
return not negate
return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate # an empty or list is False
return template % ('(%s)' % ' OR '.join(or_list)) if or_list else negate


def extract_column_names(sql_expression):
"""
extract all presumed column names from an sql expression such as the WHERE clause, for example.
extract all presumed column names from an sql expression such as the WHERE clause,
for example.

:param sql_expression: a string containing an SQL expression
:return: set of extracted column names
This may be MySQL-specific for now.
Expand All @@ -206,5 +218,8 @@ def extract_column_names(sql_expression):
s = re.sub(r"(\b[a-z][a-z_0-9]*)\(", "(", s)
remaining_tokens = set(re.findall(r"\b[a-z][a-z_0-9]*\b", s))
# update result removing reserved words
result.update(remaining_tokens - {"is", "in", "between", "like", "and", "or", "null", "not"})
result.update(remaining_tokens - {"is", "in", "between", "like", "and", "or", "null",
"not", "interval", "second", "minute", "hour", "day",
"month", "week", "year"
})
return result
4 changes: 2 additions & 2 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ def connect(self):
self._conn = client.connect(
init_command=self.init_fun,
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
charset=config['connection.charset'],
**{k: v for k, v in self.conn_info.items()
if k not in ['ssl_input', 'host_input']})
except client.err.InternalError:
self._conn = client.connect(
init_command=self.init_fun,
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
charset=config['connection.charset'],
**{k: v for k, v in self.conn_info.items()
if not(k in ['ssl_input', 'host_input'] or
Expand Down
24 changes: 16 additions & 8 deletions datajoint/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,24 +219,32 @@ def _make_graph(self):
"""
Make the self.graph - a graph object ready for drawing
"""
# mark "distinguished" tables, i.e. those that introduce new primary key attributes
# mark "distinguished" tables, i.e. those that introduce new primary key
# attributes
for name in self.nodes_to_show:
foreign_attributes = set(
attr for p in self.in_edges(name, data=True) for attr in p[2]['attr_map'] if p[2]['primary'])
attr for p in self.in_edges(name, data=True)
for attr in p[2]['attr_map'] if p[2]['primary'])
self.nodes[name]['distinguished'] = (
'primary_key' in self.nodes[name] and foreign_attributes < self.nodes[name]['primary_key'])
'primary_key' in self.nodes[name] and
foreign_attributes < self.nodes[name]['primary_key'])
# include aliased nodes that are sandwiched between two displayed nodes
gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection(
nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show))
gaps = set(nx.algorithms.boundary.node_boundary(
self, self.nodes_to_show)).intersection(
nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(),
self.nodes_to_show))
nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit)
# construct subgraph and rename nodes to class names
graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
nx.set_node_attributes(graph, name='node_type', values={n: _get_tier(n) for n in graph})
nx.set_node_attributes(graph, name='node_type', values={n: _get_tier(n)
for n in graph})
# relabel nodes to class names
mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
mapping = {node: lookup_class_name(node, self.context) or node
for node in graph.nodes()}
new_names = [mapping.values()]
if len(new_names) > len(set(new_names)):
raise DataJointError('Some classes have identical names. The Diagram cannot be plotted.')
raise DataJointError(
'Some classes have identical names. The Diagram cannot be plotted.')
nx.relabel_nodes(graph, mapping, copy=False)
return graph

Expand Down
26 changes: 16 additions & 10 deletions datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ def restriction_attributes(self):
def primary_key(self):
return self.heading.primary_key

_subquery_alias_count = count() # count for alias names used in from_clause
_subquery_alias_count = count() # count for alias names used in the FROM clause

def from_clause(self):
support = ('(' + src.make_sql() + ') as `_s%x`' % next(
self._subquery_alias_count) if isinstance(src, QueryExpression) else src for src in self.support)
support = ('(' + src.make_sql() + ') as `$%x`' % next(
self._subquery_alias_count) if isinstance(src, QueryExpression)
else src for src in self.support)
clause = next(support)
for s, left in zip(support, self._left):
clause += 'NATURAL{left} JOIN {clause}'.format(
clause += ' NATURAL{left} JOIN {clause}'.format(
left=" LEFT" if left else "",
clause=s)
return clause
Expand Down Expand Up @@ -264,8 +265,10 @@ def join(self, other, semantic_check=True, left=False):
(set(self.original_heading.names) & set(other.original_heading.names))
- join_attributes)
# need subquery if any of the join attributes are derived
need_subquery1 = need_subquery1 or any(n in self.heading.new_attributes for n in join_attributes)
need_subquery2 = need_subquery2 or any(n in other.heading.new_attributes for n in join_attributes)
need_subquery1 = (need_subquery1 or isinstance(self, Aggregation) or
any(n in self.heading.new_attributes for n in join_attributes))
need_subquery2 = (need_subquery2 or isinstance(other, Aggregation) or
any(n in other.heading.new_attributes for n in join_attributes))
if need_subquery1:
self = self.make_subquery()
if need_subquery2:
Expand Down Expand Up @@ -721,8 +724,9 @@ def __and__(self, other):

def join(self, other, left=False):
"""
Joining U with a query expression has the effect of promoting the attributes of U to the primary key of
the other query expression.
Joining U with a query expression has the effect of promoting the attributes of U to
the primary key of the other query expression.

:param other: the other query expression to join with.
:param left: ignored. dj.U always acts as if left=False
:return: a copy of the other query expression with the primary key extended.
Expand All @@ -733,12 +737,14 @@ def join(self, other, left=False):
raise DataJointError('Set U can only be joined with a QueryExpression.')
try:
raise DataJointError(
'Attribute `%s` not found' % next(k for k in self.primary_key if k not in other.heading.names))
'Attribute `%s` not found' % next(k for k in self.primary_key
if k not in other.heading.names))
except StopIteration:
pass # all ok
result = copy.copy(other)
result._heading = result.heading.set_primary_key(
other.primary_key + [k for k in self.primary_key if k not in other.primary_key])
other.primary_key + [k for k in self.primary_key
if k not in other.primary_key])
return result

def __mul__(self, other):
Expand Down
10 changes: 4 additions & 6 deletions datajoint/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def ordered_dir(class_):
"""
attr_list = list()
for c in reversed(class_.mro()):
attr_list.extend(e for e in (
c._ordered_class_members if hasattr(c, '_ordered_class_members') else c.__dict__)
if e not in attr_list)
attr_list.extend(e for e in c.__dict__ if e not in attr_list)
return attr_list


Expand Down Expand Up @@ -374,9 +372,9 @@ def list_tables(self):
as ~logs and ~job
:return: A list of table names from the database schema.
"""
return [table_name for (table_name,) in self.connection.query("""
SELECT table_name FROM information_schema.tables
WHERE table_schema = %s and table_name NOT LIKE '~%%'""", args=(self.database,))]
return [t for d, t in (full_t.replace('`', '').split('.')
for full_t in Diagram(self).topological_sort())
if d == self.database]


class VirtualModule(types.ModuleType):
Expand Down
14 changes: 9 additions & 5 deletions datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,14 +720,18 @@ def lookup_class_name(name, context, depth=3):
if member.full_table_name == name: # found it!
return '.'.join([node['context_name'], member_name]).lstrip('.')
try: # look for part tables
parts = member._ordered_class_members
parts = member.__dict__
except AttributeError:
pass # not a UserTable -- cannot have part tables.
else:
for part in (getattr(member, p) for p in parts if p[0].isupper() and hasattr(member, p)):
if inspect.isclass(part) and issubclass(part, Table) and part.full_table_name == name:
return '.'.join([node['context_name'], member_name, part.__name__]).lstrip('.')
elif node['depth'] > 0 and inspect.ismodule(member) and member.__name__ != 'datajoint':
for part in (getattr(member, p) for p in parts
if p[0].isupper() and hasattr(member, p)):
if inspect.isclass(part) and issubclass(part, Table) and \
part.full_table_name == name:
return '.'.join([node['context_name'],
member_name, part.__name__]).lstrip('.')
elif node['depth'] > 0 and inspect.ismodule(member) and \
member.__name__ != 'datajoint':
try:
nodes.append(
dict(context=dict(inspect.getmembers(member)),
Expand Down
2 changes: 1 addition & 1 deletion datajoint/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.13.0"
__version__ = "0.13.1"

assert len(__version__) <= 10 # The log table limits version to the 10 characters
Loading