11from __future__ import annotations
22
3+ import collections
34import itertools
45from typing import (
56 TYPE_CHECKING ,
@@ -162,8 +163,15 @@ def _indexer_and_to_sort(
162163 ]:
163164 v = self .level
164165
165- codes = list (self .index .codes )
166166 levs = list (self .index .levels )
167+ codes = list (self .index .codes )
168+
169+ if not self .sort :
170+ # Create new codes considering that labels are already sorted
171+ for i in range (len (codes )):
172+ dd = collections .defaultdict (itertools .count ().__next__ )
173+ codes [i ] = np .array ([dd [c ] for c in codes [i ]], dtype = codes [i ].dtype )
174+
167175 to_sort = codes [:v ] + codes [v + 1 :] + [codes [v ]]
168176 sizes = tuple (len (x ) for x in levs [:v ] + levs [v + 1 :] + [levs [v ]])
169177
@@ -174,25 +182,33 @@ def _indexer_and_to_sort(
174182 return indexer , to_sort
175183
176184 @cache_readonly
177- def sorted_labels (self ) -> list [np .ndarray ]:
185+ def labels (self ) -> list [np .ndarray ]:
178186 indexer , to_sort = self ._indexer_and_to_sort
179187 if self .sort :
180188 return [line .take (indexer ) for line in to_sort ]
181189 return to_sort
182190
183- def _make_sorted_values (self , values : np .ndarray ) -> np .ndarray :
191+ @cache_readonly
192+ def sorted_labels (self ) -> list [np .ndarray ]:
184193 if self .sort :
185- indexer , _ = self ._indexer_and_to_sort
194+ return self .labels
186195
187- sorted_values = algos .take_nd (values , indexer , axis = 0 )
188- return sorted_values
189- return values
196+ v = self .level
197+ codes = list (self .index .codes )
198+ to_sort = codes [:v ] + codes [v + 1 :] + [codes [v ]]
199+ return to_sort
200+
201+ def _make_sorted_values (self , values : np .ndarray ) -> np .ndarray :
202+ indexer , _ = self ._indexer_and_to_sort
203+ sorted_values = algos .take_nd (values , indexer , axis = 0 )
204+ return sorted_values
190205
191206 def _make_selectors (self ):
192207 new_levels = self .new_index_levels
193208
194209 # make the mask
195- remaining_labels = self .sorted_labels [:- 1 ]
210+ remaining_labels = self .labels [:- 1 ]
211+ choosen_labels = self .labels [- 1 ]
196212 level_sizes = tuple (len (x ) for x in new_levels )
197213
198214 comp_index , obs_ids = get_compressed_ids (remaining_labels , level_sizes )
@@ -202,7 +218,7 @@ def _make_selectors(self):
202218 stride = self .index .levshape [self .level ] + self .lift
203219 self .full_shape = ngroups , stride
204220
205- selector = self . sorted_labels [ - 1 ] + stride * comp_index + self .lift
221+ selector = choosen_labels + stride * comp_index + self .lift
206222 mask = np .zeros (np .prod (self .full_shape ), dtype = bool )
207223 mask .put (selector , True )
208224
0 commit comments