From e0afd3a93cd22c961814238f6ff7e36f31a42ef3 Mon Sep 17 00:00:00 2001 From: Thomas Gjerde Date: Sun, 17 Aug 2025 13:31:09 +0200 Subject: [PATCH] Set new reporters on restored species object --- neat/population.py | 1 + tests/test_population.py | 46 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/neat/population.py b/neat/population.py index 706a76e4..d3819e09 100644 --- a/neat/population.py +++ b/neat/population.py @@ -48,6 +48,7 @@ def __init__(self, config, initial_state=None): self.species.speciate(config, self.population, self.generation) else: self.population, self.species, self.generation = initial_state + self.species.reporters = self.reporters # If the reproduction object has a genome indexer, # set it to continue from the last genome ID. if hasattr(self.reproduction, "genome_indexer"): diff --git a/tests/test_population.py b/tests/test_population.py index 9275dcbd..5ff70d4b 100644 --- a/tests/test_population.py +++ b/tests/test_population.py @@ -69,6 +69,52 @@ def eval_genomes(genomes, config): last_genome_key + 1 ) + def test_reporter_consistency_after_checkpoint_restore(self): + """ + Test that ReportSets in the different objects in population are the same + after restoring from a checkpoint. + """ + # Load configuration. + local_dir = os.path.dirname(__file__) + config_path = os.path.join(local_dir, 'test_configuration') + config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, + neat.DefaultSpeciesSet, neat.DefaultStagnation, + config_path) + + p = neat.Population(config) + filename_prefix = 'neat-checkpoint-test_population' + checkpointer = neat.Checkpointer(1, 5, filename_prefix=filename_prefix) + p.add_reporter(checkpointer) + + reporter_set = p.reporters + self.assertEqual(reporter_set, p.reproduction.reporters) + self.assertEqual(reporter_set, p.species.reporters) + + def eval_genomes(genomes, config): + for genome_id, genome in genomes: + genome.fitness = 0.5 + + p.run(eval_genomes, 5) + + filename = '{0}{1}'.format( + filename_prefix, checkpointer.last_generation_checkpoint + ) + restored_population = neat.Checkpointer.restore_checkpoint(filename) + + # Check that the reporters are consistent + restored_reporter_set = restored_population.reporters + self.assertEqual( + restored_reporter_set, + restored_population.reproduction.reporters, + msg="Reproduction reporters do not match after restore" + ) + self.assertEqual( + restored_reporter_set, + restored_population.species.reporters, + msg="Species reporters do not match after restore" + ) + + # def test_minimal(): # # sample fitness function # def eval_fitness(population):