diff --git a/bittensor/utils/tokenizer_utils.py b/bittensor/utils/tokenizer_utils.py index 10a82df5b0..19911d96ff 100644 --- a/bittensor/utils/tokenizer_utils.py +++ b/bittensor/utils/tokenizer_utils.py @@ -872,19 +872,31 @@ def unravel_topk_token_phrases(compact_topk: torch.Tensor, topk: int, ignore_ind batch_size = len(prob_idx) // (topk + 1) # (batch_size * (topk + floor)) / (topk + floor) assert batch_size * (topk + 1) == len(prob_idx), f'{batch_size} * ({topk} + 1) != {len(prob_idx)}' # decoding irregularity otherwise - # split into topk token phrases with prob prepend [prob, tok_0, tok_1, ... tok_n] - phrases = [s.tolist() for s in torch.tensor_split(compact_topk, prob_idx)] # tolist for faster list comprehension - phrases = phrases[1:] # ignore first (empty) split + # Obtain phrase lengths and maximum phrase length + phrase_len = prob_idx[1:] - prob_idx[:-1] # [batch_size * (topk + 1) - 1] length of each phrase + phrase_len = torch.cat((phrase_len, torch.tensor([1]))) # [batch_size * (topk + 1)] prob_floor is always len=1 + max_len = phrase_len.max() # determine width of topk_tensor as max len of all phrase lists (with prob in front) - # determine width of topk_tensor as max len of all phrase lists (with prob in front) - max_len = max([len(p) for p in phrases]) # max_{b,k}(len([prob_k, tok_0_k, tok_1_k, ...])) + # Initialize topk_tensor with ignore_index + 2, since decrement with 2 follows to remove token offset later + topk_tensor = (ignore_index + 2) * torch.ones((batch_size * (topk + 1), max_len)) # [batch_size * (topk + 1), max_len] + + # Insert phrases of each unique length as block into topk_tensor + for unique_len in phrase_len.unique(): + if unique_len <= 1: + continue # skip probability column, will be added afterward + + phrase_idx = torch.where(phrase_len == unique_len)[0] # phrase indices where phrase_len is unique_len + compact_idx = prob_idx[phrase_idx] # indices in compact_topk + + # Create indexing block, add index for each phrase position, skip first (prob) position + block_idx = [compact_idx + position for position in range(1, unique_len)] # incrementally add each position of phrase + # transpose .t() ensures correct interleaving of consecutive positions: + # [[phrase_a_1, phrase_a_2, ..., phrase_a_n], [phrase_b_1, phrase_b_2, ..., phrase_b_n], ...] + block_idx = torch.vstack(block_idx).t().reshape(-1, unique_len - 1) # [-1, unique_len - 1] for all phrases with unique_len - ignore_index_2 = ignore_index + 2 # increment with 2, as decrement with 2 follows + topk_tensor[phrase_idx, 1:unique_len] = compact_topk[block_idx] # slice selected phrases and copy into topk_tensor - # form single 2D tensor with topk token phrases with prob prepend [prob, tok_0, tok_1, ... tok_n] - topk_tensor = torch.tensor([p + [ignore_index_2] * (max_len - len(p)) - for p in phrases]).to(compact_topk.device) # [batch_size * (topk + 1), max_len] - topk_tensor -= 2 # remove token offset + topk_tensor -= 2 # remove token offset, overwrites probability column, replace probabilities below # grafting probability tensors into first column to attach gradients topk_tensor[:, 0] = compact_topk[prob_idx] # tensor([prob_k=0_b, prob_k=1_b, ..., prob_floor_b]) diff --git a/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.py b/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.py index 32df6fa499..482d7f8c18 100644 --- a/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.py +++ b/tests/unit_tests/bittensor_tests/utils/test_tokenizer_utils.py @@ -433,6 +433,77 @@ def test_topk_token_phrases(): tokenizer_topk_phrases(sample_text[text_name], model_name, max_length, _enc_pre_logits, topk=128) +def test_random_topk_token_phrases(single_token_ratios: Tuple = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0), + max_len_final: int = 10, batch_size: int = 32, topk: int = 4096, + ignore_index: int = -100, vocab_len: int = 50256): + r""" + Asserts that randomly instantiated compact_topk encodings can be correctly decoded + to recover the original topk_tensor, where: + topk_tensor: + [batch_size, (topk + 1), max_len] tensor includes topk token probabilities (prob_k) + floor_prob + in first column with gradients attached, with std_tokens in remaining columns with ignore_index padding. + Content structure: + [[[prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., ignore_index?], + [prob_k=1_b=0, tok_0_k=1_b=0, tok_1_k=1_b=0, ..., ignore_index?], + [...], + [prob_floor_b=0, ignore_index, ..., ignore_index]], + [[prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., ignore_index?], + [prob_k=1_b=1, tok_0_k=1_b=1, tok_1_k=1_b=1, ..., ignore_index?], + [...], + [prob_floor_b=1, ignore_index, ..., ignore_index]], + [...]] + compact_topk: + [sum_b(sum_k(len(phrase_k) + 1)_b)] Compacted 1-D tensor >= batch_size * (2 * topk + 1), + since 2 * topk + 1: topk x [probability, token sequence (at least one token)] + + floor probability (rest). + Content structure: + [prob_k=0_b=0, tok_0_k=0_b=0, tok_1_k=0_b=0, ..., prob_k=1_b=0, tok_0_k=1_b=0, ..., prob_floor_b=0, + prob_k=0_b=1, tok_0_k=0_b=1, tok_1_k=0_b=1, ..., prob_k=1_b=1, tok_0_k=1_b=1, ..., prob_floor_b=1, + ...] + + Args: + single_token_ratios (:obj:`Tuple`, `optional`): + Series of ratios of single-token phrases to total phrases, to test individually. + max_len_final (:obj:`int`, `optional`): + The maximum phrase length to test. + batch_size (:obj:`int`, `optional`): + The batch_size of the test input. + topk (:obj:`int`, `optional`): + The topk of the test input, the amount of logits retained. + ignore_index (:obj:`int`, `optional`): + The padding value after the end of each phrase. + vocab_len (:obj:`int`, `optional`): + The tokenizer vocabulary length. + + Returns: + """ + for single_token_ratio in single_token_ratios: # for each single token occurrence ratio + for _max_len in torch.arange(3, max_len_final): # for each max_len in range 3 to max_len_final + longer_phrases = int(topk * (1 - single_token_ratio) / (_max_len - 2)) # number of multi-token phrases per length + max_len = _max_len if longer_phrases > 0 else 2 # change max_len if only single_phrases + single_phrases = topk - (max_len - 2) * longer_phrases # number of [prob, token, ignore_index, ...] phrases + + topk_tensor = ignore_index * torch.ones((batch_size, topk + 1, max_len)) # [batch_size, (topk + 1), max_len] + + for batch in range(batch_size): # construct each batch separately + permuted = torch.randperm(topk) + + # add single token phrases: [prob, token, ignore_index, ..., ignore_index] + topk_tensor[batch, permuted[:single_phrases], 1:2] = 1. * torch.randint(vocab_len, (single_phrases, 1)) + + # add longer token phrases: [prob, token, token, ..., ignore_index?, ..., ignore_index] + for length in range(2, max_len): + start = single_phrases + (length - 2) * longer_phrases + phrase_idx = permuted[start:start + longer_phrases] + topk_tensor[batch, phrase_idx, 1:length+1] = 1. * torch.randint(vocab_len, (longer_phrases, length)) + + topk_tensor[:, :, 0] = torch.rand((batch_size, topk + 1)) # assign random probabilities to first column + + compact_topk = compact_topk_token_phrases(topk_tensor) # [>= batch_size * (2 * topk + 1)] + _topk_tensor = unravel_topk_token_phrases(compact_topk, topk=topk) # [batch_size, (topk + 1), max_len] + assert torch.all(torch.eq(_topk_tensor, topk_tensor)) + + def topk_phrases_crossentropy(text_batch: List[str], model_name: str, max_length: int, last_indices: List[int], enc_pre_logits: torch.FloatTensor = None,