|
12 | 12 |
|
13 | 13 | import pandas as pd |
14 | 14 | from pandas import (Categorical, Series, DataFrame, |
15 | | - Index, MultiIndex, Timedelta) |
| 15 | + Index, MultiIndex, Timedelta, lib) |
16 | 16 | from pandas.core.frame import _merge_doc |
17 | 17 | from pandas.types.common import (is_datetime64tz_dtype, |
18 | 18 | is_datetime64_dtype, |
19 | 19 | needs_i8_conversion, |
20 | 20 | is_int64_dtype, |
| 21 | + is_categorical_dtype, |
21 | 22 | is_integer_dtype, |
22 | 23 | is_float_dtype, |
| 24 | + is_numeric_dtype, |
23 | 25 | is_integer, |
24 | 26 | is_int_or_datetime_dtype, |
25 | 27 | is_dtype_equal, |
@@ -567,6 +569,10 @@ def __init__(self, left, right, how='inner', on=None, |
567 | 569 | self.right_join_keys, |
568 | 570 | self.join_names) = self._get_merge_keys() |
569 | 571 |
|
| 572 | + # validate the merge keys dtypes. We may need to coerce |
| 573 | + # to avoid incompat dtypes |
| 574 | + self._maybe_coerce_merge_keys() |
| 575 | + |
570 | 576 | def get_result(self): |
571 | 577 | if self.indicator: |
572 | 578 | self.left, self.right = self._indicator_pre_merge( |
@@ -757,26 +763,6 @@ def _get_join_info(self): |
757 | 763 | join_index = join_index.astype(object) |
758 | 764 | return join_index, left_indexer, right_indexer |
759 | 765 |
|
760 | | - def _get_merge_data(self): |
761 | | - """ |
762 | | - Handles overlapping column names etc. |
763 | | - """ |
764 | | - ldata, rdata = self.left._data, self.right._data |
765 | | - lsuf, rsuf = self.suffixes |
766 | | - |
767 | | - llabels, rlabels = items_overlap_with_suffix( |
768 | | - ldata.items, lsuf, rdata.items, rsuf) |
769 | | - |
770 | | - if not llabels.equals(ldata.items): |
771 | | - ldata = ldata.copy(deep=False) |
772 | | - ldata.set_axis(0, llabels) |
773 | | - |
774 | | - if not rlabels.equals(rdata.items): |
775 | | - rdata = rdata.copy(deep=False) |
776 | | - rdata.set_axis(0, rlabels) |
777 | | - |
778 | | - return ldata, rdata |
779 | | - |
780 | 766 | def _get_merge_keys(self): |
781 | 767 | """ |
782 | 768 | Note: has side effects (copy/delete key columns) |
@@ -888,6 +874,51 @@ def _get_merge_keys(self): |
888 | 874 |
|
889 | 875 | return left_keys, right_keys, join_names |
890 | 876 |
|
| 877 | + def _maybe_coerce_merge_keys(self): |
| 878 | + # we have valid mergee's but we may have to further |
| 879 | + # coerce these if they are originally incompatible types |
| 880 | + # |
| 881 | + # for example if these are categorical, but are not dtype_equal |
| 882 | + # or if we have object and integer dtypes |
| 883 | + |
| 884 | + for lk, rk, name in zip(self.left_join_keys, |
| 885 | + self.right_join_keys, |
| 886 | + self.join_names): |
| 887 | + if (len(lk) and not len(rk)) or (not len(lk) and len(rk)): |
| 888 | + continue |
| 889 | + |
| 890 | + # if either left or right is a categorical |
| 891 | + # then the must match exactly in categories & ordered |
| 892 | + if is_categorical_dtype(lk) and is_categorical_dtype(rk): |
| 893 | + if lk.is_dtype_equal(rk): |
| 894 | + continue |
| 895 | + elif is_categorical_dtype(lk) or is_categorical_dtype(rk): |
| 896 | + pass |
| 897 | + |
| 898 | + elif is_dtype_equal(lk.dtype, rk.dtype): |
| 899 | + continue |
| 900 | + |
| 901 | + # if we are numeric, then allow differing |
| 902 | + # kinds to proceed, eg. int64 and int8 |
| 903 | + # further if we are object, but we infer to |
| 904 | + # the same, then proceed |
| 905 | + if (is_numeric_dtype(lk) and is_numeric_dtype(rk)): |
| 906 | + if lk.dtype.kind == rk.dtype.kind: |
| 907 | + continue |
| 908 | + |
| 909 | + # let's infer and see if we are ok |
| 910 | + if lib.infer_dtype(lk) == lib.infer_dtype(rk): |
| 911 | + continue |
| 912 | + |
| 913 | + # Houston, we have a problem! |
| 914 | + # let's coerce to object |
| 915 | + if name in self.left.columns: |
| 916 | + self.left = self.left.assign( |
| 917 | + **{name: self.left[name].astype(object)}) |
| 918 | + if name in self.right.columns: |
| 919 | + self.right = self.right.assign( |
| 920 | + **{name: self.right[name].astype(object)}) |
| 921 | + |
891 | 922 | def _validate_specification(self): |
892 | 923 | # Hm, any way to make this logic less complicated?? |
893 | 924 | if self.on is None and self.left_on is None and self.right_on is None: |
@@ -939,9 +970,15 @@ def _get_join_indexers(left_keys, right_keys, sort=False, how='inner', |
939 | 970 |
|
940 | 971 | Parameters |
941 | 972 | ---------- |
| 973 | + left_keys: ndarray, Index, Series |
| 974 | + right_keys: ndarray, Index, Series |
| 975 | + sort: boolean, default False |
| 976 | + how: string {'inner', 'outer', 'left', 'right'}, default 'inner' |
942 | 977 |
|
943 | 978 | Returns |
944 | 979 | ------- |
| 980 | + tuple of (left_indexer, right_indexer) |
| 981 | + indexers into the left_keys, right_keys |
945 | 982 |
|
946 | 983 | """ |
947 | 984 | from functools import partial |
@@ -1345,6 +1382,13 @@ def _factorize_keys(lk, rk, sort=True): |
1345 | 1382 | if is_datetime64tz_dtype(lk) and is_datetime64tz_dtype(rk): |
1346 | 1383 | lk = lk.values |
1347 | 1384 | rk = rk.values |
| 1385 | + |
| 1386 | + # if we exactly match in categories, allow us to use codes |
| 1387 | + if (is_categorical_dtype(lk) and |
| 1388 | + is_categorical_dtype(rk) and |
| 1389 | + lk.is_dtype_equal(rk)): |
| 1390 | + return lk.codes, rk.codes, len(lk.categories) |
| 1391 | + |
1348 | 1392 | if is_int_or_datetime_dtype(lk) and is_int_or_datetime_dtype(rk): |
1349 | 1393 | klass = _hash.Int64Factorizer |
1350 | 1394 | lk = _ensure_int64(com._values_from_object(lk)) |
|
0 commit comments