@@ -19,37 +19,41 @@ def test_set_ops_error_cases(idx, case, sort, method):
1919
2020
2121@pytest .mark .parametrize ("sort" , [None , False ])
22- def test_intersection_base (idx , sort ):
22+ @pytest .mark .parametrize ("klass" , [MultiIndex , np .array , Series , list ])
23+ def test_intersection_base (idx , sort , klass ):
2324 first = idx [2 ::- 1 ] # first 3 elements reversed
2425 second = idx [:5 ]
2526
26- array_like_cases = [klass (second .values ) for klass in [np .array , Series , list ]]
27- for case in [second , * array_like_cases ]:
28- intersect = first .intersection (case , sort = sort )
29- if sort is None :
30- expected = first .sort_values ()
31- else :
32- expected = first
33- tm .assert_index_equal (intersect , expected )
27+ if klass is not MultiIndex :
28+ second = klass (second .values )
29+
30+ intersect = first .intersection (second , sort = sort )
31+ if sort is None :
32+ expected = first .sort_values ()
33+ else :
34+ expected = first
35+ tm .assert_index_equal (intersect , expected )
3436
3537 msg = "other must be a MultiIndex or a list of tuples"
3638 with pytest .raises (TypeError , match = msg ):
3739 first .intersection ([1 , 2 , 3 ], sort = sort )
3840
3941
4042@pytest .mark .parametrize ("sort" , [None , False ])
41- def test_union_base (idx , sort ):
43+ @pytest .mark .parametrize ("klass" , [MultiIndex , np .array , Series , list ])
44+ def test_union_base (idx , sort , klass ):
4245 first = idx [::- 1 ]
4346 second = idx [:5 ]
4447
45- array_like_cases = [klass (second .values ) for klass in [np .array , Series , list ]]
46- for case in [second , * array_like_cases ]:
47- union = first .union (case , sort = sort )
48- if sort is None :
49- expected = first .sort_values ()
50- else :
51- expected = first
52- tm .assert_index_equal (union , expected )
48+ if klass is not MultiIndex :
49+ second = klass (second .values )
50+
51+ union = first .union (second , sort = sort )
52+ if sort is None :
53+ expected = first .sort_values ()
54+ else :
55+ expected = first
56+ tm .assert_index_equal (union , expected )
5357
5458 msg = "other must be a MultiIndex or a list of tuples"
5559 with pytest .raises (TypeError , match = msg ):
0 commit comments