diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 3cd90360..af250afa 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -271,6 +271,14 @@ def _diff_segments( ): ... + def _resolve_key_range(self, key_range_res, usr_key_range): + key_range_res = list(key_range_res) + if usr_key_range[0] is not None: + key_range_res[0] = usr_key_range[0] + if usr_key_range[1] is not None: + key_range_res[1] = usr_key_range[1] + return tuple(key_range_res) + def _bisect_and_diff_tables(self, table1, table2, info_tree): if len(table1.key_columns) > 1: raise NotImplementedError("Composite key not supported yet!") @@ -290,11 +298,15 @@ def _bisect_and_diff_tables(self, table1, table2, info_tree): if key_type.python_type is not key_type2.python_type: raise TypeError(f"Incompatible key types: {key_type} and {key_type2}") - # Query min/max values - key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) + usr_key_range = (table1.min_key, table1.max_key) + if all(k is not None for k in [table1.min_key, table1.max_key, table2.min_key, table2.max_key]): + key_ranges = (kr for kr in [(table1.min_key, table1.max_key), (table2.min_key, table2.max_key)]) + else: + # Query min/max values + key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2]) # Start with the first completed value, so we don't waste time waiting - min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges)) + min_key1, max_key1 = self._parse_key_range_result(key_type, self._resolve_key_range(next(key_ranges), usr_key_range)) table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)] @@ -308,7 +320,7 @@ def _bisect_and_diff_tables(self, table1, table2, info_tree): ti.submit(self._bisect_and_diff_segments, ti, table1, table2, info_tree) # Now we check for the second min-max, to diff the portions we "missed". - min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges)) + min_key2, max_key2 = self._parse_key_range_result(key_type, self._resolve_key_range(next(key_ranges), usr_key_range)) if min_key2 < min_key1: pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)] diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 725beb39..1d7025b2 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -104,10 +104,15 @@ def _make_update_range(self): def source_table(self): return table(*self.table_path, schema=self._schema) - def make_select(self): - return self.source_table.where( - *self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP - ) + def make_select(self, include_key_range=True): + if include_key_range: + return self.source_table.where( + *self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP + ) + else: + return self.source_table.where( + *self._make_update_range(), Code(self._where()) if self.where else SKIP + ) def get_values(self) -> list: "Download all the relevant values of the segment from the database" @@ -187,7 +192,7 @@ def query_key_range(self) -> Tuple[int, int]: """Query database for minimum and maximum key. This is used for setting the initial bounds.""" # Normalizes the result (needed for UUIDs) after the min/max computation (k,) = self.key_columns - select = self.make_select().select( + select = self.make_select(include_key_range=False).select( ApplyFuncAndNormalizeAsString(this[k], min_), ApplyFuncAndNormalizeAsString(this[k], max_), ) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 2a8d2acc..f6b55b57 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -2,6 +2,7 @@ from typing import Callable import uuid import unittest +from unittest.mock import patch from sqeleton.queries import table, this, commit from sqeleton.utils import ArithAlphanumeric, numberToAlphanum @@ -299,6 +300,34 @@ def test_diff_sorted_by_key(self): } self.assertEqual(expected, diff) + @patch.object(TableSegment, 'query_key_range') + def test_key_bounds(self, mock_query_key_range): + # test range query when no min/max provided + mock_query_key_range.return_value = (0, 10) + _ = list(self.differ.diff_tables(self.table, self.table2)) + mock_query_key_range.assert_called() + + # test no range query + mock_query_key_range.reset_mock() + tbl1 = self.table.replace(min_key=1, max_key=100) + tbl2 = self.table2.replace(min_key=1, max_key=100) + _ = list(self.differ.diff_tables(tbl1, tbl2)) + mock_query_key_range.assert_not_called() + + # test query min only + mock_query_key_range.reset_mock() + tbl1 = self.table.replace(max_key=100) + tbl2 = self.table2.replace(max_key=100) + _ = list(self.differ.diff_tables(tbl1, tbl2)) + mock_query_key_range.assert_called() + + # test query min only + mock_query_key_range.reset_mock() + tbl1 = self.table.replace(min_key=0) + tbl2 = self.table2.replace(min_key=0) + _ = list(self.differ.diff_tables(tbl1, tbl2)) + mock_query_key_range.assert_called() + @test_each_database class TestDiffTables2(DiffTestCase):