From 60d3fe212f338e1303268759a049c83924944afd Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+behxyz@users.noreply.github.com> Date: Tue, 6 Apr 2021 21:01:10 -0400 Subject: [PATCH 1/2] Update garbage collection assertion to a more reliable one Signed-off-by: Behrooz <3968947+behxyz@users.noreply.github.com> --- tests/test_handler_garbage_collector.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index 5e6bd7275c..024a3ac846 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -63,14 +63,14 @@ def _train_func(engine, batch): print(gb_count_dict) first_count = 0 - for epoch, gb_count in gb_count_dict.items(): - # At least one zero-generation object + for iter, gb_count in gb_count_dict.items(): + # At least one zero-generation object is collected self.assertGreater(gb_count[0], 0) - if epoch == 1: - first_count = gb_count[0] - else: - # The should be less number of collected objects in the next calls. - self.assertLess(gb_count[0], first_count) + if iter > 1: + # Since we are collecting all objects from all generations manually at each call, + # starting from the second call, there shouldn't be any 1st and 2nd generation objects available to collect. + self.assertEqual(gb_count[1], first_count) + self.assertEqual(gb_count[2], first_count) if __name__ == "__main__": From 3973881771b936ef942bcf5b67ff32ff425f77c2 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+behxyz@users.noreply.github.com> Date: Tue, 6 Apr 2021 21:04:54 -0400 Subject: [PATCH 2/2] Remove print Signed-off-by: Behrooz <3968947+behxyz@users.noreply.github.com> --- tests/test_handler_garbage_collector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index 024a3ac846..9f63211a13 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -60,7 +60,6 @@ def _train_func(engine, batch): GarbageCollector(trigger_event=trigger_event, log_level=30).attach(engine) engine.run(data_loader, max_epochs=5) - print(gb_count_dict) first_count = 0 for iter, gb_count in gb_count_dict.items():