diff --git a/bittensor/_neuron/text/core_validator/__init__.py b/bittensor/_neuron/text/core_validator/__init__.py index efef464611..6fcf60bea3 100644 --- a/bittensor/_neuron/text/core_validator/__init__.py +++ b/bittensor/_neuron/text/core_validator/__init__.py @@ -36,6 +36,7 @@ from rich.console import Console from rich.style import Style from rich.table import Table +from rich.errors import MarkupError from rich.traceback import install from typing import List, Tuple, Callable, Dict, Any, Union, Set @@ -1232,6 +1233,11 @@ def _synergy(first, second, target, ext): logger.info(f'{str(synapse)} \t| Shapley synergy values [{time.time() - synergy_start_time:.3g}s]') if logging: + # === Response table === + # Prints the query response table: top prediction probabilities and texts for batch tasks + batch_predictions = format_predictions(uids, query_responses, return_ops, inputs, validation_len, index_s) + response_table(batch_predictions, stats, sort_col='shapley_values_nxt', console_width=console_width) + # === Synergy table === # Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal) synergy_table(stats, syn_loss_diff, 'shapley_values_nxt', console_width) @@ -1392,6 +1398,114 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens return syn_loss_diff +def format_predictions(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], + return_ops: List[torch.LongTensor], inputs: torch.FloatTensor, + validation_len: int, index_s: int = 0, number_of_predictions: int = 3) -> List: + r""" Format batch task topk predictions for rich table print of query responses. + """ + batch_predictions = [] + std_tokenizer = bittensor.tokenizer() + + # === Batch iteration === + for batch_item in range(inputs.shape[0]): + # === Task formatting === + context = inputs[batch_item][:-validation_len] + answer = inputs[batch_item][-validation_len:] + + context = repr(std_tokenizer.decode(context))[1:-1][-30:] # strip '' and truncate + answer = repr(std_tokenizer.decode(answer))[1:-1][:15] # strip '' and truncate + + task = f"[reverse]{context}[/reverse][bold]{answer}[/bold]" + + # === Prediction formatting === + predictions = {} + for index, uid in enumerate(uids.tolist()): + if return_ops[index][index_s] == bittensor.proto.ReturnCode.Success: + topk_tensor = query_responses[index][index_s] # [batch_size, (topk + 1), max_len] (prob_k) + floor_prob + topk_tokens = topk_tensor[batch_item, :-1, 1:].int() # [batch_size, topk, max_len - 1] Phrase tokens with ignore_index token for padding. + topk_probs = topk_tensor[batch_item, :-1, 0] # [batch_size, topk] Probabilities for each phrase in topk + + # === Topk iteration === + topk_predictions = '' + for i in range(number_of_predictions): + phrase = topk_tokens[i] + phrase = phrase[phrase >= 0] # strip negative ignore_index = -100 + phrase_str = repr(std_tokenizer.decode(phrase))[:15] # decode, escape and truncate + + prob = f'{topk_probs[i]:.3f}' + prob = prob[1:] if prob[0] == '0' else prob[:-1] # remove obvious leading 0 + + topk_predictions += f"[green]{prob}[/green]: {phrase_str} " + + predictions[uid] = topk_predictions[:-1] # strip trailing space + + batch_predictions += [(task, predictions)] + + return batch_predictions + + +def response_table(batch_predictions: List, stats: Dict, sort_col: str, console_width: int, + task_repeat: int = 4, tasks_per_server: int = 3): + r""" Prints the query response table: top prediction probabilities and texts for batch tasks. + """ + # === Batch permutation === + batch_size = len(batch_predictions) + if batch_size == 0: + return + batch_perm = torch.randperm(batch_size) # avoid restricting observation to predictable subsets + + # === Column selection === + columns = [c[:] for c in neuron_stats_columns if c[1] in ['uid', sort_col, 'loss_nxt', 'synergy_nxt']] + col_keys = [c[1] for c in columns] + + # === Sort rows === + sort = sorted([(uid, s[sort_col]) for uid, s in stats.items() if sort_col in s], reverse=True, key=lambda _row: _row[1]) + if sort_col in col_keys: + sort_idx = col_keys.index(sort_col) # sort column with key of sort_col + columns[sort_idx][0] += '\u2193' # ↓ downwards arrow (sort) + + for i, (uid, val) in enumerate(sort): + # === New table section === + if i % task_repeat == 0: + table = Table(width=console_width, box=None) + if i == 0: + table.title = f"[white bold] Query responses [/white bold] | " \ + f"[white]context[/white][bold]continuation[/bold] | .prob: 'prediction'" + + for col, _, _, stl in columns: # [Column_name, key_name, format_string, rich_style] + table.add_column(col, style=stl, justify='right') + + # === Last table section === + if i == len(sort) - 1: + table.caption = f'[bold]{len(sort)}[/bold]/{len(stats)} (respond/topk) | ' \ + f'[bold]{tasks_per_server}[/bold] tasks per server | ' \ + f'repeat tasks over [bold]{task_repeat}[/bold] servers ' \ + f'[white]\[{math.ceil(1. * len(sort) / task_repeat) * tasks_per_server}/' \ + f'{batch_size} batch tasks][/white]' + + # === Row addition === + row = [txt.format(stats[uid][key]) for _, key, txt, _ in columns] + for j in range(tasks_per_server): + batch_item = ((i // task_repeat) * tasks_per_server + j) % batch_size # repeat task over servers, do not exceed batch_size + task, predictions = batch_predictions[batch_perm[batch_item]] + row += [predictions[uid]] + + if i % task_repeat == 0: + table.add_column(task, header_style='not bold', style='', justify='left') + + table.add_row(*row) + + # === Table print === + if (i == len(sort) - 1) or (i % task_repeat == task_repeat - 1): + try: + print(table) + except MarkupError as e: + print(e) + else: + if i == len(sort) - 1: + print() + + def synergy_table(stats, syn_loss_diff, sort_col, console_width): r""" Prints the synergy loss diff matrix with pairwise loss reduction due to synergy (original loss on diagonal) """ @@ -1464,10 +1578,10 @@ def stats_table(stats, sort_col, console_width, title, caption, mark_uids=None): def synapse_table(name, stats, sort_col, console_width, start_time): r""" Prints the evaluation of the neuron responses to the validator request """ - stats_table(stats, sort_col, console_width, f'[white] \[{name}] responses [/white] | Validator forward', # title - f'[bold]{len([s for s in stats.values() if len(s)])}[/bold]/{len(stats)} (respond/topk) | ' + f'[bold]{len([s for s in stats.values() if len(s) and sort_col in s])}[/bold]/' + f'{len(stats)} (respond/topk) | ' f'[bold]Synapse[/bold] | [white]\[{time.time() - start_time:.3g}s][/white]' # caption )