diff --git a/openvalidators/neuron.py b/openvalidators/neuron.py
index 6ba439b..29e0a1c 100644
--- a/openvalidators/neuron.py
+++ b/openvalidators/neuron.py
@@ -201,7 +201,7 @@ def __init__(self):
RelevanceRewardModel(device=self.device) if not self.config.neuron.relevance_off
else MockRewardModel(RewardModelType.relevance.value)
)
- diversity_model = (
+ self.diversity_model = (
DiversityRewardModel(device=self.device) if not self.config.neuron.diversity_off
else MockRewardModel(RewardModelType.diversity.value)
)
@@ -210,7 +210,7 @@ def __init__(self):
else MockRewardModel(RewardModelType.nsfw.value)
)
- self.masking_functions = [self.blacklist, task_validator, relevance_model, diversity_model, nsfw_model]
+ self.masking_functions = [self.blacklist, task_validator, relevance_model, self.diversity_model, nsfw_model]
bt.logging.debug(str(self.reward_functions))
bt.logging.debug(str(self.masking_functions))
diff --git a/openvalidators/utils.py b/openvalidators/utils.py
index 340a020..d008f55 100644
--- a/openvalidators/utils.py
+++ b/openvalidators/utils.py
@@ -194,7 +194,10 @@ def save_state(self):
prefix="Saved model",
sufix=f"{ self.config.neuron.full_path }/model.torch",
)
+ except Exception as e:
+ bt.logging.warning(f"Failed to save model with error: {e}")
+ try:
# Save the gating model.
gating_model_linear_layer_dict = self.gating_model.linear.state_dict()
gating_model_name = self.config.gating.model_name.replace("/", "_")
@@ -205,7 +208,7 @@ def save_state(self):
wandb.log({
"step": self.step,
"block": ttl_get_block(self),
- **neuron_state_dict
+ **neuron_state_dict
})
if not self.config.wandb.off and self.config.wandb.track_gating_model:
model_artifact = wandb.Artifact(f"{gating_model_name}_gating_linear_layer", type="model")
@@ -213,12 +216,23 @@ def save_state(self):
self.wandb.log_artifact(model_artifact)
bt.logging.success(prefix="Saved gating model", sufix=f"{gating_model_file_path}")
+ except Exception as e:
+ bt.logging.warning(f"Failed to save gating model with error: {e}")
- #empty cache
- torch.cuda.empty_cache()
-
+ try:
+ # Save diversity model.
+ diversity_model_dict = {"historic_embeddings": self.diversity_model.historic_embeddings.to('cpu')}
+ diversity_model_file_path = f"{self.config.neuron.full_path}/diversity_model.pth"
+ torch.save(diversity_model_dict, diversity_model_file_path)
+ bt.logging.success(
+ prefix="Saved diversity model",
+ sufix=f"{diversity_model_file_path} {list(self.diversity_model.historic_embeddings.shape)}",
+ )
except Exception as e:
- bt.logging.warning(f"Failed to save model with error: {e}")
+ bt.logging.warning(f"Failed to save diversity model with error: {e}")
+
+ # empty cache
+ torch.cuda.empty_cache()
def load_state(self):
@@ -227,8 +241,9 @@ def load_state(self):
try:
state_dict = torch.load(f"{self.config.neuron.full_path}/model.torch")
# Check for nans in saved state dict
- if not torch.isnan(state_dict["neuron_weights"]).any():
- self.moving_averaged_scores = state_dict["neuron_weights"].clone().detach()
+ neuron_weights = torch.tensor(state_dict["neuron_weights"])
+ if not torch.isnan(neuron_weights).any():
+ self.moving_averaged_scores = neuron_weights.to(self.device)
self.hotkeys = state_dict["neuron_hotkeys"]
bt.logging.success(
prefix="Reloaded model",
@@ -236,3 +251,15 @@ def load_state(self):
)
except Exception as e:
bt.logging.warning(f"Failed to load model with error: {e}")
+
+ try:
+ # Load diversity model.
+ diversity_model_file_path = f"{self.config.neuron.full_path}/diversity_model.pth"
+ diversity_model_dict = torch.load(diversity_model_file_path)
+ self.diversity_model.historic_embeddings = diversity_model_dict["historic_embeddings"].to(self.device)
+ bt.logging.success(
+ prefix="Reloaded diversity model",
+ sufix=f"{diversity_model_file_path} {list(self.diversity_model.historic_embeddings.shape)}",
+ )
+ except Exception as e:
+ bt.logging.warning(f"Failed to load diversity model with error: {e}")