diff --git a/bittensor/__init__.py b/bittensor/__init__.py index a3ecd64793..429ca61f2b 100644 --- a/bittensor/__init__.py +++ b/bittensor/__init__.py @@ -23,7 +23,7 @@ nest_asyncio.apply() # Bittensor code and protocol version. -__version__ = '3.4.2' +__version__ = '3.4.3' version_split = __version__.split(".") __version_as_int__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2])) diff --git a/bittensor/_synapse/text_causallmnext_impl.py b/bittensor/_synapse/text_causallmnext_impl.py index 284b018d9e..9516ce3c22 100644 --- a/bittensor/_synapse/text_causallmnext_impl.py +++ b/bittensor/_synapse/text_causallmnext_impl.py @@ -123,6 +123,10 @@ def check_forward_response_tensor(self, forward_request_tensor, forward_response f"[>={forward_request_tensor.shape[0]} x (2 x {self.topk} + 1)], " f"got: {forward_response_tensor.size(0)} for synapse: {self}") + atol = 1e-6 # absolute tolerance + if (forward_response_tensor < -atol).any(): + raise ValueError("forward_response_tensor values below tolerance.") + def check_backward_request_gradient(self, forward_request_tensor, backward_request_gradient): # forward_request_tensor: [batch_size, sequence_len] # backward_request_gradient: [batch_size, (topk + 1), max_len] diff --git a/bittensor/utils/tokenizer_utils.py b/bittensor/utils/tokenizer_utils.py index 9192556097..471d4f822c 100644 --- a/bittensor/utils/tokenizer_utils.py +++ b/bittensor/utils/tokenizer_utils.py @@ -866,11 +866,17 @@ def unravel_topk_token_phrases(compact_topk: torch.Tensor, topk: int, ignore_ind [...]] """ + atol = 1e-6 # absolute tolerance # Find probability markers (per batch item: topk phrase probabilities + floor_prob) - prob_idx = torch.where(compact_topk <= 1.5)[0] # 0 <= prob <= 1 [batch_size * (topk + 1)], expect token_ids >= 2 + prob_idx = torch.where((-atol < compact_topk) & (compact_topk < 1 + atol))[0] # 0 <= prob <= 1 [batch_size * (topk + 1)], expect token_ids >= 2 batch_size = len(prob_idx) // (topk + 1) # (batch_size * (topk + floor)) / (topk + floor) - assert batch_size * (topk + 1) == len(prob_idx), f'{batch_size} * ({topk} + 1) != {len(prob_idx)}' # decoding irregularity otherwise + assert batch_size * (topk + 1) == len(prob_idx), f'unravel_topk_token_phrases() probability marker failure: ' \ + f'{batch_size} * ({topk} + 1) != {len(prob_idx)}' # decoding irregularity otherwise + + probs = torch.clamp(compact_topk[prob_idx], 0, 1) # [batch_size * (topk + 1)] ensure probabilities within [0, 1] + probs_sum = probs.reshape(batch_size, topk + 1).sum(dim=1) # [batch_size] + assert torch.all((-atol < probs_sum) & (probs_sum < 1 + atol)), f'unravel_topk_token_phrases(): probs_sum not in [0, 1]' # Obtain phrase lengths and maximum phrase length phrase_len = prob_idx[1:] - prob_idx[:-1] # [batch_size * (topk + 1) - 1] length of each phrase @@ -900,7 +906,7 @@ def unravel_topk_token_phrases(compact_topk: torch.Tensor, topk: int, ignore_ind topk_tensor -= 2 # remove token offset, overwrites probability column, replace probabilities below # grafting probability tensors into first column to attach gradients - topk_tensor[:, 0] = compact_topk[prob_idx] # tensor([prob_k=0_b, prob_k=1_b, ..., prob_floor_b]) + topk_tensor[:, 0] = probs # tensor([prob_k=0_b, prob_k=1_b, ..., prob_floor_b]) topk_tensor = topk_tensor.reshape(batch_size, topk + 1, max_len) # [batch_size, (topk + 1), max_len] reshaped @@ -953,6 +959,9 @@ def phrase_cross_entropy(target_phrases: Union[List[List[int]], torch.Tensor], topk_probs = topk_tensor[:, :-1, 0] # [batch_size, topk] Probabilities for each phrase in topk floor_probs = topk_tensor[:, -1, 0] # [batch_size] Floor probabilities as mean probability for non-topk tokens + topk_probs = torch.clamp(topk_probs, 0, 1) # [batch_size, topk] ensure probabilities within [0, 1] + floor_probs = torch.clamp(floor_probs, 0, 1) # [batch_size] ensure floor probabilities within [0, 1] + # === Ensure total probability is 1 === total_probs = topk_probs.sum(dim=-1) + max(0, vocab_size_min - topk) * floor_probs # [batch_size] total probs n_topk_probs = topk_probs / total_probs[:, None] # [batch_size, topk] normalized topk_probs