diff --git a/openvalidators/utils.py b/openvalidators/utils.py index 341fbde..9be9165 100644 --- a/openvalidators/utils.py +++ b/openvalidators/utils.py @@ -182,7 +182,7 @@ def save_state(self): bt.logging.info("save_state()") try: neuron_state_dict = { - "neuron_weights": self.moving_averaged_scores, + "neuron_weights": self.moving_averaged_scores.to('cpu').tolist(), "neuron_hotkeys": self.hotkeys, } torch.save(neuron_state_dict, f"{self.config.neuron.full_path}/model.torch") @@ -202,7 +202,7 @@ def save_state(self): "step": self.step, "block": ttl_get_block(self), **neuron_state_dict - }) + }) if not self.config.wandb.off and self.config.wandb.track_gating_model: model_artifact = wandb.Artifact(f"{gating_model_name}_gating_linear_layer", type="model") model_artifact.add_file(gating_model_file_path)