Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions bittensor/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def process_weights_for_netuid(
"""

logging.debug("process_weights_for_netuid()")
logging.debug(f"weights: {weights}")
logging.debug(f"netuid {netuid}")
logging.debug(f"subtensor: {subtensor}")
logging.debug(f"metagraph: {metagraph}")
Expand All @@ -254,6 +253,48 @@ def process_weights_for_netuid(
if metagraph is None:
metagraph = subtensor.metagraph(netuid)

return process_weights(
uids=uids,
weights=weights,
num_neurons=metagraph.n,
min_allowed_weights=subtensor.min_allowed_weights(netuid=netuid),
max_weight_limit=subtensor.max_weight_limit(netuid=netuid),
exclude_quantile=exclude_quantile,
)


def process_weights(
uids: Union[NDArray[np.int64], "torch.Tensor"],
weights: Union[NDArray[np.float32], "torch.Tensor"],
num_neurons: int,
min_allowed_weights: Optional[int],
max_weight_limit: Optional[float],
exclude_quantile: int = 0,
) -> Union[
tuple["torch.Tensor", "torch.FloatTensor"],
tuple[NDArray[np.int64], NDArray[np.float32]],
]:
"""
Processes weight tensors for a given weights and UID arrays and hyperparams, applying constraints
and normalization based on the subtensor and metagraph data. This function can handle both NumPy arrays and PyTorch
tensors.

Args:
uids (Union[NDArray[np.int64], "torch.Tensor"]): Array of unique identifiers of the neurons.
weights (Union[NDArray[np.float32], "torch.Tensor"]): Array of weights associated with the user IDs.
num_neurons (int): The number of neurons in the network.
min_allowed_weights (Optional[int]): Subnet hyperparam Minimum number of allowed weights.
max_weight_limit (Optional[float]): Subnet hyperparam Maximum weight limit.
exclude_quantile (int): Quantile threshold for excluding lower weights. Defaults to ``0``.

Returns:
Union[tuple["torch.Tensor", "torch.FloatTensor"], tuple[NDArray[np.int64], NDArray[np.float32]]]: tuple
containing the array of user IDs and the corresponding normalized weights. The data type of the return
matches the type of the input weights (NumPy or PyTorch).
"""
logging.debug("process_weights()")
logging.debug(f"weights: {weights}")

# Cast weights to floats.
if use_torch():
if not isinstance(weights, torch.FloatTensor):
Expand All @@ -265,8 +306,6 @@ def process_weights_for_netuid(
# Network configuration parameters from an subtensor.
# These parameters determine the range of acceptable weights for each neuron.
quantile = exclude_quantile / U16_MAX
min_allowed_weights = subtensor.min_allowed_weights(netuid=netuid)
max_weight_limit = subtensor.max_weight_limit(netuid=netuid)
logging.debug(f"quantile: {quantile}")
logging.debug(f"min_allowed_weights: {min_allowed_weights}")
logging.debug(f"max_weight_limit: {max_weight_limit}")
Expand All @@ -280,12 +319,12 @@ def process_weights_for_netuid(
non_zero_weight_uids = uids[non_zero_weight_idx]
non_zero_weights = weights[non_zero_weight_idx]
nzw_size = non_zero_weights.numel() if use_torch() else non_zero_weights.size
if nzw_size == 0 or metagraph.n < min_allowed_weights:
if nzw_size == 0 or num_neurons < min_allowed_weights:
logging.warning("No non-zero weights returning all ones.")
final_weights = (
torch.ones((metagraph.n)).to(metagraph.n) / metagraph.n
torch.ones(num_neurons).to(num_neurons) / num_neurons
if use_torch()
else np.ones((metagraph.n), dtype=np.int64) / metagraph.n
else np.ones(num_neurons, dtype=np.int64) / num_neurons
)
logging.debug(f"final_weights: {final_weights}")
final_weights_count = (
Expand All @@ -303,11 +342,11 @@ def process_weights_for_netuid(
logging.warning(
"No non-zero weights less then min allowed weight, returning all ones."
)
# ( const ): Should this be np.zeros( ( metagraph.n ) ) to reset everyone to build up weight?
# ( const ): Should this be np.zeros( ( num_neurons ) ) to reset everyone to build up weight?
weights = (
torch.ones((metagraph.n)).to(metagraph.n) * 1e-5
torch.ones(num_neurons).to(num_neurons) * 1e-5
if use_torch()
else np.ones((metagraph.n), dtype=np.int64) * 1e-5
else np.ones(num_neurons, dtype=np.int64) * 1e-5
) # creating minimum even non-zero weights
weights[non_zero_weight_idx] += non_zero_weights
logging.debug(f"final_weights: {weights}")
Expand Down
Loading