From 01810c7189ddbdf320c2538c457e3437e33aff93 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 16 Nov 2023 17:18:27 +0100 Subject: [PATCH 1/6] not sure if we really need wrapper --- domainlab/algos/builder_match_hduva.py | 4 ++-- domainlab/models/model_wrapper_matchdg4vae.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/domainlab/algos/builder_match_hduva.py b/domainlab/algos/builder_match_hduva.py index 1c1c77bfe..612a578db 100644 --- a/domainlab/algos/builder_match_hduva.py +++ b/domainlab/algos/builder_match_hduva.py @@ -63,10 +63,10 @@ def init_business(self, exp): beta_y=args.beta_y, beta_d=args.beta_d) - model = ModelWrapMatchDGVAE(model, list_str_y=task.list_str_y) + # model = ModelWrapMatchDGVAE(model, list_str_y=task.list_str_y) model = model.to(device) - ctr_model = ModelWrapMatchDGVAE(model_ctr, list_str_y=task.list_str_y) + # ctr_model = ModelWrapMatchDGVAE(model_ctr, list_str_y=task.list_str_y) ctr_model = ctr_model.to(device) model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es)) diff --git a/domainlab/models/model_wrapper_matchdg4vae.py b/domainlab/models/model_wrapper_matchdg4vae.py index 20634f378..ea2b62221 100644 --- a/domainlab/models/model_wrapper_matchdg4vae.py +++ b/domainlab/models/model_wrapper_matchdg4vae.py @@ -27,3 +27,9 @@ def extract_semantic_feat(self, tensor_x): """ feat = self.net.extract_semantic_features(tensor_x) return feat + + def hyper_init(self, functor_scheduler, trainer=None): + self.net.hyper_init(functor_scheduler, trainer) + + def hyper_update(self, epoch, fun_scheduler): + self.net.hyper_update(epoch, fun_scheduler) From 088cc9c5cfa3fac6a226ba2cba4beb02c41cdcbe Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 16 Nov 2023 17:28:28 +0100 Subject: [PATCH 2/6] removed wrapper, memory too small --- domainlab/algos/builder_match_hduva.py | 3 ++- domainlab/models/model_hduva.py | 3 +++ test_match_duva.sh | 4 ++++ 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 test_match_duva.sh diff --git a/domainlab/algos/builder_match_hduva.py b/domainlab/algos/builder_match_hduva.py index 612a578db..44ca2f2e4 100644 --- a/domainlab/algos/builder_match_hduva.py +++ b/domainlab/algos/builder_match_hduva.py @@ -67,7 +67,8 @@ def init_business(self, exp): model = model.to(device) # ctr_model = ModelWrapMatchDGVAE(model_ctr, list_str_y=task.list_str_y) - ctr_model = ctr_model.to(device) + # ctr_model = ctr_model.to(device) + ctr_model = model_ctr.to(device) model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es)) observer = ObVisitor(model_sel, diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index 08779f49d..8dc83899e 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -195,4 +195,7 @@ def extract_semantic_features(self, tensor_x): zy_q_loc = self.encoder.infer_zy_loc(tensor_x) return zy_q_loc + def extract_semantic_feat(self, tensor_x): + return self.extract_semantic_features(tensor_x) + return ModelHDUVA diff --git a/test_match_duva.sh b/test_match_duva.sh new file mode 100644 index 000000000..86926284a --- /dev/null +++ b/test_match_duva.sh @@ -0,0 +1,4 @@ +python main_out.py --te_d=caltech --task=mini_vlcs --debug --bs=2 --aname=matchhduva \ + --epochs_ctr=3 --epos=6 --npath=examples/nets/resnet.py --gamma_y=7e5 \ + --npath_topic_distrib_img2topic=examples/nets/resnet.py \ + --npath_encoder_sandwich_layer_img2h4zd=examples/nets/resnet.py From 3230af0f9a0d256e0ae06a224ae8364e0c350ed4 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 16 Nov 2023 17:34:34 +0100 Subject: [PATCH 3/6] . --- test_match_duva.sh | 10 ++++++---- test_match_duva_vlcs.sh | 4 ++++ 2 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 test_match_duva_vlcs.sh diff --git a/test_match_duva.sh b/test_match_duva.sh index 86926284a..39816019e 100644 --- a/test_match_duva.sh +++ b/test_match_duva.sh @@ -1,4 +1,6 @@ -python main_out.py --te_d=caltech --task=mini_vlcs --debug --bs=2 --aname=matchhduva \ - --epochs_ctr=3 --epos=6 --npath=examples/nets/resnet.py --gamma_y=7e5 \ - --npath_topic_distrib_img2topic=examples/nets/resnet.py \ - --npath_encoder_sandwich_layer_img2h4zd=examples/nets/resnet.py +python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --debug --bs=2 --aname=matchhduva \ + --epochs_ctr=3 --epos=6 --nname=conv_bn_pool_2 --gamma_y=7e5 \ + --nname_topic_distrib_img2topic=conv_bn_pool_2 \ + --nname_encoder_sandwich_layer_img2h4zd=conv_bn_pool_2 + + diff --git a/test_match_duva_vlcs.sh b/test_match_duva_vlcs.sh new file mode 100644 index 000000000..86926284a --- /dev/null +++ b/test_match_duva_vlcs.sh @@ -0,0 +1,4 @@ +python main_out.py --te_d=caltech --task=mini_vlcs --debug --bs=2 --aname=matchhduva \ + --epochs_ctr=3 --epos=6 --npath=examples/nets/resnet.py --gamma_y=7e5 \ + --npath_topic_distrib_img2topic=examples/nets/resnet.py \ + --npath_encoder_sandwich_layer_img2h4zd=examples/nets/resnet.py From 8a4e4647c8283b36e71038701637fe242f2b5a84 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 17 Nov 2023 14:01:00 +0100 Subject: [PATCH 4/6] defautl extrac feat to cal logit --- domainlab/models/a_model_classif.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index ae1814b95..c9102ccb7 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -60,6 +60,13 @@ def evaluate(self, loader_te, device): logger = Logger.get_logger() logger.info(f"before training, model accuracy: {acc}") + def extract_semantic_feat(self, tensor_x): + """ + by default, use the logit as extracted feature if the current method + is not being overriden by child class + """ + return self.cal_logit_y(tensor_x) + @abc.abstractmethod def cal_logit_y(self, tensor_x): """ @@ -199,4 +206,4 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ device = tensor_x.device bsize = tensor_x.shape[0] - return [torch.zeros(bsize, 1).to(device)], [0.0] \ No newline at end of file + return [torch.zeros(bsize, 1).to(device)], [0.0] From 37aa8d5c8a80bcdb1a9aaabd94dad8168e924e80 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 17 Nov 2023 15:05:05 +0100 Subject: [PATCH 5/6] doc --- domainlab/algos/builder_matchdg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/domainlab/algos/builder_matchdg.py b/domainlab/algos/builder_matchdg.py index 42bd7deff..bf556ab1a 100644 --- a/domainlab/algos/builder_matchdg.py +++ b/domainlab/algos/builder_matchdg.py @@ -62,6 +62,8 @@ def init_business(self, exp): i_c=task.isize.i_c, i_h=task.isize.i_h, i_w=task.isize.i_w) + # different than model, ctr_model has no classification so it has + # different wrapper ctr_model = ModelWrapMatchDGNet(ctr_net, list_str_y=task.list_str_y) ctr_model = ctr_model.to(device) From 7f82bc9c26c3ff4f70b8f564a4923837728a3098 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 17 Nov 2023 16:40:03 +0100 Subject: [PATCH 6/6] remove redundancy --- domainlab/models/model_hduva.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index 7a522ec83..c8dc32430 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -205,7 +205,4 @@ def extract_semantic_feat(self, tensor_x): zy_q_loc = self.encoder.infer_zy_loc(tensor_x) return zy_q_loc - def extract_semantic_feat(self, tensor_x): - return self.extract_semantic_features(tensor_x) - return ModelHDUVA