diff --git a/openvalidators/forward.py b/openvalidators/forward.py index 9fa8397..376fc8c 100644 --- a/openvalidators/forward.py +++ b/openvalidators/forward.py @@ -42,16 +42,23 @@ def get_random_uids(self, k: int, exclude: List[int] = None) -> torch.LongTensor If `k` is larger than the number of available `uids`, set `k` to the number of available `uids`. """ candidate_uids = [] + avail_uids = [] for uid in range(self.metagraph.n.item()): uid_is_available = check_uid_availability(self.metagraph, uid, self.config.neuron.vpermit_tao_limit) uid_is_not_excluded = exclude is None or uid not in exclude - if uid_is_available and uid_is_not_excluded: - candidate_uids.append(uid) - - available_uids = torch.tensor(candidate_uids, dtype=torch.int64).to(self.device) - uids = torch.tensor(random.sample(available_uids.tolist(), k), dtype=torch.int64) + if uid_is_available: + avail_uids.append(uid) + if uid_is_not_excluded: + candidate_uids.append(uid) + + # Check if candidate_uids contain enough for querying, if not grab all avaliable uids + available_uids = candidate_uids + if len(candidate_uids) < k: + available_uids += random.sample([uid for uid in avail_uids if uid not in candidate_uids], k-len(candidate_uids)) + + uids = torch.tensor(random.sample(available_uids, k), dtype=torch.int64) return uids diff --git a/openvalidators/gating.py b/openvalidators/gating.py index 482ad0e..2ca79b6 100644 --- a/openvalidators/gating.py +++ b/openvalidators/gating.py @@ -52,8 +52,7 @@ def add_args(cls, parser: argparse.ArgumentParser): parser.add_argument( "--gating.num_uids", type=int, - default=1024, - help="Number of uids to gate on", + help="Number of uids to gate on. Default is pulled from subtensor directly", ) parser.add_argument( "--gating.learning_rate", @@ -137,7 +136,7 @@ def __init__( config = GatingModel.config() if model_name is not None: config.gating.model_name = model_name - config.gating.num_uids = num_uids if num_uids is not None else metagraph.n + config.gating.num_uids = num_uids if num_uids is not None else config.gating.num_uids self.config = config self.num_uids = config.gating.num_uids self.device = torch.device(self.config.neuron.device) @@ -228,7 +227,7 @@ def __init__( config = SentenceEmbedGatingModel.config() if model_name is not None: config.gating.model_name = model_name - config.gating.num_uids = num_uids if num_uids is not None else metagraph.n + config.gating.num_uids = num_uids if num_uids is not None else config.gating.num_uids self.config = config self.num_uids = config.gating.num_uids self.device = torch.device(self.config.neuron.device) diff --git a/openvalidators/neuron.py b/openvalidators/neuron.py index 227d2c8..f1a2405 100644 --- a/openvalidators/neuron.py +++ b/openvalidators/neuron.py @@ -108,6 +108,9 @@ def __init__(self): # Init the gating model which learns which miners to select for each query. bt.logging.debug("loading", "gating_model") + if not self.config.gating.num_uids: + self.config.gating.num_uids = self.subtensor.max_n(self.config.netuid) + if self.config.neuron.mock_gating_model: self.gating_model = MockGatingModel(self.metagraph.n.item()) elif self.config.neuron.use_custom_gating_model: @@ -116,7 +119,7 @@ def __init__(self): self.gating_model = GatingModel(metagraph=self.metagraph, config=self.config).to(self.device) bt.logging.debug(str(self.gating_model)) - # Dendrite pool for querying the network during training. + # Dendrite pool for querying the network during training. bt.logging.debug("loading", "dendrite_pool") if self.config.neuron.mock_dendrite_pool: self.dendrite_pool = MockDendritePool() diff --git a/openvalidators/reward/diversity.py b/openvalidators/reward/diversity.py index 52eebd5..185c53b 100644 --- a/openvalidators/reward/diversity.py +++ b/openvalidators/reward/diversity.py @@ -90,7 +90,7 @@ def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch # Check if completions are empty, return 0 if so if len(completions) == 0: - return torch.tensor([]) + return torch.tensor([]).to(self.device) # Get embeddings for all completions. embeddings = self.get_embeddings( completions ) diff --git a/openvalidators/reward/reward.py b/openvalidators/reward/reward.py index 537d94b..570b7bf 100644 --- a/openvalidators/reward/reward.py +++ b/openvalidators/reward/reward.py @@ -51,7 +51,7 @@ def normalize_rewards( self, rewards: torch.FloatTensor ) -> torch.FloatTensor: - This function uses Welford's online algorithm to update the mean and variance. - It standardizes the reward values using the updated mean and variance. - It then scales the standardized values to the 0-1 range using the error function (erf) as a CDF. - """ + """ # Get the number of rewards (successful responses). new_count = rewards.numel() @@ -88,6 +88,7 @@ def apply( self, prompt: str, responses: List[ bt.DendriteCall ], name: str) -> """ Applies the reward model across each call. Unsuccessful responses are zeroed. """ # Get indices of correctly responding calls. + successful_completions_indices: List[int] = [ idx for idx, resp in enumerate(responses) if resp.is_success ] # Get all completions from responding calls. diff --git a/openvalidators/utils.py b/openvalidators/utils.py index 341fbde..5eca03e 100644 --- a/openvalidators/utils.py +++ b/openvalidators/utils.py @@ -209,6 +209,10 @@ def save_state(self): self.wandb.log_artifact(model_artifact) bt.logging.success(prefix="Saved gating model", sufix=f"{gating_model_file_path}") + + #empty cache + torch.cuda.empty_cache() + except Exception as e: bt.logging.warning(f"Failed to save model with error: {e}")