diff --git a/bittensor/metagraph.py b/bittensor/metagraph.py index f15912bff9..2c31ccf6b1 100644 --- a/bittensor/metagraph.py +++ b/bittensor/metagraph.py @@ -552,6 +552,7 @@ def _process_root_weights( """ data_array = [] n_subnets = subtensor.get_total_subnets() + subnets = subtensor.get_subnets() for item in data: if len(item) == 0: data_array.append(torch.zeros(n_subnets)) @@ -559,8 +560,8 @@ def _process_root_weights( uids, values = zip(*item) # TODO: Validate and test the conversion of uids and values to tensor data_array.append( - bittensor.utils.weight_utils.convert_weight_uids_and_vals_to_tensor( - n_subnets, uids, values + bittensor.utils.weight_utils.convert_root_weight_uids_and_vals_to_tensor( + n_subnets, uids, values, subnets ) ) diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index cc80b975c1..d07caf22aa 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -100,6 +100,39 @@ def convert_weight_uids_and_vals_to_tensor( return row_weights +def convert_root_weight_uids_and_vals_to_tensor( + n: int, uids: List[int], weights: List[int], subnets: List[int] +) -> "torch.FloatTensor": + r"""Converts root weights and uids from chain representation into a torch tensor (inverse operation from convert_weights_and_uids_for_emit) + Args: + n: int: + number of neurons on network. + uids (:obj:`List[int],`): + Tensor of uids as destinations for passed weights. + weights (:obj:`List[int],`): + Tensor of weights. + subnets (:obj:`List[int],`): + list of subnets on the network + Returns: + row_weights ( torch.FloatTensor ): + Converted row weights. + """ + + row_weights = torch.zeros([n], dtype=torch.float32) + for uid_j, wij in list(zip(uids, weights)): + if uid_j in subnets: + index_s = subnets.index(uid_j) + else: + raise Exception("Incorrect Subnet {uid_j} in {subnets}") + row_weights[index_s] = float( + wij + ) # assumes max-upscaled values (w_max = U16_MAX). + row_sum = row_weights.sum() + if row_sum > 0: + row_weights /= row_sum # normalize + return row_weights + + def convert_bond_uids_and_vals_to_tensor( n: int, uids: List[int], bonds: List[int] ) -> "torch.LongTensor":