Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 116 additions & 2 deletions bittensor/_neuron/text/core_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from rich.console import Console
from rich.style import Style
from rich.table import Table
from rich.errors import MarkupError
from rich.traceback import install
from typing import List, Tuple, Callable, Dict, Any, Union, Set

Expand Down Expand Up @@ -1232,6 +1233,11 @@ def _synergy(first, second, target, ext):
logger.info(f'{str(synapse)} \t| Shapley synergy values <dim>[{time.time() - synergy_start_time:.3g}s]</dim>')

if logging:
# === Response table ===
# Prints the query response table: top prediction probabilities and texts for batch tasks
batch_predictions = format_predictions(uids, query_responses, return_ops, inputs, validation_len, index_s)
response_table(batch_predictions, stats, sort_col='shapley_values_nxt', console_width=console_width)

# === Synergy table ===
# Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal)
synergy_table(stats, syn_loss_diff, 'shapley_values_nxt', console_width)
Expand Down Expand Up @@ -1392,6 +1398,114 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens
return syn_loss_diff


def format_predictions(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]],
return_ops: List[torch.LongTensor], inputs: torch.FloatTensor,
validation_len: int, index_s: int = 0, number_of_predictions: int = 3) -> List:
r""" Format batch task topk predictions for rich table print of query responses.
"""
batch_predictions = []
std_tokenizer = bittensor.tokenizer()

# === Batch iteration ===
for batch_item in range(inputs.shape[0]):
# === Task formatting ===
context = inputs[batch_item][:-validation_len]
answer = inputs[batch_item][-validation_len:]

context = repr(std_tokenizer.decode(context))[1:-1][-30:] # strip '' and truncate
answer = repr(std_tokenizer.decode(answer))[1:-1][:15] # strip '' and truncate

task = f"[reverse]{context}[/reverse][bold]{answer}[/bold]"

# === Prediction formatting ===
predictions = {}
for index, uid in enumerate(uids.tolist()):
if return_ops[index][index_s] == bittensor.proto.ReturnCode.Success:
topk_tensor = query_responses[index][index_s] # [batch_size, (topk + 1), max_len] (prob_k) + floor_prob
topk_tokens = topk_tensor[batch_item, :-1, 1:].int() # [batch_size, topk, max_len - 1] Phrase tokens with ignore_index token for padding.
topk_probs = topk_tensor[batch_item, :-1, 0] # [batch_size, topk] Probabilities for each phrase in topk

# === Topk iteration ===
topk_predictions = ''
for i in range(number_of_predictions):
phrase = topk_tokens[i]
phrase = phrase[phrase >= 0] # strip negative ignore_index = -100
phrase_str = repr(std_tokenizer.decode(phrase))[:15] # decode, escape and truncate

prob = f'{topk_probs[i]:.3f}'
prob = prob[1:] if prob[0] == '0' else prob[:-1] # remove obvious leading 0

topk_predictions += f"[green]{prob}[/green]: {phrase_str} "

predictions[uid] = topk_predictions[:-1] # strip trailing space

batch_predictions += [(task, predictions)]

return batch_predictions


def response_table(batch_predictions: List, stats: Dict, sort_col: str, console_width: int,
task_repeat: int = 4, tasks_per_server: int = 3):
r""" Prints the query response table: top prediction probabilities and texts for batch tasks.
"""
# === Batch permutation ===
batch_size = len(batch_predictions)
if batch_size == 0:
return
batch_perm = torch.randperm(batch_size) # avoid restricting observation to predictable subsets

# === Column selection ===
columns = [c[:] for c in neuron_stats_columns if c[1] in ['uid', sort_col, 'loss_nxt', 'synergy_nxt']]
col_keys = [c[1] for c in columns]

# === Sort rows ===
sort = sorted([(uid, s[sort_col]) for uid, s in stats.items() if sort_col in s], reverse=True, key=lambda _row: _row[1])
if sort_col in col_keys:
sort_idx = col_keys.index(sort_col) # sort column with key of sort_col
columns[sort_idx][0] += '\u2193' # ↓ downwards arrow (sort)

for i, (uid, val) in enumerate(sort):
# === New table section ===
if i % task_repeat == 0:
table = Table(width=console_width, box=None)
if i == 0:
table.title = f"[white bold] Query responses [/white bold] | " \
f"[white]context[/white][bold]continuation[/bold] | .prob: 'prediction'"

for col, _, _, stl in columns: # [Column_name, key_name, format_string, rich_style]
table.add_column(col, style=stl, justify='right')

# === Last table section ===
if i == len(sort) - 1:
table.caption = f'[bold]{len(sort)}[/bold]/{len(stats)} (respond/topk) | ' \
f'[bold]{tasks_per_server}[/bold] tasks per server | ' \
f'repeat tasks over [bold]{task_repeat}[/bold] servers ' \
f'[white]\[{math.ceil(1. * len(sort) / task_repeat) * tasks_per_server}/' \
f'{batch_size} batch tasks][/white]'

# === Row addition ===
row = [txt.format(stats[uid][key]) for _, key, txt, _ in columns]
for j in range(tasks_per_server):
batch_item = ((i // task_repeat) * tasks_per_server + j) % batch_size # repeat task over servers, do not exceed batch_size
task, predictions = batch_predictions[batch_perm[batch_item]]
row += [predictions[uid]]

if i % task_repeat == 0:
table.add_column(task, header_style='not bold', style='', justify='left')

table.add_row(*row)

# === Table print ===
if (i == len(sort) - 1) or (i % task_repeat == task_repeat - 1):
try:
print(table)
except MarkupError as e:
print(e)
else:
if i == len(sort) - 1:
print()


def synergy_table(stats, syn_loss_diff, sort_col, console_width):
r""" Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal)
"""
Expand Down Expand Up @@ -1464,10 +1578,10 @@ def stats_table(stats, sort_col, console_width, title, caption, mark_uids=None):
def synapse_table(name, stats, sort_col, console_width, start_time):
r""" Prints the evaluation of the neuron responses to the validator request
"""

stats_table(stats, sort_col, console_width,
f'[white] \[{name}] responses [/white] | Validator forward', # title
f'[bold]{len([s for s in stats.values() if len(s)])}[/bold]/{len(stats)} (respond/topk) | '
f'[bold]{len([s for s in stats.values() if len(s) and sort_col in s])}[/bold]/'
f'{len(stats)} (respond/topk) | '
f'[bold]Synapse[/bold] | [white]\[{time.time() - start_time:.3g}s][/white]' # caption
)

Expand Down