Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ def _diff_segments(
):
...

def _resolve_key_range(self, key_range_res, usr_key_range):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extracted to a function since this logic is needed twice.

key_range_res = list(key_range_res)
if usr_key_range[0] is not None:
key_range_res[0] = usr_key_range[0]
if usr_key_range[1] is not None:
key_range_res[1] = usr_key_range[1]
return tuple(key_range_res)

def _bisect_and_diff_tables(self, table1, table2, info_tree):
if len(table1.key_columns) > 1:
raise NotImplementedError("Composite key not supported yet!")
Expand All @@ -290,11 +298,15 @@ def _bisect_and_diff_tables(self, table1, table2, info_tree):
if key_type.python_type is not key_type2.python_type:
raise TypeError(f"Incompatible key types: {key_type} and {key_type2}")

# Query min/max values
key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2])
usr_key_range = (table1.min_key, table1.max_key)
if all(k is not None for k in [table1.min_key, table1.max_key, table2.min_key, table2.max_key]):
key_ranges = (kr for kr in [(table1.min_key, table1.max_key), (table2.min_key, table2.max_key)])
else:
# Query min/max values
key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2])

# Start with the first completed value, so we don't waste time waiting
min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges))
min_key1, max_key1 = self._parse_key_range_result(key_type, self._resolve_key_range(next(key_ranges), usr_key_range))

table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]

Expand All @@ -308,7 +320,7 @@ def _bisect_and_diff_tables(self, table1, table2, info_tree):
ti.submit(self._bisect_and_diff_segments, ti, table1, table2, info_tree)

# Now we check for the second min-max, to diff the portions we "missed".
min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges))
min_key2, max_key2 = self._parse_key_range_result(key_type, self._resolve_key_range(next(key_ranges), usr_key_range))

if min_key2 < min_key1:
pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)]
Expand Down
15 changes: 10 additions & 5 deletions data_diff/table_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,15 @@ def _make_update_range(self):
def source_table(self):
return table(*self.table_path, schema=self._schema)

def make_select(self):
return self.source_table.where(
*self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP
)
def make_select(self, include_key_range=True):
if include_key_range:
return self.source_table.where(
*self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP
)
else:
return self.source_table.where(
*self._make_update_range(), Code(self._where()) if self.where else SKIP
)

def get_values(self) -> list:
"Download all the relevant values of the segment from the database"
Expand Down Expand Up @@ -187,7 +192,7 @@ def query_key_range(self) -> Tuple[int, int]:
"""Query database for minimum and maximum key. This is used for setting the initial bounds."""
# Normalizes the result (needed for UUIDs) after the min/max computation
(k,) = self.key_columns
select = self.make_select().select(
select = self.make_select(include_key_range=False).select(
ApplyFuncAndNormalizeAsString(this[k], min_),
ApplyFuncAndNormalizeAsString(this[k], max_),
)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable
import uuid
import unittest
from unittest.mock import patch

from sqeleton.queries import table, this, commit
from sqeleton.utils import ArithAlphanumeric, numberToAlphanum
Expand Down Expand Up @@ -299,6 +300,34 @@ def test_diff_sorted_by_key(self):
}
self.assertEqual(expected, diff)

@patch.object(TableSegment, 'query_key_range')
def test_key_bounds(self, mock_query_key_range):
# test range query when no min/max provided
mock_query_key_range.return_value = (0, 10)
_ = list(self.differ.diff_tables(self.table, self.table2))
mock_query_key_range.assert_called()

# test no range query
mock_query_key_range.reset_mock()
tbl1 = self.table.replace(min_key=1, max_key=100)
tbl2 = self.table2.replace(min_key=1, max_key=100)
_ = list(self.differ.diff_tables(tbl1, tbl2))
mock_query_key_range.assert_not_called()

# test query min only
mock_query_key_range.reset_mock()
tbl1 = self.table.replace(max_key=100)
tbl2 = self.table2.replace(max_key=100)
_ = list(self.differ.diff_tables(tbl1, tbl2))
mock_query_key_range.assert_called()

# test query min only
mock_query_key_range.reset_mock()
tbl1 = self.table.replace(min_key=0)
tbl2 = self.table2.replace(min_key=0)
_ = list(self.differ.diff_tables(tbl1, tbl2))
mock_query_key_range.assert_called()


@test_each_database
class TestDiffTables2(DiffTestCase):
Expand Down