diff --git a/langtest/langtest.py b/langtest/langtest.py index b3f5b614e..b5ba0d90f 100644 --- a/langtest/langtest.py +++ b/langtest/langtest.py @@ -1572,8 +1572,18 @@ def __multi_datasets_run( ): generated_results = {} + # temp_store_prompt + temp_store_prompt = self._config.get("model_parameters", {}).get( + "user_prompt", None + ) + # Run the testcases for each dataset for dataset_name, samples in testcases.items(): + # update user prompt for each dataset + if temp_store_prompt and isinstance(temp_store_prompt, dict): + self._config.get("model_parameters", {}).update( + {"user_prompt": temp_store_prompt.get(dataset_name)} + ) # Get the raw data for the dataset if isinstance(self.data, dict): raw_data = self.data.get(dataset_name) @@ -1597,6 +1607,12 @@ def __multi_datasets_run( print(f"{'':-^80}\n") + # resore user prompt + if temp_store_prompt: + self._config.get("model_parameters", {}).update( + {"user_prompt": temp_store_prompt} + ) + if ( self.is_multi_dataset and self._generated_results is None