From cbabfebf9df3d1f4688b834165ee692a4345dfbf Mon Sep 17 00:00:00 2001 From: sidulova Date: Fri, 8 Mar 2024 12:08:28 -0500 Subject: [PATCH] new argument in arg parser that is bassed to b task. --- domainlab/arg_parser.py | 5 ++++- domainlab/tasks/b_task.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index f8f78c28b..62e002350 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -294,7 +294,10 @@ def mk_parser_main(): arg_group_task.add_argument( "--loglevel", type=str, default="DEBUG", help="sets the loglevel of the logger" ) - + arg_group_task.add_argument( + "--shuffling_off", action="store_true", default=False, + help="disable shuffling of the training dataloader for the dataset" + ) # args for variational auto encoder arg_group_vae = parser.add_argument_group("vae") arg_group_vae = add_args2parser_vae(arg_group_vae) diff --git a/domainlab/tasks/b_task.py b/domainlab/tasks/b_task.py index d8b778689..56ac4cd79 100644 --- a/domainlab/tasks/b_task.py +++ b/domainlab/tasks/b_task.py @@ -54,7 +54,9 @@ def init_business(self, args, trainer=None): self.dict_loader_tr.update({na_domain: mk_loader(ddset_tr, args.bs)}) self.dict_dset_val.update({na_domain: ddset_val}) ddset_mix = ConcatDataset(tuple(self.dict_dset_tr.values())) - self._loader_tr = mk_loader(ddset_mix, args.bs) + flag_shuffling = not args.shuffling_off + # args.shuffling_off default is False -> not False -> True + self._loader_tr = mk_loader(ddset_mix, args.bs, shuffle=flag_shuffling) ddset_mix_val = ConcatDataset(tuple(self.dict_dset_val.values())) self._loader_val = mk_loader(