diff --git a/domainlab/algos/builder_matchdg.py b/domainlab/algos/builder_matchdg.py index 42bd7deff..bf556ab1a 100644 --- a/domainlab/algos/builder_matchdg.py +++ b/domainlab/algos/builder_matchdg.py @@ -62,6 +62,8 @@ def init_business(self, exp): i_c=task.isize.i_c, i_h=task.isize.i_h, i_w=task.isize.i_w) + # different than model, ctr_model has no classification so it has + # different wrapper ctr_model = ModelWrapMatchDGNet(ctr_net, list_str_y=task.list_str_y) ctr_model = ctr_model.to(device) diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index ae1814b95..c9102ccb7 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -60,6 +60,13 @@ def evaluate(self, loader_te, device): logger = Logger.get_logger() logger.info(f"before training, model accuracy: {acc}") + def extract_semantic_feat(self, tensor_x): + """ + by default, use the logit as extracted feature if the current method + is not being overriden by child class + """ + return self.cal_logit_y(tensor_x) + @abc.abstractmethod def cal_logit_y(self, tensor_x): """ @@ -199,4 +206,4 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ device = tensor_x.device bsize = tensor_x.shape[0] - return [torch.zeros(bsize, 1).to(device)], [0.0] \ No newline at end of file + return [torch.zeros(bsize, 1).to(device)], [0.0] diff --git a/test_match_duva.sh b/test_match_duva.sh new file mode 100644 index 000000000..39816019e --- /dev/null +++ b/test_match_duva.sh @@ -0,0 +1,6 @@ +python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --debug --bs=2 --aname=matchhduva \ + --epochs_ctr=3 --epos=6 --nname=conv_bn_pool_2 --gamma_y=7e5 \ + --nname_topic_distrib_img2topic=conv_bn_pool_2 \ + --nname_encoder_sandwich_layer_img2h4zd=conv_bn_pool_2 + + diff --git a/test_match_duva_vlcs.sh b/test_match_duva_vlcs.sh new file mode 100644 index 000000000..86926284a --- /dev/null +++ b/test_match_duva_vlcs.sh @@ -0,0 +1,4 @@ +python main_out.py --te_d=caltech --task=mini_vlcs --debug --bs=2 --aname=matchhduva \ + --epochs_ctr=3 --epos=6 --npath=examples/nets/resnet.py --gamma_y=7e5 \ + --npath_topic_distrib_img2topic=examples/nets/resnet.py \ + --npath_encoder_sandwich_layer_img2h4zd=examples/nets/resnet.py