diff --git a/openvalidators/forward.py b/openvalidators/forward.py index 03aa2b4..6cc11d6 100644 --- a/openvalidators/forward.py +++ b/openvalidators/forward.py @@ -339,14 +339,11 @@ async def forward(self): rewards=answer_rewards, ) - # Compute forward pass rewards. - scattered_followup_rewards = ( - torch.zeros((self.metagraph.n), dtype=torch.float32).to(self.device).scatter(0, followup_uids, followup_rewards) - ) - scattered_answer_rewards = ( - torch.zeros((self.metagraph.n), dtype=torch.float32).to(self.device).scatter(0, answer_uids, answer_rewards) - ) - rewards = scattered_followup_rewards + scattered_answer_rewards + # Compute forward pass rewards, assumes followup_uids and answer_uids are mutually exclusive. + rewards = self.moving_averaged_scores.scatter(0, followup_uids, followup_rewards) + rewards = rewards.scatter(0, answer_uids, answer_rewards) + + # Update moving_averaged_scores with rewards. self.moving_averaged_scores = self.config.neuron.moving_average_alpha * rewards.to(self.device) + ( 1 - self.config.neuron.moving_average_alpha ) * self.moving_averaged_scores.to(self.device)