diff --git a/bittensor/_neuron/text/core_validator/__init__.py b/bittensor/_neuron/text/core_validator/__init__.py
index 4329f9265b..ab9b86acf3 100644
--- a/bittensor/_neuron/text/core_validator/__init__.py
+++ b/bittensor/_neuron/text/core_validator/__init__.py
@@ -365,14 +365,19 @@ def run_epoch( self ):
# Each block length lasts blocks_per_epoch blocks.
# This gives us a consistent network wide timer.
# Here we run until blocks_per_epochs have progressed.
- self.metagraph_sync() # Reset metagraph.
+ if self.epoch > 0: # skip first epoch: already synced at start of run
+ self.metagraph_sync() # Reset metagraph.
+
+ self.nucleus.permute_uids = [] # clear nucleus permutation before epoch
epoch_steps = 0
epoch_responsive_uids = set()
epoch_queried_uids = set()
+ epoch_start_time = time.time()
start_block = self.subtensor.block
- while self.subtensor.block < start_block + blocks_per_epoch:
+ while (self.subtensor.block < start_block + blocks_per_epoch or
+ len(epoch_queried_uids) < self.metagraph.n): # ensure each UID is queried at least once - assumes nucleus samples without replacement
start_time = time.time()
# === Forward ===
@@ -434,8 +439,9 @@ def run_epoch( self ):
f'[dim] Epoch {self.epoch}[/dim] | '
f'[bright_green not bold]{len(responsive_uids)}[/bright_green not bold]/'
f'[white]{len(queried_uids)}[/white] '
- f'[dim white not bold][green]responsive[/green]/queried[/dim white not bold] '
- f'[[yellow]{step_time:.3g}[/yellow]s]')
+ f'[[yellow]{step_time:.3g}[/yellow]s] '
+ f'[dim white not bold][green]{len(epoch_responsive_uids)}[/green]/'
+ f'{len(epoch_queried_uids)}[/dim white not bold]')
if self.config.logging.debug or self.config.logging.trace:
# === Print stats update (table) ===
@@ -485,6 +491,20 @@ def run_epoch( self ):
if self.config.logging.debug or self.config.logging.trace:
self.weights_table(sample_uids, sample_weights) # print weights table
+ # set weights console message (every epoch)
+ print(f"[white not bold]{datetime.datetime.now():%Y-%m-%d %H:%M:%S}[/white not bold]{' ' * 4} | "
+ f"{f'[bright_white]Set weights[/bright_white]'.center(16 + len('[bright_white][/bright_white]'))} | "
+ f'[bright_green not bold]{len(sample_weights)}[/bright_green not bold] [dim]weights set[/dim] | '
+ f'[bright_green not bold]{len(epoch_responsive_uids)}[/bright_green not bold]/'
+ f'[white]{len(epoch_queried_uids)}[/white] '
+ f'[dim white not bold][green]responsive[/green]/queried[/dim white not bold] '
+ f'[[yellow]{time.time() - epoch_start_time:.0f}[/yellow]s] | '
+ f'[dim]weights[/dim] sum:{sample_weights.sum().item():.2g} '
+ f'[white] max:[bold]{sample_weights.max().item():.4g}[/bold] / '
+ f'min:[bold]{sample_weights.min().item():.4g}[/bold] [/white] '
+ f'\[{sample_weights.max().item() / sample_weights.min().item():.1f}:1] '
+ f'({max_allowed_ratio} allowed)')
+
self.subtensor.set_weights(
uids=sample_uids.detach().to('cpu'),
weights=sample_weights.detach().to('cpu'),
@@ -543,6 +563,7 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):
zkey = key + '!' # zeroing key
stats.setdefault(zkey, 0.) # initialize zkey val to zero to gradually increase with observations
if key in _stats and not math.isnan(_stats[key]):
+ responsive_uids += [_uid]
stats[zkey] = (1 - self.alpha) * stats[zkey] + self.alpha * _stats[key]
else:
stats[zkey] = (1 - self.alpha) * stats[zkey] # + self.alpha * 0
@@ -555,7 +576,6 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):
updates = 'updates_' + key
if updates in stats:
stats[updates] += 1 # increment number of normal EMA updates made
- responsive_uids += [_uid]
else:
stats.setdefault(updates, 1) # add updates fields for new uid entries
@@ -660,6 +680,7 @@ def __init__( self, config, device, subtensor ):
self.config = config
self.device = device
self.max_n = subtensor.max_n
+ self.permute_uids = [] # iterable of next UIDs to query, reset to permuted UIDs when empty
tokenizer = bittensor.tokenizer()
self.pad_token = tokenizer(tokenizer.pad_token)['input_ids'][0]
@@ -702,6 +723,7 @@ def add_args( cls, parser ):
parser.add_argument('--nucleus.noise_multiplier', type=float, help='Standard deviation multipler on weights', default=2 )
parser.add_argument('--nucleus.dendrite_backward', action='store_true', help='Pass backward request to the server side or not', default=False )
parser.add_argument('--nucleus.scaling_law_power', type=float, help='Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.', default=0.5)
+ parser.add_argument('--nucleus.synergy_scaling_law_power', type=float, help='Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.', default=0.6)
@classmethod
def config ( cls ):
@@ -794,18 +816,26 @@ def forward(
# Ensure number of queried neurons does not exceed metagraph.n
num_endpoints = min([self.config.nucleus.topk, metagraph.n])
- logger.info(f'Forward \t| Routing forward [{time.time() - start_time:.3g}s]')
- logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}')
- request_start_time = time.time()
+ # === Ensure each UID is queried once ===
+ # Persist object variable self.permute_uids across forward calls.
+ # Reset to new permutation of all UIDs once empty.
+ if len(self.permute_uids) == 0: # no more UIDs to query
+ self.permute_uids = torch.randperm(metagraph.n) # reset to new permutation of all UIDs
# === Randomly select num_endpoints UIDs ===
- random_uids = torch.randperm(metagraph.n)[:num_endpoints]
+ random_uids = self.permute_uids[:num_endpoints] # newest selection of UIDs to query
+ self.permute_uids = self.permute_uids[num_endpoints:] # slice out remaining selection
# === Get endpoint information for the selected UIDs ===
# We index into the metagraph's endpoints and return a list of the filtered set of endpoints we wish to query.
# random_endpoints: List[bittensor.endpoints]: endpoint information for filtered uids.
# len(neurons) == self.config.nucleus.topk
random_endpoints = [metagraph.endpoints[uid] for uid in random_uids]
+ num_endpoints = len(random_endpoints) # in case len(self.permute_uids) < num_endpoints during random_uids select
+
+ logger.info(f'Forward \t| Routing forward [{time.time() - start_time:.3g}s]')
+ logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}')
+ request_start_time = time.time()
# === Define which synapse we want to use ===
# The synapse defines the task we are sending to the neurons
@@ -847,8 +877,9 @@ def forward(
# === Prepare validation parameter set ===
console_width = self.config.get('width', None) # console width for rich table displays of synapse measures
validation_params = (random_uids, query_responses, return_ops, times, routing_score,
- inputs, val_len, self.loss_fct, self.config.nucleus.scaling_law_power, console_width,
- self.config.logging.debug or self.config.logging.trace)
+ inputs, val_len, self.loss_fct,
+ self.config.nucleus.scaling_law_power, self.config.nucleus.synergy_scaling_law_power,
+ console_width, self.config.logging.debug or self.config.logging.trace)
loss = torch.tensor(0.).to(self.device) # to accumulate neuron_loss and routing_loss over synapses
neuron_stats = {} # to gather neuron synapse validation measures and statistics
@@ -876,7 +907,8 @@ def scaling_law_loss_to_params(loss):
def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor],
times: List[torch.FloatTensor], routing_score: torch.FloatTensor,
- inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, scaling_law_power: float,
+ inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable,
+ scaling_law_power: float, synergy_scaling_law_power: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLM' = None, index_s: int = 0
) -> Tuple[torch.FloatTensor, Dict]:
r"""
@@ -901,6 +933,8 @@ def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTenso
CrossEntropy loss function to use.
scaling_law_power (:obj:`float`, `required`):
Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
+ synergy_scaling_law_power (:obj:`float`, `required`):
+ Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
console_width (:obj:`int`, `required`):
Config console width for table print.
logging (:obj:`bool`, `required`):
@@ -957,9 +991,9 @@ def _synergy(first, second, target, _ext):
synergy_start_time = time.time()
syn_loss_diff = shapley_synergy(stats, _synergy, ext='', target=inputs_seq[:, 1:],
- scaling_law_power=scaling_law_power)
+ scaling_law_power=synergy_scaling_law_power)
syn_loss_diff_val = shapley_synergy(stats, _synergy, ext='_val', target=inputs_val,
- scaling_law_power=scaling_law_power)
+ scaling_law_power=synergy_scaling_law_power)
# === Shapley value combination ===
# Combine base values with synergy approximation to get final Shapley values.
@@ -998,7 +1032,8 @@ def _synergy(first, second, target, _ext):
def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor],
times: List[torch.FloatTensor], routing_score: torch.FloatTensor,
- inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, scaling_law_power: float,
+ inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable,
+ scaling_law_power: float, synergy_scaling_law_power: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLMNext' = None, index_s: int = 0
) -> Tuple[torch.FloatTensor, Dict]:
r"""
@@ -1023,6 +1058,8 @@ def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatT
CrossEntropy loss function to use.
scaling_law_power (:obj:`float`, `required`):
Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
+ synergy_scaling_law_power (:obj:`float`, `required`):
+ Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
console_width (:obj:`int`, `required`):
Config console width for table print.
logging (:obj:`bool`, `required`):
@@ -1075,7 +1112,7 @@ def _synergy(first, second, target, ext):
synergy_start_time = time.time()
- syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=scaling_law_power)
+ syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=synergy_scaling_law_power)
# === Shapley value combination ===
# Combine base values with synergy approximation to get final Shapley values.
@@ -1212,6 +1249,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens
# Synergy = measured performance above expected performance
# Measured in effective number of model parameters, just like base Shapley values.
syn_loss_diff = {} # expected_loss - measured_loss (where > 0)
+ responsives = [uid for uid, stat in stats.items() if 'loss' + ext in stat]
for _first, first in stats.items():
if 'loss' + ext not in first:
continue
@@ -1229,6 +1267,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens
measured_loss = synergy(first, second, target, ext)
loss_diff_share = torch.clamp(expected_loss - measured_loss, 0) / 2 # record direct loss diff
+ loss_diff_share /= len(responsives) # average over responsives
first['synergy_loss_diff' + ext] += loss_diff_share
second['synergy_loss_diff' + ext] += loss_diff_share
@@ -1244,6 +1283,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens
pow_expected_params = torch.pow(expected_params, scaling_law_power)
synergy_share = torch.clamp(pow_measured_params - pow_expected_params, 0) / 2
+ synergy_share /= len(responsives) # average over responsives
first['synergy' + ext] += synergy_share # share synergy amongst coalition members
second['synergy' + ext] += synergy_share
diff --git a/requirements.txt b/requirements.txt
index 1c7f9b8fd6..c96f9a27e6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -32,7 +32,7 @@ pytest-rerunfailures
coveralls
pytest-cov
pyyaml
-rich
+rich>=12.5.1
retry
requests>=2.25.0
scalecodec>=1.0.35