diff --git a/bittensor/_neuron/text/core_validator/__init__.py b/bittensor/_neuron/text/core_validator/__init__.py index 4329f9265b..ab9b86acf3 100644 --- a/bittensor/_neuron/text/core_validator/__init__.py +++ b/bittensor/_neuron/text/core_validator/__init__.py @@ -365,14 +365,19 @@ def run_epoch( self ): # Each block length lasts blocks_per_epoch blocks. # This gives us a consistent network wide timer. # Here we run until blocks_per_epochs have progressed. - self.metagraph_sync() # Reset metagraph. + if self.epoch > 0: # skip first epoch: already synced at start of run + self.metagraph_sync() # Reset metagraph. + + self.nucleus.permute_uids = [] # clear nucleus permutation before epoch epoch_steps = 0 epoch_responsive_uids = set() epoch_queried_uids = set() + epoch_start_time = time.time() start_block = self.subtensor.block - while self.subtensor.block < start_block + blocks_per_epoch: + while (self.subtensor.block < start_block + blocks_per_epoch or + len(epoch_queried_uids) < self.metagraph.n): # ensure each UID is queried at least once - assumes nucleus samples without replacement start_time = time.time() # === Forward === @@ -434,8 +439,9 @@ def run_epoch( self ): f'[dim] Epoch {self.epoch}[/dim] | ' f'[bright_green not bold]{len(responsive_uids)}[/bright_green not bold]/' f'[white]{len(queried_uids)}[/white] ' - f'[dim white not bold][green]responsive[/green]/queried[/dim white not bold] ' - f'[[yellow]{step_time:.3g}[/yellow]s]') + f'[[yellow]{step_time:.3g}[/yellow]s] ' + f'[dim white not bold][green]{len(epoch_responsive_uids)}[/green]/' + f'{len(epoch_queried_uids)}[/dim white not bold]') if self.config.logging.debug or self.config.logging.trace: # === Print stats update (table) === @@ -485,6 +491,20 @@ def run_epoch( self ): if self.config.logging.debug or self.config.logging.trace: self.weights_table(sample_uids, sample_weights) # print weights table + # set weights console message (every epoch) + print(f"[white not bold]{datetime.datetime.now():%Y-%m-%d %H:%M:%S}[/white not bold]{' ' * 4} | " + f"{f'[bright_white]Set weights[/bright_white]'.center(16 + len('[bright_white][/bright_white]'))} | " + f'[bright_green not bold]{len(sample_weights)}[/bright_green not bold] [dim]weights set[/dim] | ' + f'[bright_green not bold]{len(epoch_responsive_uids)}[/bright_green not bold]/' + f'[white]{len(epoch_queried_uids)}[/white] ' + f'[dim white not bold][green]responsive[/green]/queried[/dim white not bold] ' + f'[[yellow]{time.time() - epoch_start_time:.0f}[/yellow]s] | ' + f'[dim]weights[/dim] sum:{sample_weights.sum().item():.2g} ' + f'[white] max:[bold]{sample_weights.max().item():.4g}[/bold] / ' + f'min:[bold]{sample_weights.min().item():.4g}[/bold] [/white] ' + f'\[{sample_weights.max().item() / sample_weights.min().item():.1f}:1] ' + f'({max_allowed_ratio} allowed)') + self.subtensor.set_weights( uids=sample_uids.detach().to('cpu'), weights=sample_weights.detach().to('cpu'), @@ -543,6 +563,7 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]): zkey = key + '!' # zeroing key stats.setdefault(zkey, 0.) # initialize zkey val to zero to gradually increase with observations if key in _stats and not math.isnan(_stats[key]): + responsive_uids += [_uid] stats[zkey] = (1 - self.alpha) * stats[zkey] + self.alpha * _stats[key] else: stats[zkey] = (1 - self.alpha) * stats[zkey] # + self.alpha * 0 @@ -555,7 +576,6 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]): updates = 'updates_' + key if updates in stats: stats[updates] += 1 # increment number of normal EMA updates made - responsive_uids += [_uid] else: stats.setdefault(updates, 1) # add updates fields for new uid entries @@ -660,6 +680,7 @@ def __init__( self, config, device, subtensor ): self.config = config self.device = device self.max_n = subtensor.max_n + self.permute_uids = [] # iterable of next UIDs to query, reset to permuted UIDs when empty tokenizer = bittensor.tokenizer() self.pad_token = tokenizer(tokenizer.pad_token)['input_ids'][0] @@ -702,6 +723,7 @@ def add_args( cls, parser ): parser.add_argument('--nucleus.noise_multiplier', type=float, help='Standard deviation multipler on weights', default=2 ) parser.add_argument('--nucleus.dendrite_backward', action='store_true', help='Pass backward request to the server side or not', default=False ) parser.add_argument('--nucleus.scaling_law_power', type=float, help='Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.', default=0.5) + parser.add_argument('--nucleus.synergy_scaling_law_power', type=float, help='Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.', default=0.6) @classmethod def config ( cls ): @@ -794,18 +816,26 @@ def forward( # Ensure number of queried neurons does not exceed metagraph.n num_endpoints = min([self.config.nucleus.topk, metagraph.n]) - logger.info(f'Forward \t| Routing forward [{time.time() - start_time:.3g}s]') - logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}') - request_start_time = time.time() + # === Ensure each UID is queried once === + # Persist object variable self.permute_uids across forward calls. + # Reset to new permutation of all UIDs once empty. + if len(self.permute_uids) == 0: # no more UIDs to query + self.permute_uids = torch.randperm(metagraph.n) # reset to new permutation of all UIDs # === Randomly select num_endpoints UIDs === - random_uids = torch.randperm(metagraph.n)[:num_endpoints] + random_uids = self.permute_uids[:num_endpoints] # newest selection of UIDs to query + self.permute_uids = self.permute_uids[num_endpoints:] # slice out remaining selection # === Get endpoint information for the selected UIDs === # We index into the metagraph's endpoints and return a list of the filtered set of endpoints we wish to query. # random_endpoints: List[bittensor.endpoints]: endpoint information for filtered uids. # len(neurons) == self.config.nucleus.topk random_endpoints = [metagraph.endpoints[uid] for uid in random_uids] + num_endpoints = len(random_endpoints) # in case len(self.permute_uids) < num_endpoints during random_uids select + + logger.info(f'Forward \t| Routing forward [{time.time() - start_time:.3g}s]') + logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}') + request_start_time = time.time() # === Define which synapse we want to use === # The synapse defines the task we are sending to the neurons @@ -847,8 +877,9 @@ def forward( # === Prepare validation parameter set === console_width = self.config.get('width', None) # console width for rich table displays of synapse measures validation_params = (random_uids, query_responses, return_ops, times, routing_score, - inputs, val_len, self.loss_fct, self.config.nucleus.scaling_law_power, console_width, - self.config.logging.debug or self.config.logging.trace) + inputs, val_len, self.loss_fct, + self.config.nucleus.scaling_law_power, self.config.nucleus.synergy_scaling_law_power, + console_width, self.config.logging.debug or self.config.logging.trace) loss = torch.tensor(0.).to(self.device) # to accumulate neuron_loss and routing_loss over synapses neuron_stats = {} # to gather neuron synapse validation measures and statistics @@ -876,7 +907,8 @@ def scaling_law_loss_to_params(loss): def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor], times: List[torch.FloatTensor], routing_score: torch.FloatTensor, - inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, scaling_law_power: float, + inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, + scaling_law_power: float, synergy_scaling_law_power: float, console_width: int, logging, synapse: 'bittensor.TextCausalLM' = None, index_s: int = 0 ) -> Tuple[torch.FloatTensor, Dict]: r""" @@ -901,6 +933,8 @@ def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTenso CrossEntropy loss function to use. scaling_law_power (:obj:`float`, `required`): Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. + synergy_scaling_law_power (:obj:`float`, `required`): + Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. console_width (:obj:`int`, `required`): Config console width for table print. logging (:obj:`bool`, `required`): @@ -957,9 +991,9 @@ def _synergy(first, second, target, _ext): synergy_start_time = time.time() syn_loss_diff = shapley_synergy(stats, _synergy, ext='', target=inputs_seq[:, 1:], - scaling_law_power=scaling_law_power) + scaling_law_power=synergy_scaling_law_power) syn_loss_diff_val = shapley_synergy(stats, _synergy, ext='_val', target=inputs_val, - scaling_law_power=scaling_law_power) + scaling_law_power=synergy_scaling_law_power) # === Shapley value combination === # Combine base values with synergy approximation to get final Shapley values. @@ -998,7 +1032,8 @@ def _synergy(first, second, target, _ext): def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor], times: List[torch.FloatTensor], routing_score: torch.FloatTensor, - inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, scaling_law_power: float, + inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, + scaling_law_power: float, synergy_scaling_law_power: float, console_width: int, logging, synapse: 'bittensor.TextCausalLMNext' = None, index_s: int = 0 ) -> Tuple[torch.FloatTensor, Dict]: r""" @@ -1023,6 +1058,8 @@ def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatT CrossEntropy loss function to use. scaling_law_power (:obj:`float`, `required`): Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. + synergy_scaling_law_power (:obj:`float`, `required`): + Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5. console_width (:obj:`int`, `required`): Config console width for table print. logging (:obj:`bool`, `required`): @@ -1075,7 +1112,7 @@ def _synergy(first, second, target, ext): synergy_start_time = time.time() - syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=scaling_law_power) + syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=synergy_scaling_law_power) # === Shapley value combination === # Combine base values with synergy approximation to get final Shapley values. @@ -1212,6 +1249,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens # Synergy = measured performance above expected performance # Measured in effective number of model parameters, just like base Shapley values. syn_loss_diff = {} # expected_loss - measured_loss (where > 0) + responsives = [uid for uid, stat in stats.items() if 'loss' + ext in stat] for _first, first in stats.items(): if 'loss' + ext not in first: continue @@ -1229,6 +1267,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens measured_loss = synergy(first, second, target, ext) loss_diff_share = torch.clamp(expected_loss - measured_loss, 0) / 2 # record direct loss diff + loss_diff_share /= len(responsives) # average over responsives first['synergy_loss_diff' + ext] += loss_diff_share second['synergy_loss_diff' + ext] += loss_diff_share @@ -1244,6 +1283,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens pow_expected_params = torch.pow(expected_params, scaling_law_power) synergy_share = torch.clamp(pow_measured_params - pow_expected_params, 0) / 2 + synergy_share /= len(responsives) # average over responsives first['synergy' + ext] += synergy_share # share synergy amongst coalition members second['synergy' + ext] += synergy_share diff --git a/requirements.txt b/requirements.txt index 1c7f9b8fd6..c96f9a27e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,7 @@ pytest-rerunfailures coveralls pytest-cov pyyaml -rich +rich>=12.5.1 retry requests>=2.25.0 scalecodec>=1.0.35