diff --git a/bittensor/utils/tokenizer_utils.py b/bittensor/utils/tokenizer_utils.py index 19911d96ff..9192556097 100644 --- a/bittensor/utils/tokenizer_utils.py +++ b/bittensor/utils/tokenizer_utils.py @@ -878,7 +878,8 @@ def unravel_topk_token_phrases(compact_topk: torch.Tensor, topk: int, ignore_ind max_len = phrase_len.max() # determine width of topk_tensor as max len of all phrase lists (with prob in front) # Initialize topk_tensor with ignore_index + 2, since decrement with 2 follows to remove token offset later - topk_tensor = (ignore_index + 2) * torch.ones((batch_size * (topk + 1), max_len)) # [batch_size * (topk + 1), max_len] + topk_tensor = torch.ones((batch_size * (topk + 1), max_len), device=compact_topk.device) + topk_tensor *= ignore_index + 2 # [batch_size * (topk + 1), max_len] # Insert phrases of each unique length as block into topk_tensor for unique_len in phrase_len.unique():