diff --git a/openvalidators/neuron.py b/openvalidators/neuron.py index d09ba27..f1a2405 100644 --- a/openvalidators/neuron.py +++ b/openvalidators/neuron.py @@ -109,7 +109,8 @@ def __init__(self): # Init the gating model which learns which miners to select for each query. bt.logging.debug("loading", "gating_model") if not self.config.gating.num_uids: - self.config.gating.num_uids = self.subtensor.subnetwork_n(self.config.netuid) + self.config.gating.num_uids = self.subtensor.max_n(self.config.netuid) + if self.config.neuron.mock_gating_model: self.gating_model = MockGatingModel(self.metagraph.n.item()) elif self.config.neuron.use_custom_gating_model: @@ -118,7 +119,7 @@ def __init__(self): self.gating_model = GatingModel(metagraph=self.metagraph, config=self.config).to(self.device) bt.logging.debug(str(self.gating_model)) - # Dendrite pool for querying the network during training. + # Dendrite pool for querying the network during training. bt.logging.debug("loading", "dendrite_pool") if self.config.neuron.mock_dendrite_pool: self.dendrite_pool = MockDendritePool() diff --git a/openvalidators/reward/diversity.py b/openvalidators/reward/diversity.py index 52eebd5..185c53b 100644 --- a/openvalidators/reward/diversity.py +++ b/openvalidators/reward/diversity.py @@ -90,7 +90,7 @@ def get_rewards( self, prompt: str, completions: List[str], name: str ) -> torch # Check if completions are empty, return 0 if so if len(completions) == 0: - return torch.tensor([]) + return torch.tensor([]).to(self.device) # Get embeddings for all completions. embeddings = self.get_embeddings( completions )