diff --git a/monai/apps/auto3dseg/ensemble_builder.py b/monai/apps/auto3dseg/ensemble_builder.py index 6f26b9a777..0ac7ed4930 100644 --- a/monai/apps/auto3dseg/ensemble_builder.py +++ b/monai/apps/auto3dseg/ensemble_builder.py @@ -175,8 +175,19 @@ def __call__(self, pred_param: dict | None = None) -> list: pred = infer_instance.predict(predict_files=[file], predict_params=param) preds.append(pred[0]) if "image_save_func" in param: - res = img_saver(self.ensemble_pred(preds, sigmoid=sigmoid)) + try: + ensemble_preds = self.ensemble_pred(preds, sigmoid=sigmoid) + except BaseException: + ensemble_preds = self.ensemble_pred([_.to("cpu") for _ in preds], sigmoid=sigmoid) + res = img_saver(ensemble_preds) + # res is the path to the saved results + if hasattr(res, "meta") and "saved_to" in res.meta.keys(): + res = res.meta["saved_to"] + else: + warn("Image save path not returned.") + res = None else: + warn("Prediction returned in list instead of disk, provide image_save_func to avoid out of memory.") res = self.ensemble_pred(preds, sigmoid=sigmoid) outputs.append(res) return outputs @@ -451,6 +462,7 @@ def set_image_save_transform(self, **kwargs): "output_dtype": output_dtype, "resample": resample, "print_log": False, + "savepath_in_metadict": True, } if kwargs: self.save_image.update(kwargs) @@ -483,7 +495,7 @@ def ensemble(self): if history_untrained: if self.rank == 0: warn( - f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos." + f"Ensembling step will skip {[h[AlgoKeys.ID] for h in history_untrained]} untrained algos." "Generally it means these algos did not complete training." ) history = [h for h in history if h[AlgoKeys.IS_TRAINED]] @@ -497,7 +509,9 @@ def ensemble(self): builder.set_ensemble_method(self.ensemble_method) self.ensembler = builder.get_ensemble() infer_files = self.ensembler.infer_files - infer_files = partition_dataset(data=infer_files, shuffle=False, num_partitions=self.world_size)[self.rank] + infer_files = partition_dataset( + data=infer_files, shuffle=False, num_partitions=self.world_size, even_divisible=True + )[self.rank] # TO DO: Add some function in ensembler for infer_files update? self.ensembler.infer_files = infer_files # self.kwargs has poped out args for set_image_save_transform diff --git a/monai/apps/auto3dseg/utils.py b/monai/apps/auto3dseg/utils.py index 90de5e8f75..0e5c734cc6 100644 --- a/monai/apps/auto3dseg/utils.py +++ b/monai/apps/auto3dseg/utils.py @@ -50,6 +50,12 @@ def import_bundle_algo_history( algo.template_path = algo_meta_data["template_path"] best_metric = algo_meta_data.get(AlgoKeys.SCORE, None) + if best_metric is None: + try: + best_metric = algo.get_score() + except BaseException: + pass + is_trained = best_metric is not None if (only_trained and is_trained) or not only_trained: diff --git a/tests/test_auto3dseg_hpo.py b/tests/test_auto3dseg_hpo.py index f8b9ebb2fb..08a51752e5 100644 --- a/tests/test_auto3dseg_hpo.py +++ b/tests/test_auto3dseg_hpo.py @@ -177,9 +177,8 @@ def test_get_history(self) -> None: obj_filename = nni_gen.get_obj_filename() NNIGen().run_algo(obj_filename, self.work_dir) - history = import_bundle_algo_history(self.work_dir, only_trained=True) - assert len(history) == 1 + assert len(history) == 3 def tearDown(self) -> None: self.test_dir.cleanup()