diff --git a/openvalidators/forward.py b/openvalidators/forward.py index 376fc8c..59589b2 100644 --- a/openvalidators/forward.py +++ b/openvalidators/forward.py @@ -62,7 +62,11 @@ def get_random_uids(self, k: int, exclude: List[int] = None) -> torch.LongTensor return uids -async def run_step(self, prompt: str, k: int, timeout: float, name: str, exclude: list = []): +async def run_step(self, prompt: str, k: int, timeout: float, name: str, exclude: list = [], base_prompt = None): + + if base_prompt == None: + base_prompt = prompt + bt.logging.debug("run_step", name) # Record event start time. @@ -90,7 +94,7 @@ async def run_step(self, prompt: str, k: int, timeout: float, name: str, exclude bt.logging.trace(str(reward_fn_i.name), reward_i.tolist()) for masking_fn_i in self.masking_functions: - mask_i = masking_fn_i.apply(prompt, responses, name).to(self.device) + mask_i = masking_fn_i.apply(base_prompt, responses, name).to(self.device) rewards *= mask_i # includes diversity if not self.config.neuron.disable_log_rewards: event[masking_fn_i.name] = mask_i.tolist() @@ -168,6 +172,7 @@ async def forward(self): ) base_text = augment_event["best"] + base_prompt = augment_event["best"] exclude = augment_event["uids"] for k in range(self.config.neuron.num_followup_steps): @@ -180,6 +185,7 @@ async def forward(self): k=self.config.neuron.followup_sample_size, timeout=self.config.neuron.followup_timeout, exclude=exclude, + base_prompt=base_prompt ) exclude += followup_event["uids"] @@ -192,6 +198,7 @@ async def forward(self): k=self.config.neuron.answer_sample_size, timeout=self.config.neuron.answer_timeout, exclude=exclude, + base_prompt=followup_event["best"] ) exclude += answer_event["uids"] @@ -205,3 +212,4 @@ async def forward(self): ) else: base_text = base_text + "\nQuestion:" + followup_event["best"] + "\nAnswer:" + answer_event["best"] + \ No newline at end of file