diff --git a/openvalidators/neuron.py b/openvalidators/neuron.py index 6ba439b..29e0a1c 100644 --- a/openvalidators/neuron.py +++ b/openvalidators/neuron.py @@ -201,7 +201,7 @@ def __init__(self): RelevanceRewardModel(device=self.device) if not self.config.neuron.relevance_off else MockRewardModel(RewardModelType.relevance.value) ) - diversity_model = ( + self.diversity_model = ( DiversityRewardModel(device=self.device) if not self.config.neuron.diversity_off else MockRewardModel(RewardModelType.diversity.value) ) @@ -210,7 +210,7 @@ def __init__(self): else MockRewardModel(RewardModelType.nsfw.value) ) - self.masking_functions = [self.blacklist, task_validator, relevance_model, diversity_model, nsfw_model] + self.masking_functions = [self.blacklist, task_validator, relevance_model, self.diversity_model, nsfw_model] bt.logging.debug(str(self.reward_functions)) bt.logging.debug(str(self.masking_functions)) diff --git a/openvalidators/utils.py b/openvalidators/utils.py index 340a020..d008f55 100644 --- a/openvalidators/utils.py +++ b/openvalidators/utils.py @@ -194,7 +194,10 @@ def save_state(self): prefix="Saved model", sufix=f"{ self.config.neuron.full_path }/model.torch", ) + except Exception as e: + bt.logging.warning(f"Failed to save model with error: {e}") + try: # Save the gating model. gating_model_linear_layer_dict = self.gating_model.linear.state_dict() gating_model_name = self.config.gating.model_name.replace("/", "_") @@ -205,7 +208,7 @@ def save_state(self): wandb.log({ "step": self.step, "block": ttl_get_block(self), - **neuron_state_dict + **neuron_state_dict }) if not self.config.wandb.off and self.config.wandb.track_gating_model: model_artifact = wandb.Artifact(f"{gating_model_name}_gating_linear_layer", type="model") @@ -213,12 +216,23 @@ def save_state(self): self.wandb.log_artifact(model_artifact) bt.logging.success(prefix="Saved gating model", sufix=f"{gating_model_file_path}") + except Exception as e: + bt.logging.warning(f"Failed to save gating model with error: {e}") - #empty cache - torch.cuda.empty_cache() - + try: + # Save diversity model. + diversity_model_dict = {"historic_embeddings": self.diversity_model.historic_embeddings.to('cpu')} + diversity_model_file_path = f"{self.config.neuron.full_path}/diversity_model.pth" + torch.save(diversity_model_dict, diversity_model_file_path) + bt.logging.success( + prefix="Saved diversity model", + sufix=f"{diversity_model_file_path} {list(self.diversity_model.historic_embeddings.shape)}", + ) except Exception as e: - bt.logging.warning(f"Failed to save model with error: {e}") + bt.logging.warning(f"Failed to save diversity model with error: {e}") + + # empty cache + torch.cuda.empty_cache() def load_state(self): @@ -227,8 +241,9 @@ def load_state(self): try: state_dict = torch.load(f"{self.config.neuron.full_path}/model.torch") # Check for nans in saved state dict - if not torch.isnan(state_dict["neuron_weights"]).any(): - self.moving_averaged_scores = state_dict["neuron_weights"].clone().detach() + neuron_weights = torch.tensor(state_dict["neuron_weights"]) + if not torch.isnan(neuron_weights).any(): + self.moving_averaged_scores = neuron_weights.to(self.device) self.hotkeys = state_dict["neuron_hotkeys"] bt.logging.success( prefix="Reloaded model", @@ -236,3 +251,15 @@ def load_state(self): ) except Exception as e: bt.logging.warning(f"Failed to load model with error: {e}") + + try: + # Load diversity model. + diversity_model_file_path = f"{self.config.neuron.full_path}/diversity_model.pth" + diversity_model_dict = torch.load(diversity_model_file_path) + self.diversity_model.historic_embeddings = diversity_model_dict["historic_embeddings"].to(self.device) + bt.logging.success( + prefix="Reloaded diversity model", + sufix=f"{diversity_model_file_path} {list(self.diversity_model.historic_embeddings.shape)}", + ) + except Exception as e: + bt.logging.warning(f"Failed to load diversity model with error: {e}")