diff --git a/sqlalchemydiff/comparer.py b/sqlalchemydiff/comparer.py index 2fbc506..8c9ccb0 100644 --- a/sqlalchemydiff/comparer.py +++ b/sqlalchemydiff/comparer.py @@ -278,21 +278,35 @@ def _get_foreign_keys(inspector, table_name): def _get_primary_keys_info( left_inspector, right_inspector, table_name, ignores ): - left_pk_list = _get_primary_keys(left_inspector, table_name) - right_pk_list = _get_primary_keys(right_inspector, table_name) - - left_pk_list = _discard_ignores(left_pk_list, ignores) - right_pk_list = _discard_ignores(right_pk_list, ignores) - - # process into dict - left_pk = dict((elem, elem) for elem in left_pk_list) - right_pk = dict((elem, elem) for elem in right_pk_list) + left_pk_constraint = _get_primary_keys(left_inspector, table_name) + right_pk_constraint = _get_primary_keys(right_inspector, table_name) + + pk_constraint_has_name = ('name' in left_pk_constraint and + left_pk_constraint['name'] is not None) + + if pk_constraint_has_name: + left_pk = ({left_pk_constraint['name']: left_pk_constraint} + if _discard_ignores_by_name([left_pk_constraint], ignores) + else {}) + right_pk = ({right_pk_constraint['name']: right_pk_constraint} + if _discard_ignores_by_name([right_pk_constraint], ignores) + else {}) + else: + left_pk_list = left_pk_constraint['constrained_columns'] + right_pk_list = right_pk_constraint['constrained_columns'] + + left_pk_list = _discard_ignores(left_pk_list, ignores) + right_pk_list = _discard_ignores(right_pk_list, ignores) + + # process into dict + left_pk = dict((elem, elem) for elem in left_pk_list) + right_pk = dict((elem, elem) for elem in right_pk_list) return _diff_dicts(left_pk, right_pk) def _get_primary_keys(inspector, table_name): - return inspector.get_primary_keys(table_name) + return inspector.get_pk_constraint(table_name) def _get_indexes_info(left_inspector, right_inspector, table_name, ignores): diff --git a/test/unit/test_comparer.py b/test/unit/test_comparer.py index 4002459..d99266c 100644 --- a/test/unit/test_comparer.py +++ b/test/unit/test_comparer.py @@ -431,8 +431,8 @@ def test__get_foreign_keys(self): def test__get_primary_keys_info( self, _diff_dicts_mock, _get_primary_keys_mock): _get_primary_keys_mock.side_effect = [ - ['pk_left_1', 'pk_left_2'], - ['pk_right_1'] + {'constrained_columns': ['pk_left_1', 'pk_left_2']}, + {'constrained_columns': ['pk_right_1']} ] left_inspector, right_inspector = Mock(), Mock() @@ -449,8 +449,8 @@ def test__get_primary_keys_info( def test__get_primary_keys_info_ignores( self, _diff_dicts_mock, _get_primary_keys_mock): _get_primary_keys_mock.side_effect = [ - ['pk_left_1', 'pk_left_2'], - ['pk_right_1', 'pk_right_2'] + {'constrained_columns': ['pk_left_1', 'pk_left_2']}, + {'constrained_columns': ['pk_right_1', 'pk_right_2']}, ] left_inspector, right_inspector = Mock(), Mock() ignores = ['pk_left_1', 'pk_right_2'] @@ -465,13 +465,58 @@ def test__get_primary_keys_info_ignores( assert _diff_dicts_mock.return_value == result + def test__get_primary_keys_info_with_pk_constraint_name( + self, _diff_dicts_mock, _get_primary_keys_mock): + _get_primary_keys_mock.side_effect = [ + {'name': 'left', 'constrained_columns': ['pk_left_1']}, + {'name': 'right', 'constrained_columns': ['pk_right_1']} + ] + left_inspector, right_inspector = Mock(), Mock() + + result = _get_primary_keys_info( + left_inspector, right_inspector, 'table_A', []) + + _diff_dicts_mock.assert_called_once_with( + { + 'left': {'name': 'left', + 'constrained_columns': ['pk_left_1']} + }, + { + 'right': {'name': 'right', + 'constrained_columns': ['pk_right_1']} + } + ) + assert _diff_dicts_mock.return_value == result + + def test__get_primary_keys_info_ignores_with_pk_constraint_name( + self, _diff_dicts_mock, _get_primary_keys_mock): + _get_primary_keys_mock.side_effect = [ + {'name': 'left_1', 'constrained_columns': ['pk_left_1']}, + {'name': 'right_1', 'constrained_columns': ['pk_right_1']}, + ] + left_inspector, right_inspector = Mock(), Mock() + ignores = ['left_1', 'left_2', 'right_2'] + + result = _get_primary_keys_info( + left_inspector, right_inspector, 'table_A', ignores) + + _diff_dicts_mock.assert_called_once_with( + dict(), + { + 'right_1': {'name': 'right_1', + 'constrained_columns': ['pk_right_1']}, + } + ) + + assert _diff_dicts_mock.return_value == result + def test__get_primary_keys(self): inspector = Mock() result = _get_primary_keys(inspector, 'table_A') - inspector.get_primary_keys.assert_called_once_with('table_A') - assert inspector.get_primary_keys.return_value == result + inspector.get_pk_constraint.assert_called_once_with('table_A') + assert inspector.get_pk_constraint.return_value == result def test__get_indexes_info( self, _diff_dicts_mock, _get_indexes_mock):