diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index 5e6bd7275c..9f63211a13 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -60,17 +60,16 @@ def _train_func(engine, batch): GarbageCollector(trigger_event=trigger_event, log_level=30).attach(engine) engine.run(data_loader, max_epochs=5) - print(gb_count_dict) first_count = 0 - for epoch, gb_count in gb_count_dict.items(): - # At least one zero-generation object + for iter, gb_count in gb_count_dict.items(): + # At least one zero-generation object is collected self.assertGreater(gb_count[0], 0) - if epoch == 1: - first_count = gb_count[0] - else: - # The should be less number of collected objects in the next calls. - self.assertLess(gb_count[0], first_count) + if iter > 1: + # Since we are collecting all objects from all generations manually at each call, + # starting from the second call, there shouldn't be any 1st and 2nd generation objects available to collect. + self.assertEqual(gb_count[1], first_count) + self.assertEqual(gb_count[2], first_count) if __name__ == "__main__":