Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions bittensor/_neuron/text/core_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,19 @@ def run_epoch( self ):
# Each block length lasts blocks_per_epoch blocks.
# This gives us a consistent network wide timer.
# Here we run until blocks_per_epochs have progressed.
self.metagraph_sync() # Reset metagraph.
if self.epoch > 0: # skip first epoch: already synced at start of run
self.metagraph_sync() # Reset metagraph.

self.nucleus.permute_uids = [] # clear nucleus permutation before epoch

epoch_steps = 0
epoch_responsive_uids = set()
epoch_queried_uids = set()
epoch_start_time = time.time()

start_block = self.subtensor.block
while self.subtensor.block < start_block + blocks_per_epoch:
while (self.subtensor.block < start_block + blocks_per_epoch or
len(epoch_queried_uids) < self.metagraph.n): # ensure each UID is queried at least once - assumes nucleus samples without replacement
start_time = time.time()

# === Forward ===
Expand Down Expand Up @@ -434,8 +439,9 @@ def run_epoch( self ):
f'[dim] Epoch {self.epoch}[/dim] | '
f'[bright_green not bold]{len(responsive_uids)}[/bright_green not bold]/'
f'[white]{len(queried_uids)}[/white] '
f'[dim white not bold][green]responsive[/green]/queried[/dim white not bold] '
f'[[yellow]{step_time:.3g}[/yellow]s]')
f'[[yellow]{step_time:.3g}[/yellow]s] '
f'[dim white not bold][green]{len(epoch_responsive_uids)}[/green]/'
f'{len(epoch_queried_uids)}[/dim white not bold]')

if self.config.logging.debug or self.config.logging.trace:
# === Print stats update (table) ===
Expand Down Expand Up @@ -485,6 +491,20 @@ def run_epoch( self ):
if self.config.logging.debug or self.config.logging.trace:
self.weights_table(sample_uids, sample_weights) # print weights table

# set weights console message (every epoch)
print(f"[white not bold]{datetime.datetime.now():%Y-%m-%d %H:%M:%S}[/white not bold]{' ' * 4} | "
f"{f'[bright_white]Set weights[/bright_white]'.center(16 + len('[bright_white][/bright_white]'))} | "
f'[bright_green not bold]{len(sample_weights)}[/bright_green not bold] [dim]weights set[/dim] | '
f'[bright_green not bold]{len(epoch_responsive_uids)}[/bright_green not bold]/'
f'[white]{len(epoch_queried_uids)}[/white] '
f'[dim white not bold][green]responsive[/green]/queried[/dim white not bold] '
f'[[yellow]{time.time() - epoch_start_time:.0f}[/yellow]s] | '
f'[dim]weights[/dim] sum:{sample_weights.sum().item():.2g} '
f'[white] max:[bold]{sample_weights.max().item():.4g}[/bold] / '
f'min:[bold]{sample_weights.min().item():.4g}[/bold] [/white] '
f'\[{sample_weights.max().item() / sample_weights.min().item():.1f}:1] '
f'({max_allowed_ratio} allowed)')

self.subtensor.set_weights(
uids=sample_uids.detach().to('cpu'),
weights=sample_weights.detach().to('cpu'),
Expand Down Expand Up @@ -543,6 +563,7 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):
zkey = key + '!' # zeroing key
stats.setdefault(zkey, 0.) # initialize zkey val to zero to gradually increase with observations
if key in _stats and not math.isnan(_stats[key]):
responsive_uids += [_uid]
stats[zkey] = (1 - self.alpha) * stats[zkey] + self.alpha * _stats[key]
else:
stats[zkey] = (1 - self.alpha) * stats[zkey] # + self.alpha * 0
Expand All @@ -555,7 +576,6 @@ def neuron_stats_update(self, neuron_stats: Dict[int, Dict[str, Any]]):
updates = 'updates_' + key
if updates in stats:
stats[updates] += 1 # increment number of normal EMA updates made
responsive_uids += [_uid]
else:
stats.setdefault(updates, 1) # add updates fields for new uid entries

Expand Down Expand Up @@ -660,6 +680,7 @@ def __init__( self, config, device, subtensor ):
self.config = config
self.device = device
self.max_n = subtensor.max_n
self.permute_uids = [] # iterable of next UIDs to query, reset to permuted UIDs when empty

tokenizer = bittensor.tokenizer()
self.pad_token = tokenizer(tokenizer.pad_token)['input_ids'][0]
Expand Down Expand Up @@ -702,6 +723,7 @@ def add_args( cls, parser ):
parser.add_argument('--nucleus.noise_multiplier', type=float, help='Standard deviation multipler on weights', default=2 )
parser.add_argument('--nucleus.dendrite_backward', action='store_true', help='Pass backward request to the server side or not', default=False )
parser.add_argument('--nucleus.scaling_law_power', type=float, help='Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.', default=0.5)
parser.add_argument('--nucleus.synergy_scaling_law_power', type=float, help='Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.', default=0.6)

@classmethod
def config ( cls ):
Expand Down Expand Up @@ -794,18 +816,26 @@ def forward(
# Ensure number of queried neurons does not exceed metagraph.n
num_endpoints = min([self.config.nucleus.topk, metagraph.n])

logger.info(f'Forward \t| Routing forward <dim>[{time.time() - start_time:.3g}s]</dim>')
logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}')
request_start_time = time.time()
# === Ensure each UID is queried once ===
# Persist object variable self.permute_uids across forward calls.
# Reset to new permutation of all UIDs once empty.
if len(self.permute_uids) == 0: # no more UIDs to query
self.permute_uids = torch.randperm(metagraph.n) # reset to new permutation of all UIDs

# === Randomly select num_endpoints UIDs ===
random_uids = torch.randperm(metagraph.n)[:num_endpoints]
random_uids = self.permute_uids[:num_endpoints] # newest selection of UIDs to query
self.permute_uids = self.permute_uids[num_endpoints:] # slice out remaining selection

# === Get endpoint information for the selected UIDs ===
# We index into the metagraph's endpoints and return a list of the filtered set of endpoints we wish to query.
# random_endpoints: List[bittensor.endpoints]: endpoint information for filtered uids.
# len(neurons) == self.config.nucleus.topk
random_endpoints = [metagraph.endpoints[uid] for uid in random_uids]
num_endpoints = len(random_endpoints) # in case len(self.permute_uids) < num_endpoints during random_uids select

logger.info(f'Forward \t| Routing forward <dim>[{time.time() - start_time:.3g}s]</dim>')
logger.info(f'Dendrite \t| Request {num_endpoints} x {list(inputs_seq.shape)}')
request_start_time = time.time()

# === Define which synapse we want to use ===
# The synapse defines the task we are sending to the neurons
Expand Down Expand Up @@ -847,8 +877,9 @@ def forward(
# === Prepare validation parameter set ===
console_width = self.config.get('width', None) # console width for rich table displays of synapse measures
validation_params = (random_uids, query_responses, return_ops, times, routing_score,
inputs, val_len, self.loss_fct, self.config.nucleus.scaling_law_power, console_width,
self.config.logging.debug or self.config.logging.trace)
inputs, val_len, self.loss_fct,
self.config.nucleus.scaling_law_power, self.config.nucleus.synergy_scaling_law_power,
console_width, self.config.logging.debug or self.config.logging.trace)

loss = torch.tensor(0.).to(self.device) # to accumulate neuron_loss and routing_loss over synapses
neuron_stats = {} # to gather neuron synapse validation measures and statistics
Expand Down Expand Up @@ -876,7 +907,8 @@ def scaling_law_loss_to_params(loss):

def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor],
times: List[torch.FloatTensor], routing_score: torch.FloatTensor,
inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, scaling_law_power: float,
inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable,
scaling_law_power: float, synergy_scaling_law_power: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLM' = None, index_s: int = 0
) -> Tuple[torch.FloatTensor, Dict]:
r"""
Expand All @@ -901,6 +933,8 @@ def textcausallm(uids: torch.Tensor, query_responses: List[List[torch.FloatTenso
CrossEntropy loss function to use.
scaling_law_power (:obj:`float`, `required`):
Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
synergy_scaling_law_power (:obj:`float`, `required`):
Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
console_width (:obj:`int`, `required`):
Config console width for table print.
logging (:obj:`bool`, `required`):
Expand Down Expand Up @@ -957,9 +991,9 @@ def _synergy(first, second, target, _ext):
synergy_start_time = time.time()

syn_loss_diff = shapley_synergy(stats, _synergy, ext='', target=inputs_seq[:, 1:],
scaling_law_power=scaling_law_power)
scaling_law_power=synergy_scaling_law_power)
syn_loss_diff_val = shapley_synergy(stats, _synergy, ext='_val', target=inputs_val,
scaling_law_power=scaling_law_power)
scaling_law_power=synergy_scaling_law_power)

# === Shapley value combination ===
# Combine base values with synergy approximation to get final Shapley values.
Expand Down Expand Up @@ -998,7 +1032,8 @@ def _synergy(first, second, target, _ext):

def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatTensor]], return_ops: List[torch.LongTensor],
times: List[torch.FloatTensor], routing_score: torch.FloatTensor,
inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable, scaling_law_power: float,
inputs: torch.FloatTensor, validation_len: int, loss_fct: Callable,
scaling_law_power: float, synergy_scaling_law_power: float,
console_width: int, logging, synapse: 'bittensor.TextCausalLMNext' = None, index_s: int = 0
) -> Tuple[torch.FloatTensor, Dict]:
r"""
Expand All @@ -1023,6 +1058,8 @@ def textcausallmnext(uids: torch.Tensor, query_responses: List[List[torch.FloatT
CrossEntropy loss function to use.
scaling_law_power (:obj:`float`, `required`):
Power for modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
synergy_scaling_law_power (:obj:`float`, `required`):
Power for synergy modified scaling law, powered down to improve dynamic range, e.g. 3 → 6 nats for 0.5.
console_width (:obj:`int`, `required`):
Config console width for table print.
logging (:obj:`bool`, `required`):
Expand Down Expand Up @@ -1075,7 +1112,7 @@ def _synergy(first, second, target, ext):

synergy_start_time = time.time()

syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=scaling_law_power)
syn_loss_diff = shapley_synergy(stats, _synergy, '_nxt', scaling_law_power=synergy_scaling_law_power)

# === Shapley value combination ===
# Combine base values with synergy approximation to get final Shapley values.
Expand Down Expand Up @@ -1212,6 +1249,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens
# Synergy = measured performance above expected performance
# Measured in effective number of model parameters, just like base Shapley values.
syn_loss_diff = {} # expected_loss - measured_loss (where > 0)
responsives = [uid for uid, stat in stats.items() if 'loss' + ext in stat]
for _first, first in stats.items():
if 'loss' + ext not in first:
continue
Expand All @@ -1229,6 +1267,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens
measured_loss = synergy(first, second, target, ext)

loss_diff_share = torch.clamp(expected_loss - measured_loss, 0) / 2 # record direct loss diff
loss_diff_share /= len(responsives) # average over responsives
first['synergy_loss_diff' + ext] += loss_diff_share
second['synergy_loss_diff' + ext] += loss_diff_share

Expand All @@ -1244,6 +1283,7 @@ def shapley_synergy(stats: Dict, synergy: Callable, ext: str, target: torch.Tens
pow_expected_params = torch.pow(expected_params, scaling_law_power)

synergy_share = torch.clamp(pow_measured_params - pow_expected_params, 0) / 2
synergy_share /= len(responsives) # average over responsives
first['synergy' + ext] += synergy_share # share synergy amongst coalition members
second['synergy' + ext] += synergy_share

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pytest-rerunfailures
coveralls
pytest-cov
pyyaml
rich
rich>=12.5.1
retry
requests>=2.25.0
scalecodec>=1.0.35
Expand Down