diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36cdc2044..0e825f216 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,4 +41,4 @@ jobs: - name: test if examples in markdown works run: bash -x -v ci_run_examples.sh - name: test if benchmark works - run: pip install snakemake && sed -i '1s/^/#!\/bin\/bash -x -v\n/' run_benchmark_standalone.sh && bash -x -v run_benchmark_standalone.sh examples/benchmark/demo_shared_hyper_grid.yaml && cat zoutput/benchmarks/mnist_benchmark_grid/hyperparameters.csv && cat zoutput/benchmarks/mnist_benchmark_grid/results.csv + run: pip install snakemake==7.32.0 && pip install pulp==2.7.0 && sed -i '1s/^/#!\/bin\/bash -x -v\n/' run_benchmark_standalone.sh && bash -x -v run_benchmark_standalone.sh examples/benchmark/demo_shared_hyper_grid.yaml && cat zoutput/benchmarks/mnist_benchmark_grid/hyperparameters.csv && cat zoutput/benchmarks/mnist_benchmark_grid/results.csv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..9b66a6665 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 # Use the specific version of the repo + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black diff --git a/README.md b/README.md index e7f149b31..f328ee905 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,10 @@ ## Distribution shifts, domain generalization and DomainLab -Neural networks trained using data from a specific distribution (domain) usually fails to generalize to novel distributions (domains). Domain generalization aims at learning domain invariant features by utilizing data from multiple domains (data sites, corhorts, batches, vendors) so the learned feature can generalize to new unseen domains (distributions). +Neural networks trained using data from a specific distribution (domain) usually fails to generalize to novel distributions (domains). Domain generalization aims at learning domain invariant features by utilizing data from multiple domains (data sites, corhorts, batches, vendors) so the learned feature can generalize to new unseen domains (distributions).
- +
DomainLab is a software platform with state-of-the-art domain generalization algorithms implemented, designed by maximal decoupling of different software components thus enhances maximal code reuse. @@ -19,7 +19,7 @@ DomainLab is a software platform with state-of-the-art domain generalization alg DomainLab decouples the following concepts or objects: - task $M$: a combination of datasets (e.g. from distribution $D_1$ and $D_2$) - neural network: a map $\phi$ from the input data to the feature space and a map $\varphi$ from feature space to output $\hat{y}$ (e.g. decision variable). -- model: structural risk in the form of $\ell() + \mu R()$ where +- model: structural risk in the form of $\ell() + \mu R()$ where - $\ell(Y, \hat{y}=\varphi(\phi(X)))$ is the task specific empirical loss (e.g. cross entropy for classification task). - $R(\phi(X))$ is the penalty loss to boost domain invariant feature extraction using $\phi$. - $\mu$ is the corresponding multiplier to each penalty function factor. @@ -35,7 +35,7 @@ DomainLab makes it possible to combine models with models, trainers with models, ### Installation For development version in Github, see [Installation and Dependencies handling](./docs/doc_install.md) -We also offer a PyPI version here https://pypi.org/project/domainlab/ which one could install via `pip install domainlab` and it is recommended to create a virtual environment for it. +We also offer a PyPI version here https://pypi.org/project/domainlab/ which one could install via `pip install domainlab` and it is recommended to create a virtual environment for it. ### Task specification In DomainLab, a task is a container for datasets from different domains. See detail in @@ -43,13 +43,13 @@ In DomainLab, a task is a container for datasets from different domains. See det ### Example and usage -#### Either clone this repo and use command line +#### Either clone this repo and use command line `python main_out.py -c ./examples/conf/vlcs_diva_mldg_dial.yaml` where the configuration file below can be downloaded [here](https://raw.githubusercontent.com/marrlab/DomainLab/master/examples/conf/vlcs_diva_mldg_dial.yaml) ``` te_d: caltech # domain name of test domain -tpath: examples/tasks/task_vlcs.py # python file path to specify the task +tpath: examples/tasks/task_vlcs.py # python file path to specify the task bs: 2 # batch size model: dann_diva # combine model DANN with DIVA epos: 1 # number of epochs @@ -67,16 +67,16 @@ See example here: [Transformer as feature extractor, decorate JIGEN with DANN, t ### Benchmark different methods -DomainLab provides a powerful benchmark functionality. +DomainLab provides a powerful benchmark functionality. To benchmark several algorithms(combination of neural networks, models, trainers and associated hyperparameters), a single line command along with a benchmark configuration files is sufficient. See details in [benchmarks documentation and tutorial](./docs/doc_benchmark.md) -One could simply run -`bash run_benchmark_slurm.sh your_benchmark_configuration.yaml` to launch different experiments with specified configuraiton. +One could simply run +`bash run_benchmark_slurm.sh your_benchmark_configuration.yaml` to launch different experiments with specified configuraiton. For example, the following result (without any augmentation like flip) is for PACS dataset.
- +
where each rectangle represent one model trainer combination, each bar inside the rectangle represent a unique hyperparameter index associated with that method combination, each dot represent a random seeds. diff --git a/ci.sh b/ci.sh index e95f51f1b..408dc2d95 100644 --- a/ci.sh +++ b/ci.sh @@ -21,5 +21,3 @@ endtime=`date +%s` runtime=$((endtime-starttime)) echo "total time used:" echo "$runtime" - - diff --git a/ci_pytest_cov.sh b/ci_pytest_cov.sh index 0b3f2b133..c0ebf6d70 100644 --- a/ci_pytest_cov.sh +++ b/ci_pytest_cov.sh @@ -1,5 +1,5 @@ #!/bin/bash -export CUDA_VISIBLE_DEVICES="" +export CUDA_VISIBLE_DEVICES="" # although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error # so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring # --cov-report term-missing to show in console file wise coverage and lines missing diff --git a/ci_run_examples.sh b/ci_run_examples.sh index 15cdc0fbc..9f6b4e041 100644 --- a/ci_run_examples.sh +++ b/ci_run_examples.sh @@ -6,8 +6,8 @@ set -e # exit upon first error # echo "#!/bin/bash -x -v" > sh_temp_example.sh sed -n '/```shell/,/```/ p' docs/doc_examples.md | sed '/^```/ d' >> ./sh_temp_example.sh split -l 5 sh_temp_example.sh sh_example_split -for file in sh_example_split*; -do (echo "#!/bin/bash -x -v" > "$file"_exe && cat "$file" >> "$file"_exe && bash -x -v "$file"_exe && rm -r zoutput); +for file in sh_example_split*; +do (echo "#!/bin/bash -x -v" > "$file"_exe && cat "$file" >> "$file"_exe && bash -x -v "$file"_exe && rm -r zoutput); done # bash -x -v -e sh_temp_example.sh echo "general examples done" diff --git a/data/mixed_codec/caltech/auto/text.txt b/data/mixed_codec/caltech/auto/text.txt index 5e1c309da..557db03de 100644 --- a/data/mixed_codec/caltech/auto/text.txt +++ b/data/mixed_codec/caltech/auto/text.txt @@ -1 +1 @@ -Hello World \ No newline at end of file +Hello World diff --git a/data/script/download_pacs.py b/data/script/download_pacs.py index b05c7a4de..51c346f24 100644 --- a/data/script/download_pacs.py +++ b/data/script/download_pacs.py @@ -1,14 +1,16 @@ -'this script can be used to download the pacs dataset' +"this script can be used to download the pacs dataset" import os import tarfile from zipfile import ZipFile + import gdown + def stage_path(data_dir, name): - ''' + """ creates the path to data_dir/name if it does not exist already - ''' + """ full_path = os.path.join(data_dir, name) if not os.path.exists(full_path): @@ -16,11 +18,12 @@ def stage_path(data_dir, name): return full_path + def download_and_extract(url, dst, remove=True): - ''' + """ downloads and extracts the data behind the url and saves it at dst - ''' + """ gdown.download(url, dst, quiet=False) if dst.endswith(".tar.gz"): @@ -43,17 +46,19 @@ def download_and_extract(url, dst, remove=True): def download_pacs(data_dir): - ''' + """ download and extract dataset pacs. Dataset is saved at location data_dir - ''' + """ full_path = stage_path(data_dir, "PACS") - download_and_extract("https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", - os.path.join(data_dir, "PACS.zip")) + download_and_extract( + "https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", + os.path.join(data_dir, "PACS.zip"), + ) + + os.rename(os.path.join(data_dir, "kfold"), full_path) - os.rename(os.path.join(data_dir, "kfold"), - full_path) -if __name__ == '__main__': - download_pacs('../pacs') +if __name__ == "__main__": + download_pacs("../pacs") diff --git a/data/ztest_files/dummy_file.py b/data/ztest_files/dummy_file.py index ee817a687..e0c7faa27 100644 --- a/data/ztest_files/dummy_file.py +++ b/data/ztest_files/dummy_file.py @@ -1,4 +1,4 @@ -''' +""" I am a dummy file used in tests/test_git_tag.py to produce a file which is not commited -''' +""" diff --git a/docs/.nojekyll b/docs/.nojekyll index 8b1378917..e69de29bb 100644 --- a/docs/.nojekyll +++ b/docs/.nojekyll @@ -1 +0,0 @@ - diff --git a/docs/conf.py b/docs/conf.py index 0d5d79a76..bfa3e6eb3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,18 +15,16 @@ # Incase the project was not installed import os import sys -import sphinx_material from datetime import datetime +import sphinx_material + sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- project = "domainlab" # @FIXME -copyright = ( - f"2021-{datetime.now().year}, Marr Lab." - "" -) +copyright = f"2021-{datetime.now().year}, Marr Lab." "" author = "Xudong Sun, et.al." @@ -94,11 +92,11 @@ # '.md': 'recommonmark.parser.CommonMarkParser', # } -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] source_suffix = { - '.rst': 'restructuredtext', - '.txt': 'markdown', - '.md': 'markdown', + ".rst": "restructuredtext", + ".txt": "markdown", + ".md": "markdown", } # The master toctree document. @@ -114,11 +112,13 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = ["setup.py", - "_build", - "Thumbs.db", - ".DS_Store", - "**.ipynb_checkpoints"] +exclude_patterns = [ + "setup.py", + "_build", + "Thumbs.db", + ".DS_Store", + "**.ipynb_checkpoints", +] # The name of the Pygments (syntax highlighting) style to use. pygments_style = "default" @@ -129,10 +129,9 @@ # -- HTML theme settings ------------------------------------------------ html_short_title = "domainlab" # @FIXME html_show_sourcelink = False -html_sidebars = {"**": ["logo-text.html", - "globaltoc.html", - "localtoc.html", - "searchbox.html"]} +html_sidebars = { + "**": ["logo-text.html", "globaltoc.html", "localtoc.html", "searchbox.html"] +} html_theme_path = sphinx_material.html_theme_path() html_context = sphinx_material.get_html_context() @@ -157,40 +156,30 @@ "master_doc": False, "nav_title": "DomainLab", "nav_links": [ - { - "href": "readme_link", - "internal": True, - "title": "Introduction"}, - { - "href": "doc_tasks", - "internal": True, - "title": "Task Specification"}, + {"href": "readme_link", "internal": True, "title": "Introduction"}, + {"href": "doc_tasks", "internal": True, "title": "Task Specification"}, { "href": "doc_custom_nn", "internal": True, - "title": "Specify neural network in commandline"}, + "title": "Specify neural network in commandline", + }, { "href": "doc_MNIST_classification", "internal": True, - "title": "Examples with MNIST"}, + "title": "Examples with MNIST", + }, { "href": "doc_examples", "internal": True, - "title": "More commandline examples"}, - - { - "href": "doc_benchmark", - "internal": True, - "title": "Benchmarks tutorial"}, - - { - "href": "doc_output", - "internal": True, - "title": "Output Structure"}, + "title": "More commandline examples", + }, + {"href": "doc_benchmark", "internal": True, "title": "Benchmarks tutorial"}, + {"href": "doc_output", "internal": True, "title": "Output Structure"}, { "href": "doc_extend_contribute", "internal": True, - "title": "Specify custom model in commandline"}, + "title": "Specify custom model in commandline", + }, # { # "href": "https://squidfunk.github.io/mkdocs-material/", # "internal": False, @@ -251,7 +240,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, "domainlab", "domainlab", [author], 1)] # @FIXME +man_pages = [(master_doc, "domainlab", "domainlab", [author], 1)] # @FIXME # -- Options for Texinfo output ---------------------------------------------- diff --git a/docs/conf0.py b/docs/conf0.py index 21f1ad1ce..c6138b9f6 100644 --- a/docs/conf0.py +++ b/docs/conf0.py @@ -12,17 +12,18 @@ # import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) sys.setrecursionlimit(1500) # -- Project information ----------------------------------------------------- -project = 'domainlab' -copyright = '2022, Xudong Sun' -author = 'Xudong Sun' +project = "domainlab" +copyright = "2022, Xudong Sun" +author = "Xudong Sun" # The full version, including alpha/beta/rc tags -release = '0.0.0' +release = "0.0.0" # -- General configuration --------------------------------------------------- @@ -46,7 +47,7 @@ ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -59,7 +60,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -67,9 +68,9 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/docs/docDIAL.md b/docs/docDIAL.md index 2765b5762..8b8111de8 100644 --- a/docs/docDIAL.md +++ b/docs/docDIAL.md @@ -7,22 +7,22 @@ The algorithm introduced in https://arxiv.org/pdf/2104.00322.pdf uses adversaria ## generating the adversarial domain The generation of adversary images is demonstrated in figure 1. -The task is to find an adversary image $x'$ to the natural image $x$ with $||x- x'||$ small, such that the output of a classification network $\phi$ fulfills $||\phi(x) - \phi(x')||$ big. In the example in figure 1 you can for example see, that the difference between the left and the right image of the panda is unobservable, but the classifier does still classify them differently. +The task is to find an adversary image $x'$ to the natural image $x$ with $||x- x'||$ small, such that the output of a classification network $\phi$ fulfills $||\phi(x) - \phi(x')||$ big. In the example in figure 1 you can for example see, that the difference between the left and the right image of the panda is unobservable, but the classifier does still classify them differently. In Domainlab the adversary images are created starting from a random perturbation of the natural image $x'_0 = x + \sigma \tilde{x}~$, $\tilde{x} \sim \mathcal{N}(0, 1)$ and using $n$ steps in a gradient descend with step size $\tau$ to maximize $||\phi(x) - \phi(x')||$. In general machine learning, the generation of adversary images is used during the training process to make networks more robust to adversarial attacks.
- +
Figure 1: adversarial domain (Image source: Figure 1 of Explaining and Harnessing Adversarial Examples https://arxiv.org/abs/1412.6572)
## network structure -The network consists of three parts. At first a feature extractor, which extracts the main characteristics of the images. This features are then used as the input to a label classifier and a domain classifier. +The network consists of three parts. At first a feature extractor, which extracts the main characteristics of the images. This features are then used as the input to a label classifier and a domain classifier. During training the network is optimized to a have low error on the classification task, while ensuring that the internal representation (output of the feature extractor) cannot discriminate between the natural and adversarial domain. This goal can be archived by using a special loss function in combination with a gradient reversal layer.
- +
Figure 2: network structure (Image source: Figure 1 of Domain Invariant Adversarial Learning https://arxiv.org/pdf/2104.00322.pdf)
@@ -42,7 +42,7 @@ During training the network is optimized to a have low error on the classificati [comment]: <> ($$) -[comment]: <> (DIAL_{CE} = CE_{nat} + \lambda ~ CE_{adv} - r / D_{nat} + D_{adv} / ) +[comment]: <> (DIAL_{CE} = CE_{nat} + \lambda ~ CE_{adv} - r / D_{nat} + D_{adv} / ) [comment]: <> ($$) diff --git a/docs/docFishr.md b/docs/docFishr.md index 3949c332f..08580d9fe 100644 --- a/docs/docFishr.md +++ b/docs/docFishr.md @@ -1,14 +1,14 @@ # Trainer Fishr ## Invariant Gradient Variances for Out-of-distribution Generalization -The goal of the Fishr regularization technique is locally aligning the domain-level loss landscapes +The goal of the Fishr regularization technique is locally aligning the domain-level loss landscapes around the final weights, finding a minimizer around which the inconsistencies between the domain-level loss landscapes are as small as possible. This is done by considering second order terms during training, matching the variances between the domain-level gradients.
- +
Figure 1: Fishr matches the domain-level gradient variances of the distributions across the training domains (Image source: Figure 1 of "Fishr: Invariant gradient variances for out-of-distribution generalization")
@@ -19,18 +19,18 @@ Invariant gradient variances for out-of-distribution generalization")
### Quantifying inconsistency between domains Intuitively, two domains are locally inconsistent around a minimizer, if a small perturbation of the minimizer highly affects its optimality in one domain, but only -minimally affects its optimality in the other domain. Under certain assumptions, most importantly +minimally affects its optimality in the other domain. Under certain assumptions, most importantly the Hessians being positive definite, it is possible to measure the inconsistency between two domains $A$ and $B$ with the following inconsistency score: $$ -\mathcal{I}^\epsilon ( \theta^* ) = \text{max}_ {(A,B)\in\mathcal{E}^2} \biggl( \mathcal{R}_ B (\theta^* ) - \mathcal{R}_ {A} ( \theta^* ) + \text{max}_ {\frac{1}{2} \theta^T H_A \theta\leq\epsilon}\frac{1}{2}\theta^T H_B \theta \biggl) +\mathcal{I}^\epsilon ( \theta^* ) = \text{max}_ {(A,B)\in\mathcal{E}^2} \biggl( \mathcal{R}_ B (\theta^* ) - \mathcal{R}_ {A} ( \theta^* ) + \text{max}_ {\frac{1}{2} \theta^T H_A \theta\leq\epsilon}\frac{1}{2}\theta^T H_B \theta \biggl) $$ , whereby $\theta^*$ denotes the minimizer, $\mathcal{E}$ denotes the set of training domains, $H_e$ denotes the Hessian for $e\in\mathcal{E}$, $\theta$ denote the network parameters and $\mathcal{R}_e$ for $e\in\mathcal{E}$ denotes the domain-level ERM objective. -The Fishr regularization method forces both terms on the right hand side +The Fishr regularization method forces both terms on the right hand side of the inconsistency score to become small. The first term represents the difference between the domain-level risks and is implicitly forced to be small by applying the Fishr regularization. For the second term it suffices to align diagonal approximations of the @@ -64,7 +64,7 @@ $v = \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}} v_e$. ### Implementation The variance of the gradients within each domain can be computed with the BACKPACK package (see: Dangel, Felix, Frederik Kunstner, and Philipp Hennig. -"Backpack: Packing more into backprop." https://arxiv.org/abs/1912.10985). +"Backpack: Packing more into backprop." https://arxiv.org/abs/1912.10985). Further on, we use $ \textnormal{Var}(G) \approx \textnormal{diag}(H) $. The Hessian is then approximated by the Fisher Information Matrix, which again is approximated by an empirical estimator for computational efficiency. diff --git a/docs/docHDUVA.md b/docs/docHDUVA.md index d70b4dcf9..4abcb71e9 100644 --- a/docs/docHDUVA.md +++ b/docs/docHDUVA.md @@ -1,26 +1,24 @@ # Model HDUVA ## HDUVA: HIERARCHICAL VARIATIONAL AUTO-ENCODING FOR UNSUPERVISED DOMAIN GENERALIZATION -HDUVA builds on a generative approach within the framework of variational autoencoders to facilitate generalization to new domains without supervision. HDUVA learns representations that disentangle domain-specific information from class-label specific information even in complex settings where domain structure is not observed during training. +HDUVA builds on a generative approach within the framework of variational autoencoders to facilitate generalization to new domains without supervision. HDUVA learns representations that disentangle domain-specific information from class-label specific information even in complex settings where domain structure is not observed during training. ## Model Overview -More specifically, HDUVA is based on three latent variables that are used to model distinct sources of variation and are denoted as $z_y$, $z_d$ and $z_x$. $z_y$ represents class specific information, $z_d$ represents domain specific information and $z_x$ models residual variance of the input. We introduce an additional hierarchical level and use a continuous latent representation s to model (potentially unobserved) domain structure. This means that we can encourage disentanglement of the latent variables through conditional priors without the need of conditioning on a one-hot-encoded, observed domain label. The model along with its parameters and hyperparameters is shown in Figure 1: +More specifically, HDUVA is based on three latent variables that are used to model distinct sources of variation and are denoted as $z_y$, $z_d$ and $z_x$. $z_y$ represents class specific information, $z_d$ represents domain specific information and $z_x$ models residual variance of the input. We introduce an additional hierarchical level and use a continuous latent representation s to model (potentially unobserved) domain structure. This means that we can encourage disentanglement of the latent variables through conditional priors without the need of conditioning on a one-hot-encoded, observed domain label. The model along with its parameters and hyperparameters is shown in Figure 1:
- PGM for HDUVA + PGM for HDUVA
Figure 1: Probabilistic graphical model for HDUVA:Hierarchical Domain Unsupervised Variational Autoencoding.
- - Note that as part of the model a latent representation of $X$ is concatentated with $s$ and $z_d$ (dashed arrows), requiring respecive encoder networks. ## Evidence lower bound and overall loss -The ELBO of the model can be decomposed into 4 different terms: +The ELBO of the model can be decomposed into 4 different terms: -Likelihood: $E_{q(z_d, s|x), q(z_x|x), q(z_y|x)}\log p_{\theta}(x|s, z_d, z_x, z_y)$ +Likelihood: $E_{q(z_d, s|x), q(z_x|x), q(z_y|x)}\log p_{\theta}(x|s, z_d, z_x, z_y)$ -KL divergence weighted as in the Beta-VAE: $-\beta_x KL(q_{\phi_x}(z_x|x)||p_{\theta_x}(z_x)) - \beta_y KL(q_{\phi_y}(z_y|x)||p_{\theta_y}(z_y|y))$ +KL divergence weighted as in the Beta-VAE: $-\beta_x KL(q_{\phi_x}(z_x|x)||p_{\theta_x}(z_x)) - \beta_y KL(q_{\phi_y}(z_y|x)||p_{\theta_y}(z_y|y))$ Hierarchical KL loss (domain term): $- \beta_d E_{q_{\phi_s}(s|x), q_{\phi_d}(z_d|x, s)} \log \frac{q_{\phi_d}(z_d|x, s)}{p_{\theta_d}(z_d|s)}$ @@ -30,28 +28,28 @@ In addition, we construct the overall loss by adding an auxiliary classsifier, b ## Hyperparameters loss function -For fitting the model, we need to specify the 4 $\beta$-weights related to the the different terms of the ELBO ( $\beta_x$ , $\beta_y$, $\beta_d$, $\beta_t$) as well as $\gamma_y$. +For fitting the model, we need to specify the 4 $\beta$-weights related to the the different terms of the ELBO ( $\beta_x$ , $\beta_y$, $\beta_d$, $\beta_t$) as well as $\gamma_y$. ## Model hyperparameters -In addition to these hyperparameters, the following model parameters can be specified: +In addition to these hyperparameters, the following model parameters can be specified: - `zd_dim`: size of latent space for domain-specific information - `zx_dim`: size of latent space for residual variance - `zy_dim`: size of latent space for class-specific information - `topic_dim`: size of dirichlet distribution for topics $s$ -The user need to specify at least two neural networks for the **encoder** part via +The user need to specify at least two neural networks for the **encoder** part via -- `npath_encoder_x2topic_h`: the python file path of a neural network that maps the image (or other +- `npath_encoder_x2topic_h`: the python file path of a neural network that maps the image (or other modal of data to a one dimensional (`topic_dim`) hidden representation serving as input to Dirichlet encoder: `X->h_t(X)->alpha(h_t(X))` where `alpha` is the neural network to map a 1-d hidden layer to dirichlet concentration parameter. -- `npath_encoder_sandwich_x2h4zd`: the python file path of a neural network that maps the + +- `npath_encoder_sandwich_x2h4zd`: the python file path of a neural network that maps the image to a hidden representation (same size as `topic_dim`), which will be used to infere the posterior distribution of `z_d`: `topic(X), X -> [topic(X), h_d(X)] -> zd_mean, zd_scale` Alternatively, one could use an existing neural network in DomainLab using `nname` instead of `npath`: - `nname_encoder_x2topic_h` - `nname_encoder_sandwich_x2h4zd` - ## Hyperparameter for warmup Finally, the number of epochs for hyper-parameter warm-up can be specified via the argument `warmup`. diff --git a/docs/docJiGen.md b/docs/docJiGen.md index cd17ac140..8830842ee 100644 --- a/docs/docJiGen.md +++ b/docs/docJiGen.md @@ -1,23 +1,23 @@ # Model JiGen The JiGen method extends the understanding of the concept of spatial correlation in the -neural network by training the network not only on a classification task, but also on solving jigsaw puzzles. +neural network by training the network not only on a classification task, but also on solving jigsaw puzzles. -To create a jigsaw puzzle, an image is split into $n \times n$ patches, which are then permuted. -The goal is training the model to predict the correct permutation, which results in the permuted image. +To create a jigsaw puzzle, an image is split into $n \times n$ patches, which are then permuted. +The goal is training the model to predict the correct permutation, which results in the permuted image. To solve the classification problem and the jigsaw puzzle in parallel, the permuted and the original images are first fed into a convolutional network for feature extraction and then given to two classifiers, one being the image classifier and the other the jigsaw classifier. -For the training of both classification networks, a cross-entropy loss is used. The total loss is then +For the training of both classification networks, a cross-entropy loss is used. The total loss is then given by the loss of the image classification task plus the loss of the jigsaw task, whereby the jigsaw loss is weighted by a hyperparameter. Another hyperparameter denotes the probability of shuffling the patches of one instance from the training data set, i.e. the relative ratio. The advantage of this method is that it does not require domain labels, as the jigsaw puzzle can be -solved despite missing domain labels. +solved despite missing domain labels. ### Model parameters The following hyperparameters can be specified: @@ -29,4 +29,3 @@ Furthermore, the user can specify a custom grid length via `grid_len`. _Reference_: Carlucci, Fabio M., et al. "Domain generalization by solving jigsaw puzzles." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. - diff --git a/docs/docMatchDG.md b/docs/docMatchDG.md index 8fb0495e7..96eea99f1 100644 --- a/docs/docMatchDG.md +++ b/docs/docMatchDG.md @@ -9,28 +9,28 @@ The authors of the paper motivate their approach by looking at the data-generati
- +
Figure 1: Structural causal model for the data-generating process. Observed variables are shaded; dashed arrows denote correlated nodes. Object may not be observed. (Image source: Figure 2 of Domain Generalization using Causal Matching https://arxiv.org/pdf/2006.07500.pdf)
## Network -Before defining the network, one needs to define three sets: -- $\mathcal{X}$: image space with $x \in \mathcal{X}$ +Before defining the network, one needs to define three sets: +- $\mathcal{X}$: image space with $x \in \mathcal{X}$ - $\mathcal{C}$: causal feature space with $x_C \in \mathcal{C}$ -- $\mathcal{Y}$: label space with $y \in \mathcal{Y}$ +- $\mathcal{Y}$: label space with $y \in \mathcal{Y}$ -For the classification the goal is to classify an object only based on its causal features $x_C$, hence we define a network $h: \mathcal{C} \rightarrow \mathcal{Y}$. Since $x_C$ for an image $x$ is unknown, one needs to learn a representation function $\phi: \mathcal{X} \rightarrow \mathcal{C}$. By assumption for two images $x_j^{(d)}$ and $x_k^{(d')}$ of the same class, but from different domains $\text{ dist}\left(\phi(x_j^{(d)}), \phi(x_k^{(d')})\right)$ is small to enforce that the features in $\phi(x) \in \mathcal{C}$ are affected by the associated object and not the domain. This motivates the definition of a match function $\Omega: \mathcal{X} \times \mathcal{X} \rightarrow \{0, 1\}$, +For the classification the goal is to classify an object only based on its causal features $x_C$, hence we define a network $h: \mathcal{C} \rightarrow \mathcal{Y}$. Since $x_C$ for an image $x$ is unknown, one needs to learn a representation function $\phi: \mathcal{X} \rightarrow \mathcal{C}$. By assumption for two images $x_j^{(d)}$ and $x_k^{(d')}$ of the same class, but from different domains $\text{ dist}\left(\phi(x_j^{(d)}), \phi(x_k^{(d')})\right)$ is small to enforce that the features in $\phi(x) \in \mathcal{C}$ are affected by the associated object and not the domain. This motivates the definition of a match function $\Omega: \mathcal{X} \times \mathcal{X} \rightarrow \{0, 1\}$, $$ \Omega(x_j, x_k) = \begin{cases} 1 \quad & \text{$x_j$ and $x_k$ correspond to the same object} \\ 0 & \text{otherwise} -\end{cases} +\end{cases} $$ -by using +by using $$ \sum_{\substack{\Omega(x_j, x_k) = 1,\\ d \neq d'}} \text{dist}\left(\phi(x_j^{(d)}), \phi(x_k^{(d')})\right) = 0. @@ -38,7 +38,7 @@ $$ Together the networks form the desired classifier $f = h \circ \phi : \mathcal{X} \rightarrow \mathcal{Y}$. - + ## Training **Initialisation:** first of all match pairs of same-class data points from different domains are constructed. Given a data point, another data point with the same label from a different domain is selected randomly. The matching across domains is done relative to a base domain, which is chosen as the domain with the highest number of samples for that class. This leads to a matched data matrix $\mathcal{M}$ of size $(N', K)$ with $N'$ sum of the size of base domains over all classes and $K$ number ob domains. @@ -60,7 +60,7 @@ $$ \underset{h, \phi}{\text{arg min}} ~ \sum_{d \in D} \sum_{i=1}^{n_d} ~ l\left(h(\phi(x_i^{(d)})), y_i^{(d)}\right) + \gamma_{\text{reg}} \sum_{\substack{\Omega(x_j, x_k) = 1,\\ d \neq d'}} \text{dist}\left(\phi(x_j^{(d)}), \phi(x_k^{(d')})\right). $$ -The training of $h$ and $\phi$ is performed from scratch. The trained network $\phi^*$ from phase 1 is only used to update the matched data matrix using yielding $\Omega$. +The training of $h$ and $\phi$ is performed from scratch. The trained network $\phi^*$ from phase 1 is only used to update the matched data matrix using yielding $\Omega$. --- diff --git a/docs/doc_MNIST_classification.md b/docs/doc_MNIST_classification.md index 7cf9d65b2..a979b66f3 100644 --- a/docs/doc_MNIST_classification.md +++ b/docs/doc_MNIST_classification.md @@ -1,10 +1,10 @@ # colored MNIST classification -We include in the DomainLab package colored verion of MNIST where the color corresponds to the domain and digit corresponds to the semantic concept that we want to classify. +We include in the DomainLab package colored verion of MNIST where the color corresponds to the domain and digit corresponds to the semantic concept that we want to classify. ## colored MNIST dataset -We provide 10 different colored version of the MNIST dataset with numbers 0 to 9 as 10 different domains. The digit and background are colored differently, thus a domain correspond to a 2-color combination. -An extraction of digit 0 to 9 from domain 0 is shown in Figure 1. +We provide 10 different colored version of the MNIST dataset with numbers 0 to 9 as 10 different domains. The digit and background are colored differently, thus a domain correspond to a 2-color combination. +An extraction of digit 0 to 9 from domain 0 is shown in Figure 1.
digits 0 - 9: @@ -21,7 +21,7 @@ digits 0 - 9: > + <> # ... you may like to add more shared samples here like: # gamma_y, gamma_d, zy_dim, zd_dim @@ -109,13 +109,13 @@ Shared params: Task_Diva_Dial: # set the method to be used, if model is skipped the Task will not be executed - model: diva + model: diva # select a trainer to be used, if trainer is skipped adam is used # options: "dial" or "mldg" trainer: dial - - # Here we can also set task specific hyperparameters + + # Here we can also set task specific hyperparameters # which shall be fixed among all experiments. # f not set, the default values will be used. zd_dim: 32 @@ -133,11 +133,11 @@ Task_Diva_Dial: # define task specific hyperparameter sampling hyperparameters: <> - - # add constraints for your sampled hyperparameters, + + # add constraints for your sampled hyperparameters, # by using theire name in a python expression. - # You can use all hyperparameters defined in the hyperparameter section of - # the current task and the shared hyperparameters specified in the shared + # You can use all hyperparameters defined in the hyperparameter section of + # the current task and the shared hyperparameters specified in the shared # section of the current task constraints: - 'zx_dim <= zy_dim' @@ -161,14 +161,14 @@ For filling in the sampling description for the into the `Shared params` and the 1. uniform samples in the interval [min, max] ```yaml tau: # name of the hyperparameter - min: 0.01 + min: 0.01 max: 1 distribution: uniform # name of the distribution ##### for grid search ##### num: 3 # number of grid points created for this hyperparameter ``` -2. loguniform samples in the interval [min, max]. This is usefull if the interval spans over multiple magnitudes. +2. loguniform samples in the interval [min, max]. This is usefull if the interval spans over multiple magnitudes. ```yaml gamma_y: # name of the hyperparameter min: 1e4 @@ -182,14 +182,14 @@ gamma_y: # name of the hyperparameter 1. normal samples with mean and standard deviation ```yaml pperm: # name of the hyperparameter - mean: 0.5 + mean: 0.5 std: 0.2 distribution: normal # name of the distribution ##### for grid search ##### num: 3 # number of grid points created for this hyperparameter ``` -2. lognormal samples with mean and standard deviation. This is usefull if the interval spans over multiple magnitudes. +2. lognormal samples with mean and standard deviation. This is usefull if the interval spans over multiple magnitudes. ```yaml gamma_y: # name of the hyperparameter mean: 1e5 @@ -205,7 +205,7 @@ choose the values of the hyperparameter from a predefined list. If one uses grid ```yaml nperm: # name of the hyperparameter distribution: categorical # name of the distribution - datatype: int + datatype: int values: # concrete values to choose from - 30 - 31 @@ -250,15 +250,15 @@ it is possible to have all sorts of combinations: 1. a task which includes shared and task specific sampled hyperparameters ```yaml Task_Name: - model: ... + model: ... ... - # specify sections from the Shared params section + # specify sections from the Shared params section shared: - ... # specify task specific hyperparameter sampling hyperparameters: - ... + ... # add the constraints to the hperparameters section constraints: - '...' # constraints using params from the hyperparameters and the shared section @@ -267,7 +267,7 @@ Task_Name: 2. Only task specific sampled hyperparameters ```yaml Task_Name: - model: ... + model: ... ... # specify task specific hyperparameter sampling @@ -281,10 +281,10 @@ Task_Name: 3. Only shared sampled hyperparamters ```yaml Task_Name: - model: ... + model: ... ... - # specify sections from the Shared params section + # specify sections from the Shared params section shared: - ... # add the constraints as a standalone section to the task @@ -295,6 +295,6 @@ Task_Name: 4. No hyperparameter sampling. All Hyperparameters are either fixed to a user defined value or to the default value. No hyperparameter samples indicates no constraints. ```yaml Task_Name: - model: ... + model: ... ... ``` diff --git a/docs/doc_diva.md b/docs/doc_diva.md index b02d0dc97..b083e7fca 100644 --- a/docs/doc_diva.md +++ b/docs/doc_diva.md @@ -2,21 +2,21 @@ ## Domain Invariant Variational Autoencoders DIVA addresses the domain generalization problem with a variational autoencoder -with three latent variables, using three independent encoders. +with three latent variables, using three independent encoders. By encouraging the network to store each the domain, class and residual features in one of the latent spaces, the class-specific information -is disentangled. +is disentangled. In order to obtain marginally independent latent variables, the densities of the domain and class latent spaces are conditioned on the domain and the class, respectively. These densities are then parameterized by learnable parameters. During training, all three latent variables are fed into a single decoder -reconstructing the input image. +reconstructing the input image. Additionally, two classifiers are trained, predicting the domain and class label from the respective latent variable. This leads to an overall large network. However, during inference only the class encoder and classifier -are used. +are used. DIVA can improve the classification accuracy also in a semi-supervised setting, where class labels are missing for some data or domains. This is an advantage, as prediction @@ -30,7 +30,7 @@ decreased performance. ### Model parameters The following hyperparameters can be specified: -- `zd_dim`: size of latent space for domain-specific information +- `zd_dim`: size of latent space for domain-specific information - `zx_dim`: size of latent space for residual variance - `zy_dim`: size of latent space for class-specific information - `gamma_y`: multiplier for y classifier ($\alpha_y$ of eq. (2) in paper below) diff --git a/docs/doc_examples.md b/docs/doc_examples.md index da5812c4a..21d0b2eb2 100755 --- a/docs/doc_examples.md +++ b/docs/doc_examples.md @@ -171,7 +171,7 @@ python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=diva --nn ### Set hyper-parameters for trainer as well ```shell python main_out.py --te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=diva --nname=conv_bn_pool_2 --nname_dom=conv_bn_pool_2 --gamma_y=7e5 --gamma_d=1e5 --trainer=dial --dial_steps_perturb=1 -``` +``` ## Meta Learning Domain Generalization ```shell diff --git a/docs/doc_install.md b/docs/doc_install.md index e8c077d68..c819f6e50 100644 --- a/docs/doc_install.md +++ b/docs/doc_install.md @@ -4,7 +4,7 @@ `conda create --name domainlab_py39 python=3.9` -then +then `conda activate domainlab_py39` @@ -15,13 +15,13 @@ Suppose you have cloned the repository and have changed directory to the cloned ```norun pip install -r requirements.txt ``` -then +then `python setup.py install` #### Dependencies management - [python-poetry](https://python-poetry.org/) and use the configuration file `pyproject.toml` in this repository. - + ### Install Release -It is strongly recommended to create a virtual environment first, then +It is strongly recommended to create a virtual environment first, then - Install via `pip install domainlab` diff --git a/docs/doc_output.md b/docs/doc_output.md index 61c67e17a..345bb79f0 100644 --- a/docs/doc_output.md +++ b/docs/doc_output.md @@ -4,16 +4,16 @@ By default, this package generates outputs into a folder `zoutput` relative to t The output structure is something similar to below. ([] means the folder might or might not exist, texts inside () are comments) -``` +```text zoutput/ ├── aggrsts (aggregation of results) │ ├── task1_test_domain1_tagName.csv │ ├── task2_test_domain3_tagName.csv -│ -│ +│ +│ ├── [gen] (counterfactual image generation, only exist for generative models with "--gen" specified) │ ├── [task1_test_domain1] -│ +│ └── saved_models (persisted pytorch model) ├── task1_algo1_git-commit-hashtag1_seed_1_instance_wise_predictions.txt (instance wise prediction of the model) ├── [task1_algo1_git-commit-hashtag1_seed_1.model] (only exist if with command line argument "--keep_model") diff --git a/docs/doc_tasks.md b/docs/doc_tasks.md index 2eec7e460..3d3f0d73d 100644 --- a/docs/doc_tasks.md +++ b/docs/doc_tasks.md @@ -2,10 +2,10 @@ The package offers various ways to specify a domain generalization task (where to find the data, which domain to use as training, which to test) according to user's need. -For all thress ways covered below, the user has to prepare a python file to feed via argument `--tpath` (means task path) into DomainLab. We provide example python files in our repository [see all examples here for specifying domain generalization task](https://github.com/marrlab/DomainLab/tree/master/examples/tasks) so that the user could follow the example to create their own domain generalization task specification. We provide inline comment to explain what each line is doing, as well as below in this documentation. +For all thress ways covered below, the user has to prepare a python file to feed via argument `--tpath` (means task path) into DomainLab. We provide example python files in our repository [see all examples here for specifying domain generalization task](https://github.com/marrlab/DomainLab/tree/master/examples/tasks) so that the user could follow the example to create their own domain generalization task specification. We provide inline comment to explain what each line is doing, as well as below in this documentation. ## Possibility 1: Specify train and test domain dataset directly -The most straightforward way to specify a domain generalization task is, if you have already a [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) class for each domain: you could define a dictionary with the key being name for domain, and the value being the PyTorch Dataset you created corresponding to that domain (train and validation or only training) +The most straightforward way to specify a domain generalization task is, if you have already a [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) class for each domain: you could define a dictionary with the key being name for domain, and the value being the PyTorch Dataset you created corresponding to that domain (train and validation or only training) [See an example python file here](https://github.com/marrlab/DomainLab/blob/master/examples/tasks/task_dset_custom.py) To train a ERM (Emperical Risk Minimization) network on this task: @@ -28,24 +28,24 @@ python main_out.py --te_d=sketch --tpath=examples/tasks/demo_task_path_list_smal In this mode, we assume there are structured folders where each folder contains all data from one domain, and each domain folder contains subfolders corresponding to different classes. See examples below. ### Data organization -To give an example, suppose we have a classification task to classify between car, dog, human, chair and bird and there are 3 data sources (domains) with folder name "folder_a", "folder_b" and "folder_c" respectively as shown below. +To give an example, suppose we have a classification task to classify between car, dog, human, chair and bird and there are 3 data sources (domains) with folder name "folder_a", "folder_b" and "folder_c" respectively as shown below. In each folder, the images are organized in sub-folders by their class. For example, "/path/to/3rd_domain/folder_c/dog" folder contains all the images of class "dog" from the 3rd domain. -It might be the case that across the different data sources the same class is named differently. For example, in the 1st data source, the class dog is stored in sub-folder named +It might be the case that across the different data sources the same class is named differently. For example, in the 1st data source, the class dog is stored in sub-folder named "hund", in the 2nd data source, the dog is stored in sub-folder named "husky" and in the 3rd data source, the dog is stored in sub-folder named "dog". It might also be the case that some classes exist in one data source but does not exist in another data source. For example, folder "/path/to/2nd_domain/folder_b" does not have a sub-folder for class "human". Folder structure of the 1st domain: -``` +```text ── /path/to/1st_domain/folder_a ├── auto ├── hund ├── mensch ├── stuhl └── vogel - + ``` Folder structure of the 2nd domain: @@ -56,7 +56,7 @@ Folder structure of the 2nd domain: ├── sit └── husky ``` -Folder structure of the 3rd domain: +Folder structure of the 3rd domain: ``` ── /path/to/3rd_domain/folder_c @@ -146,7 +146,7 @@ of domain information so only a unique transform (composition) is allowed. isize: domainlab.tasks.ImSize(image channel, image height, image width) -dict_domain2imgroot: a python dictionary with keys as user specified domain names and values +dict_domain2imgroot: a python dictionary with keys as user specified domain names and values as the absolute path to each domain's data. taskna: user defined task name diff --git a/docs/doc_usage_cmd.md b/docs/doc_usage_cmd.md index 40250d7b5..b6320328b 100644 --- a/docs/doc_usage_cmd.md +++ b/docs/doc_usage_cmd.md @@ -6,7 +6,7 @@ Suppose you have cloned the repository and have the dependencies ready, change d To train a domain invariant model on the vlcs_mini task ```shell -python main_out.py --te_d=caltech --tpath=examples/tasks/task_vlcs.py --config=examples/yaml/demo_config_single_run_diva.yaml +python main_out.py --te_d=caltech --tpath=examples/tasks/task_vlcs.py --config=examples/yaml/demo_config_single_run_diva.yaml ``` where `--tpath` specifies the path of a user specified python file which defines the domain generalization task [see here](../examples/tasks/task_vlcs.py), `--te_d` specifies the test domain name (or index starting from 0), `--config` specifies the configurations of the domain generalization algorithms, [see here](../examples/yaml/demo_config_single_run_diva.yaml) diff --git a/docs/figs/tikz_hduva.svg b/docs/figs/tikz_hduva.svg index 159e7eb91..e1149b73c 100644 --- a/docs/figs/tikz_hduva.svg +++ b/docs/figs/tikz_hduva.svg @@ -531,4 +531,4 @@ transform="translate(-110.686)" id="g514" /> \ No newline at end of file + id="g516" /> diff --git a/docs/index.html b/docs/index.html index 3af30f6c1..fd99d5a8c 100644 --- a/docs/index.html +++ b/docs/index.html @@ -19,7 +19,7 @@ - + @@ -35,6 +35,6 @@ font-family: "Roboto Mono", "Courier New", Courier, monospace } - Welcome to domainlab’s documentation! — libdg documentation + Welcome to domainlab’s documentation! — libdg documentation diff --git a/docs/libDG.svg b/docs/libDG.svg index 30d8c50bc..bfe2b405c 100644 --- a/docs/libDG.svg +++ b/docs/libDG.svg @@ -1 +1 @@ -tasksalgostrainersobservermodel_selectiondatasetsmodelsexperimentTasktransformationsdatasets (training domains)datasets (test domains)TaskFolderTaskPathListAlgoBuilderbuild_model()build_trainer()build_observer()model_specific_task_processing()ConcreteAlgoBuilderbuild_model()build_trainer()build_observer()model_specific_task_processing()Trainermodelobservertaskupdate_regularization_weight()update_model_parameter()calculate_loss()Observertrainer.model.calculate_metric()ModelSelearly_stopDataset__get__item__()YMLConfigurationModelextract_feature()calculate_task_loss()calculate_regularization_loss()calculate_metric()ModelBuilderbuild_model_componet()ConcreteModelcalculate_regularization_loss()extract_feature()ConcreteModelBuildernetworksExperimenttasktrainerexecute()TaskHandlerAlgoHandlerUserInputtaskhyperparametersread_yml_for_configuration()command_line_arguments()benchmarkrun_experiment()plot_results()read_yml_for_configuration() \ No newline at end of file +tasksalgostrainersobservermodel_selectiondatasetsmodelsexperimentTasktransformationsdatasets (training domains)datasets (test domains)TaskFolderTaskPathListAlgoBuilderbuild_model()build_trainer()build_observer()model_specific_task_processing()ConcreteAlgoBuilderbuild_model()build_trainer()build_observer()model_specific_task_processing()Trainermodelobservertaskupdate_regularization_weight()update_model_parameter()calculate_loss()Observertrainer.model.calculate_metric()ModelSelearly_stopDataset__get__item__()YMLConfigurationModelextract_feature()calculate_task_loss()calculate_regularization_loss()calculate_metric()ModelBuilderbuild_model_componet()ConcreteModelcalculate_regularization_loss()extract_feature()ConcreteModelBuildernetworksExperimenttasktrainerexecute()TaskHandlerAlgoHandlerUserInputtaskhyperparametersread_yml_for_configuration()command_line_arguments()benchmarkrun_experiment()plot_results()read_yml_for_configuration() diff --git a/domainlab/__init__.py b/domainlab/__init__.py index 97011d380..3837e899b 100644 --- a/domainlab/__init__.py +++ b/domainlab/__init__.py @@ -4,17 +4,18 @@ __docformat__ = "restructuredtext" import torch - g_inst_component_loss_agg = torch.sum g_tensor_batch_agg = torch.sum g_list_loss_agg = sum + def g_list_model_penalized_reg_agg(list_penalized_reg): """ aggregate along the list, but do not diminish the batch structure of the tensor """ return torch.stack(list_penalized_reg, dim=0).sum(dim=0) + g_str_cross_entropy_agg = "none" # component loss refers to aggregation of pixel loss, digit of KL divergences loss # instance loss currently use torch.sum, which is the same effect as torch.mean, the diff --git a/domainlab/algos/a_algo_builder.py b/domainlab/algos/a_algo_builder.py index 504713d46..f51184f3b 100644 --- a/domainlab/algos/a_algo_builder.py +++ b/domainlab/algos/a_algo_builder.py @@ -2,6 +2,7 @@ parent class for combing model, trainer, task, observer """ import abc + from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler from domainlab.utils.logger import Logger @@ -10,6 +11,7 @@ class NodeAlgoBuilder(AbstractChainNodeHandler): """ Base class for Algorithm Builder """ + na_prefix = "NodeAlgoBuilder" @property @@ -22,8 +24,11 @@ def name(self): na_class = type(self).__name__ if na_class[:len_prefix] != na_prefix: raise RuntimeError( - "algorithm builder node class must start with ", na_prefix, - "the current class is named: ", na_class) + "algorithm builder node class must start with ", + na_prefix, + "the current class is named: ", + na_class, + ) return type(self).__name__[len_prefix:].lower() def is_myjob(self, request): @@ -39,6 +44,14 @@ def init_business(self, exp): """ def extend(self, node): + """ + Extends the current algorithm builder with a new node. + + This method updates the builder by setting the `next_model` attribute to the specified node. + + Args: + node: The node to be added to the algorithm builder. + """ self.next_model = node def init_next_model(self, model, exp): diff --git a/domainlab/algos/builder_api_model.py b/domainlab/algos/builder_api_model.py index 6b4cc5430..07c087e3b 100644 --- a/domainlab/algos/builder_api_model.py +++ b/domainlab/algos/builder_api_model.py @@ -2,10 +2,10 @@ build algorithm from API coded model with custom backbone """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter -from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.observers.b_obvisitor import ObVisitor +from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter from domainlab.utils.utils_cuda import get_device @@ -13,6 +13,7 @@ class NodeAlgoBuilderAPIModel(NodeAlgoBuilder): """ build algorithm from API coded model with custom backbone """ + def init_business(self, exp): """ return trainer, model, observer @@ -20,6 +21,6 @@ def init_business(self, exp): args = exp.args device = get_device(args) model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es)) - observer = ObVisitor(model_sel) + observer = ObVisitor(model_sel) trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler") return trainer, None, observer, device diff --git a/domainlab/algos/builder_custom.py b/domainlab/algos/builder_custom.py index 84236e9b7..3d4d48e6e 100644 --- a/domainlab/algos/builder_custom.py +++ b/domainlab/algos/builder_custom.py @@ -1,6 +1,6 @@ from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter @@ -11,6 +11,7 @@ def make_basic_trainer(class_name_model): """make_basic_trainer. :param class_name_model: """ + class NodeAlgoBuilderCustom(NodeAlgoBuilder): """NodeAlgoBuilderCustom.""" @@ -43,20 +44,21 @@ def _set_args(self, args, val_arg_na, prefix, argname): :param argname: nname_argna2val or "npath_argna2val", hard coded """ if getattr(args, argname) is None: - setattr(args, prefix+val_arg_na, None) + setattr(args, prefix + val_arg_na, None) return list_args = getattr(args, argname) ind = list_args.index(val_arg_na) - if ind+1 >= len(list_args): # list of args always even length - raise RuntimeError("\n nname_argna2val or npath_argna2val should \ + if ind + 1 >= len(list_args): # list of args always even length + raise RuntimeError( + f"\n nname_argna2val or npath_argna2val should \ \n always be specified in pairs instead of \ odd number:\ - \n %s" % ( - str(list_args))) - val = list_args[ind+1] + \n {str(list_args)}" + ) + val = list_args[ind + 1] # add attributes to namespaces args, the attributes are provided by # user in the custom model file - setattr(args, prefix+val_arg_na, val) + setattr(args, prefix + val_arg_na, val) def set_nets_from_dictionary(self, args, task, model): """set_nets_from_dictionary. @@ -67,23 +69,28 @@ def set_nets_from_dictionary(self, args, task, model): –apath=examples/algos/demo_custom_model.py –model=custom –nname_argna2val net1 –nname_argna2val alexnet """ - for key_module_na, val_arg_na in \ - model.dict_net_module_na2arg_na.items(): + for key_module_na, val_arg_na in model.dict_net_module_na2arg_na.items(): # - if args.nname_argna2val is None and \ - args.npath_argna2val is None: - raise RuntimeError("either specify nname_argna2val or \ - npath_argna2val") + if args.nname_argna2val is None and args.npath_argna2val is None: + raise RuntimeError( + "either specify nname_argna2val or \ + npath_argna2val" + ) self._set_args(args, val_arg_na, "nname", "nname_argna2val") self._set_args(args, val_arg_na, "npath", "npath_argna2val") # builder = FeatExtractNNBuilderChainNodeGetter( - args, arg_name_of_net="nname"+val_arg_na, - arg_path_of_net="npath"+val_arg_na)() + args, + arg_name_of_net="nname" + val_arg_na, + arg_path_of_net="npath" + val_arg_na, + )() net = builder.init_business( - flag_pretrain=True, dim_out=task.dim_y, - remove_last_layer=False, args=args, - isize=(task.isize.i_c, task.isize.i_h, task.isize.i_w)) + flag_pretrain=True, + dim_out=task.dim_y, + remove_last_layer=False, + args=args, + isize=(task.isize.i_c, task.isize.i_h, task.isize.i_w), + ) model.add_module("%s" % (key_module_na), net) def init_business(self, exp): @@ -101,4 +108,5 @@ def init_business(self, exp): trainer = self.get_trainer(args) trainer.init_business(model, task, observer, device, args) return trainer, model, observer, device + return NodeAlgoBuilderCustom diff --git a/domainlab/algos/builder_dann.py b/domainlab/algos/builder_dann.py index bb6bf65b0..34da86926 100644 --- a/domainlab/algos/builder_dann.py +++ b/domainlab/algos/builder_dann.py @@ -2,8 +2,8 @@ builder for Domain Adversarial Neural Network: accept different training scheme """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential @@ -19,6 +19,7 @@ class NodeAlgoBuilderDANN(NodeAlgoBuilder): """ NodeAlgoBuilderDANN """ + def init_business(self, exp): """ return trainer, model, observer @@ -33,38 +34,44 @@ def init_business(self, exp): observer = ObVisitorCleanUp(observer) builder = FeatExtractNNBuilderChainNodeGetter( - args, arg_name_of_net="nname", - arg_path_of_net="npath")() # request, @FIXME, constant string + args, arg_name_of_net="nname", arg_path_of_net="npath" + )() # request, @FIXME, constant string net_encoder = builder.init_business( - flag_pretrain=True, dim_out=task.dim_y, - remove_last_layer=False, args=args, - isize=(task.isize.i_c, task.isize.i_w, task.isize.i_h)) + flag_pretrain=True, + dim_out=task.dim_y, + remove_last_layer=False, + args=args, + isize=(task.isize.i_c, task.isize.i_w, task.isize.i_h), + ) - dim_feat = get_flat_dim(net_encoder, - task.isize.i_c, - task.isize.i_h, - task.isize.i_w) + dim_feat = get_flat_dim( + net_encoder, task.isize.i_c, task.isize.i_h, task.isize.i_w + ) net_classifier = ClassifDropoutReluLinear(dim_feat, task.dim_y) net_discriminator = self.reset_aux_net(net_encoder) - model = mk_dann()(list_str_y=task.list_str_y, - list_d_tr=task.list_domain_tr, - alpha=args.gamma_reg, - net_encoder=net_encoder, - net_classifier=net_classifier, - net_discriminator=net_discriminator, - builder=self) + model = mk_dann()( + list_str_y=task.list_str_y, + list_d_tr=task.list_domain_tr, + alpha=args.gamma_reg, + net_encoder=net_encoder, + net_classifier=net_classifier, + net_discriminator=net_discriminator, + builder=self, + ) model = self.init_next_model(model, exp) trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler") trainer.init_business(model, task, observer, device, args) if trainer.name == "hyperscheduler": - trainer.set_scheduler(HyperSchedulerWarmupExponential, - total_steps=trainer.num_batches*args.warmup, - flag_update_epoch=False, - flag_update_batch=True) + trainer.set_scheduler( + HyperSchedulerWarmupExponential, + total_steps=trainer.num_batches * args.warmup, + flag_update_epoch=False, + flag_update_batch=True, + ) return trainer, model, observer, device def reset_aux_net(self, net_encoder): @@ -72,10 +79,13 @@ def reset_aux_net(self, net_encoder): reset auxilliary neural network from task note that net_encoder can also be a method like extract_semantic_feat """ - dim_feat = get_flat_dim(net_encoder, - self._task.isize.i_c, - self._task.isize.i_h, - self._task.isize.i_w) + dim_feat = get_flat_dim( + net_encoder, + self._task.isize.i_c, + self._task.isize.i_h, + self._task.isize.i_w, + ) net_discriminator = ClassifDropoutReluLinear( - dim_feat, len(self._task.list_domain_tr)) + dim_feat, len(self._task.list_domain_tr) + ) return net_discriminator diff --git a/domainlab/algos/builder_diva.py b/domainlab/algos/builder_diva.py index ca82c681d..73ede5902 100644 --- a/domainlab/algos/builder_diva.py +++ b/domainlab/algos/builder_diva.py @@ -2,13 +2,12 @@ Builder pattern to build different component for experiment with DIVA """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp from domainlab.algos.observers.c_obvisitor_gen import ObVisitorGen from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter - from domainlab.compos.pcr.request import RequestVAEBuilderCHW from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter from domainlab.models.model_diva import mk_diva @@ -19,6 +18,7 @@ class NodeAlgoBuilderDIVA(NodeAlgoBuilder): """ Builder pattern to build different component for experiment with DIVA """ + def get_trainer(self, args): """ chain of responsibility pattern for fetching trainer from dictionary @@ -32,21 +32,22 @@ def init_business(self, exp): """ task = exp.task args = exp.args - request = RequestVAEBuilderCHW( - task.isize.c, task.isize.h, task.isize.w, args) + request = RequestVAEBuilderCHW(task.isize.c, task.isize.h, task.isize.w, args) node = VAEChainNodeGetter(request)() task.get_list_domains_tr_te(args.tr_d, args.te_d) - model = mk_diva()(node, - zd_dim=args.zd_dim, - zy_dim=args.zy_dim, - zx_dim=args.zx_dim, - list_str_y=task.list_str_y, - list_d_tr=task.list_domain_tr, - gamma_d=args.gamma_d, - gamma_y=args.gamma_y, - beta_x=args.beta_x, - beta_y=args.beta_y, - beta_d=args.beta_d) + model = mk_diva()( + node, + zd_dim=args.zd_dim, + zy_dim=args.zy_dim, + zx_dim=args.zx_dim, + list_str_y=task.list_str_y, + list_d_tr=task.list_domain_tr, + gamma_d=args.gamma_d, + gamma_y=args.gamma_y, + beta_x=args.beta_x, + beta_y=args.beta_y, + beta_d=args.beta_d, + ) device = get_device(args) model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es)) if not args.gen: diff --git a/domainlab/algos/builder_erm.py b/domainlab/algos/builder_erm.py index 253339cbb..b8e3d7c25 100644 --- a/domainlab/algos/builder_erm.py +++ b/domainlab/algos/builder_erm.py @@ -1,12 +1,12 @@ """ builder for erm """ -from domainlab.algos.utils import split_net_feat_last from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter +from domainlab.algos.utils import split_net_feat_last from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter from domainlab.models.model_erm import mk_erm from domainlab.utils.utils_cuda import get_device @@ -16,6 +16,7 @@ class NodeAlgoBuilderERM(NodeAlgoBuilder): """ builder for erm """ + def init_business(self, exp): """ return trainer, model, observer @@ -27,19 +28,24 @@ def init_business(self, exp): observer = ObVisitor(model_sel) builder = FeatExtractNNBuilderChainNodeGetter( - args, arg_name_of_net="nname", - arg_path_of_net="npath")() # request, # @FIXME, constant string + args, arg_name_of_net="nname", arg_path_of_net="npath" + )() # request, # @FIXME, constant string - net = builder.init_business(flag_pretrain=True, dim_out=task.dim_y, - remove_last_layer=False, args=args, - isize=(task.isize.i_c, task.isize.i_h, task.isize.i_w)) + net = builder.init_business( + flag_pretrain=True, + dim_out=task.dim_y, + remove_last_layer=False, + args=args, + isize=(task.isize.i_c, task.isize.i_h, task.isize.i_w), + ) _, _ = split_net_feat_last(net) model = mk_erm()( - net=net, - # net_feat=net_invar_feat, net_classifier=net_classifier, - list_str_y=task.list_str_y) + net=net, + # net_feat=net_invar_feat, net_classifier=net_classifier, + list_str_y=task.list_str_y, + ) model = self.init_next_model(model, exp) trainer = TrainerChainNodeGetter(args.trainer)(default="basic") # trainer.init_business(model, task, observer, device, args) diff --git a/domainlab/algos/builder_hduva.py b/domainlab/algos/builder_hduva.py index 22d3804d7..d8ac07f9e 100644 --- a/domainlab/algos/builder_hduva.py +++ b/domainlab/algos/builder_hduva.py @@ -2,8 +2,8 @@ build hduva model, get trainer from cmd arguments """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter @@ -17,6 +17,7 @@ class NodeAlgoBuilderHDUVA(NodeAlgoBuilder): """ NodeAlgoBuilderHDUVA """ + def init_business(self, exp): """ return trainer, model, observer @@ -24,27 +25,27 @@ def init_business(self, exp): task = exp.task args = exp.args task.get_list_domains_tr_te(args.tr_d, args.te_d) - request = RequestVAEBuilderCHW( - task.isize.c, task.isize.h, task.isize.w, args) + request = RequestVAEBuilderCHW(task.isize.c, task.isize.h, task.isize.w, args) device = get_device(args) node = VAEChainNodeGetter(request, args.topic_dim)() - model = mk_hduva()(node, - zd_dim=args.zd_dim, - zy_dim=args.zy_dim, - zx_dim=args.zx_dim, - device=device, - topic_dim=args.topic_dim, - list_str_y=task.list_str_y, - gamma_d=args.gamma_d, - gamma_y=args.gamma_y, - beta_t=args.beta_t, - beta_x=args.beta_x, - beta_y=args.beta_y, - beta_d=args.beta_d) + model = mk_hduva()( + node, + zd_dim=args.zd_dim, + zy_dim=args.zy_dim, + zx_dim=args.zx_dim, + device=device, + topic_dim=args.topic_dim, + list_str_y=task.list_str_y, + gamma_d=args.gamma_d, + gamma_y=args.gamma_y, + beta_t=args.beta_t, + beta_x=args.beta_x, + beta_y=args.beta_y, + beta_d=args.beta_d, + ) model = self.init_next_model(model, exp) model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es)) - observer = ObVisitorCleanUp( - ObVisitor(model_sel)) + observer = ObVisitorCleanUp(ObVisitor(model_sel)) trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler") trainer.init_business(model, task, observer, device, args) return trainer, model, observer, device diff --git a/domainlab/algos/builder_jigen1.py b/domainlab/algos/builder_jigen1.py index 6e41bd3cd..edd083bad 100644 --- a/domainlab/algos/builder_jigen1.py +++ b/domainlab/algos/builder_jigen1.py @@ -2,25 +2,26 @@ builder for JiGen """ from domainlab.algos.a_algo_builder import NodeAlgoBuilder -from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp -from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupExponential +from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler from domainlab.algos.trainers.zoo_trainer import TrainerChainNodeGetter from domainlab.compos.nn_zoo.net_classif import ClassifDropoutReluLinear from domainlab.compos.utils_conv_get_flat_dim import get_flat_dim from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter +from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches from domainlab.models.model_jigen import mk_jigen from domainlab.utils.utils_cuda import get_device -from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches class NodeAlgoBuilderJiGen(NodeAlgoBuilder): """ NodeAlgoBuilderJiGen """ + def init_business(self, exp): """ return trainer, model, observer @@ -33,37 +34,40 @@ def init_business(self, exp): observer = ObVisitorCleanUp(observer) builder = FeatExtractNNBuilderChainNodeGetter( - args, arg_name_of_net="nname", - arg_path_of_net="npath")() # request, @FIXME, constant string + args, arg_name_of_net="nname", arg_path_of_net="npath" + )() # request, @FIXME, constant string net_encoder = builder.init_business( - flag_pretrain=True, dim_out=task.dim_y, - remove_last_layer=False, args=args, - isize=(task.isize.i_c, - task.isize.i_w, - task.isize.i_h)) + flag_pretrain=True, + dim_out=task.dim_y, + remove_last_layer=False, + args=args, + isize=(task.isize.i_c, task.isize.i_w, task.isize.i_h), + ) - dim_feat = get_flat_dim(net_encoder, - task.isize.i_c, - task.isize.i_h, - task.isize.i_w) + dim_feat = get_flat_dim( + net_encoder, task.isize.i_c, task.isize.i_h, task.isize.i_w + ) net_classifier = ClassifDropoutReluLinear(dim_feat, task.dim_y) # @FIXME: this seems to be the only difference w.r.t. builder_dann - net_classifier_perm = ClassifDropoutReluLinear( - dim_feat, args.nperm+1) - model = mk_jigen()(list_str_y=task.list_str_y, - coeff_reg=args.gamma_reg, - net_encoder=net_encoder, - net_classifier_class=net_classifier, - net_classifier_permutation=net_classifier_perm) + net_classifier_perm = ClassifDropoutReluLinear(dim_feat, args.nperm + 1) + model = mk_jigen()( + list_str_y=task.list_str_y, + coeff_reg=args.gamma_reg, + net_encoder=net_encoder, + net_classifier_class=net_classifier, + net_classifier_permutation=net_classifier_perm, + ) model = self.init_next_model(model, exp) trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler") trainer.init_business(model, task, observer, device, args) if isinstance(trainer, TrainerHyperScheduler): - trainer.set_scheduler(HyperSchedulerWarmupExponential, - total_steps=trainer.num_batches*args.warmup, - flag_update_epoch=False, - flag_update_batch=True) + trainer.set_scheduler( + HyperSchedulerWarmupExponential, + total_steps=trainer.num_batches * args.warmup, + flag_update_epoch=False, + flag_update_batch=True, + ) return trainer, model, observer, device diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index 5b91c7dd1..e2e63c993 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -9,6 +9,7 @@ class AMSel(metaclass=abc.ABCMeta): """ Abstract Model Selection """ + def __init__(self): """ trainer and tr_observer diff --git a/domainlab/algos/msels/c_msel_oracle.py b/domainlab/algos/msels/c_msel_oracle.py index aa70d96da..e232b1e78 100644 --- a/domainlab/algos/msels/c_msel_oracle.py +++ b/domainlab/algos/msels/c_msel_oracle.py @@ -10,6 +10,7 @@ class MSelOracleVisitor(AMSel): save best out-of-domain test acc model, but do not affect how the final model is selected """ + def __init__(self, msel=None): """ Decorator pattern @@ -23,8 +24,9 @@ def oracle_last_setpoint_sel_te_acc(self): """ last setpoint acc """ - if self.msel is not None and \ - hasattr(self.msel, "oracle_last_setpoint_sel_te_acc"): + if self.msel is not None and hasattr( + self.msel, "oracle_last_setpoint_sel_te_acc" + ): return self.msel.oracle_last_setpoint_sel_te_acc return -1 diff --git a/domainlab/algos/msels/c_msel_tr_loss.py b/domainlab/algos/msels/c_msel_tr_loss.py index 27fb285db..7ed3f1168 100644 --- a/domainlab/algos/msels/c_msel_tr_loss.py +++ b/domainlab/algos/msels/c_msel_tr_loss.py @@ -2,6 +2,7 @@ Model Selection should be decoupled from """ import math + from domainlab.algos.msels.a_model_sel import AMSel from domainlab.utils.logger import Logger @@ -11,6 +12,7 @@ class MSelTrLoss(AMSel): 1. Model selection using sum of loss across training domains 2. Visitor pattern to trainer """ + def __init__(self, max_es): super().__init__() # NOTE: super() must come first otherwise it will overwrite existing @@ -30,7 +32,7 @@ def update(self, clear_counter=False): """ if the best model should be updated """ - loss = self.trainer.epo_loss_tr # @FIXME + loss = self.trainer.epo_loss_tr # @FIXME assert loss is not None assert not math.isnan(loss) flag = True diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index cea3394cd..c1f2f5561 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -10,6 +10,7 @@ class MSelValPerf(MSelTrLoss): 1. Model selection using validation performance 2. Visitor pattern to trainer """ + def __init__(self, max_es): super().__init__(max_es) # construct self.tr_obs (observer) self.reset() @@ -47,8 +48,7 @@ def update(self, clear_counter=False): return super().update(clear_counter) metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel] if self.tr_obs.metric_te is not None: - metric_te_current = \ - self.tr_obs.metric_te[self.tr_obs.str_metric4msel] + metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] self._best_te_metric = max(self._best_te_metric, metric_te_current) if metric > self._best_val_acc: # update hat{model} @@ -57,18 +57,19 @@ def update(self, clear_counter=False): self._best_val_acc = metric self.es_c = 0 # restore counter if self.tr_obs.metric_te is not None: - metric_te_current = \ - self.tr_obs.metric_te[self.tr_obs.str_metric4msel] + metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] self._sel_model_te_acc = metric_te_current else: self.es_c += 1 logger = Logger.get_logger() logger.info(f"early stop counter: {self.es_c}") - logger.info(f"val acc:{self.tr_obs.metric_val['acc']}, " + - f"best validation acc: {self.best_val_acc}, " + - f"corresponding to test acc: \ - {self.sel_model_te_acc} / {self.best_te_metric}") + logger.info( + f"val acc:{self.tr_obs.metric_val['acc']}, " + + f"best validation acc: {self.best_val_acc}, " + + f"corresponding to test acc: \ + {self.sel_model_te_acc} / {self.best_te_metric}" + ) flag = False # do not update best model if clear_counter: logger.info("clearing counter") diff --git a/domainlab/algos/observers/a_observer.py b/domainlab/algos/observers/a_observer.py index dcb86d651..bdad57360 100644 --- a/domainlab/algos/observers/a_observer.py +++ b/domainlab/algos/observers/a_observer.py @@ -8,6 +8,7 @@ class AObVisitor(metaclass=abc.ABCMeta): """ Observer + Visitor pattern for model selection """ + def __init__(self): self.task = None self.device = None diff --git a/domainlab/algos/observers/b_obvisitor.py b/domainlab/algos/observers/b_obvisitor.py index 4051bf7b7..04231a917 100644 --- a/domainlab/algos/observers/b_obvisitor.py +++ b/domainlab/algos/observers/b_obvisitor.py @@ -15,6 +15,7 @@ class ObVisitor(AObVisitor): """ Observer + Visitor pattern for model selection """ + def __init__(self, model_sel): """ observer trainer @@ -40,16 +41,17 @@ def update(self, epoch): self.epo = epoch if epoch % self.epo_te == 0: logger.info("---- Training Domain: ") - self.host_trainer.model.cal_perf_metric( - self.loader_tr, self.device) + self.host_trainer.model.cal_perf_metric(self.loader_tr, self.device) if self.loader_val is not None: logger.info("---- Validation: ") self.metric_val = self.host_trainer.model.cal_perf_metric( - self.loader_val, self.device) + self.loader_val, self.device + ) if self.loader_te is not None: logger.info("---- Test Domain (oracle): ") metric_te = self.host_trainer.model.cal_perf_metric( - self.loader_te, self.device) + self.loader_te, self.device + ) self.metric_te = metric_te if self.model_sel.update(): logger.info("better model found") @@ -65,7 +67,7 @@ def accept(self, trainer): """ self.host_trainer = trainer self.model_sel.accept(trainer, self) - self.set_task(trainer.task, args=trainer.aconf, device=trainer.device) + self.set_task(trainer.task, args=trainer.aconf, device=trainer.device) self.perf_metric = self.host_trainer.model.create_perf_obj(self.task) def after_all(self): @@ -81,10 +83,12 @@ def after_all(self): # this can happen if loss is increasing, model never get selected logger = Logger.get_logger() logger.warning(err) - logger.warning("this error can occur if model selection criteria \ + logger.warning( + "this error can occur if model selection criteria \ is worsening, " - "model never get persisted, \ - no performance metric is reported") + "model never get persisted, \ + no performance metric is reported" + ) return model_ld = model_ld.to(self.device) model_ld.eval() @@ -128,22 +132,22 @@ def dump_prediction(self, model_ld, metric_te): model_ld to predict each instance """ flag_task_folder = isinstance( - self.host_trainer.task, NodeTaskFolderClassNaMismatch) - flag_task_path_list = isinstance( - self.host_trainer.task, NodeTaskPathListDummy) + self.host_trainer.task, NodeTaskFolderClassNaMismatch + ) + flag_task_path_list = isinstance(self.host_trainer.task, NodeTaskPathListDummy) if flag_task_folder or flag_task_path_list: - fname4model = self.host_trainer.model.visitor.model_path # pylint: disable=E1101 + fname4model = ( + self.host_trainer.model.visitor.model_path + ) # pylint: disable=E1101 file_prefix = os.path.splitext(fname4model)[0] # remove ".model" dir4preds = os.path.join(self.host_trainer.aconf.out, "saved_predicts") if not os.path.exists(dir4preds): os.mkdir(dir4preds) - file_prefix = os.path.join(dir4preds, - os.path.basename(file_prefix)) + file_prefix = os.path.join(dir4preds, os.path.basename(file_prefix)) file_name = file_prefix + "_instance_wise_predictions.txt" model_ld.pred2file( - self.loader_te, self.device, - filename=file_name, - metric_te=metric_te) + self.loader_te, self.device, filename=file_name, metric_te=metric_te + ) def clean_up(self): """ diff --git a/domainlab/algos/observers/c_obvisitor_cleanup.py b/domainlab/algos/observers/c_obvisitor_cleanup.py index cabb834e9..91ac53216 100644 --- a/domainlab/algos/observers/c_obvisitor_cleanup.py +++ b/domainlab/algos/observers/c_obvisitor_cleanup.py @@ -5,6 +5,7 @@ class ObVisitorCleanUp(AObVisitor): """ decorator of observer, instead of using if and else to decide clean up or not, we use decorator """ + def __init__(self, observer): super().__init__() self.observer = observer diff --git a/domainlab/algos/observers/c_obvisitor_gen.py b/domainlab/algos/observers/c_obvisitor_gen.py index a38c6a59d..45571b1b2 100644 --- a/domainlab/algos/observers/c_obvisitor_gen.py +++ b/domainlab/algos/observers/c_obvisitor_gen.py @@ -7,23 +7,40 @@ class ObVisitorGen(ObVisitor): """ For Generative Models """ + def after_all(self): super().after_all() logger = Logger.get_logger() logger.info("generating images for final model at last epoch") - fun_gen(subfolder_na=self.host_trainer.model.visitor.model_name+"final", - args=self.host_trainer.aconf, node=self.host_trainer.task, model=self.host_trainer.model, - device=self.device) + fun_gen( + subfolder_na=self.host_trainer.model.visitor.model_name + "final", + args=self.host_trainer.aconf, + node=self.host_trainer.task, + model=self.host_trainer.model, + device=self.device, + ) logger.info("generating images for oracle model") - model_or = self.host_trainer.model.load("oracle") # @FIXME: name "oracle is a strong dependency + model_or = self.host_trainer.model.load( + "oracle" + ) # @FIXME: name "oracle is a strong dependency model_or = model_or.to(self.device) model_or.eval() - fun_gen(subfolder_na=self.host_trainer.model.visitor.model_name+"oracle", - args=self.host_trainer.aconf, node=self.host_trainer.task, model=model_or, device=self.device) + fun_gen( + subfolder_na=self.host_trainer.model.visitor.model_name + "oracle", + args=self.host_trainer.aconf, + node=self.host_trainer.task, + model=model_or, + device=self.device, + ) logger.info("generating images for selected model") model_ld = self.host_trainer.model.load() model_ld = model_ld.to(self.device) model_ld.eval() - fun_gen(subfolder_na=self.host_trainer.model.visitor.model_name+"selected", - args=self.host_trainer.aconf, node=self.host_trainer.task, model=model_ld, device=self.device) + fun_gen( + subfolder_na=self.host_trainer.model.visitor.model_name + "selected", + args=self.host_trainer.aconf, + node=self.host_trainer.task, + model=model_ld, + device=self.device, + ) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 6b31d8c6d..d416238c8 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -2,7 +2,9 @@ Base Class for trainer """ import abc + from torch import optim + from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -18,11 +20,11 @@ def mk_opt(model, aconf): set_param = set(list(var1) + list(var2)) list_par = list(set_param) # optimizer = optim.Adam([var1, var2], lr=aconf.lr) - #optimizer = optim.Adam([ + # optimizer = optim.Adam([ # {'params': model.parameters()}, # {'params': model._decoratee.parameters()} - #], lr=aconf.lr) - optimizer = optim.Adam(list_par, lr= aconf.lr) + # ], lr=aconf.lr) + optimizer = optim.Adam(list_par, lr=aconf.lr) return optimizer @@ -30,6 +32,7 @@ class AbstractTrainer(AbstractChainNodeHandler, metaclass=abc.ABCMeta): """ Algorithm director that controls the data flow """ + @property def p_na_prefix(self): """ @@ -120,7 +123,9 @@ def init_business(self, model, task, observer, device, aconf, flag_accept=True): # Note self.decoratee can be both model and trainer, # but self._decoratee can only be trainer! if self._decoratee is not None: - self._decoratee.init_business(model, task, observer, device, aconf, flag_accept) + self._decoratee.init_business( + model, task, observer, device, aconf, flag_accept + ) self.model = model self.task = task self.task.init_business(trainer=self, args=aconf) @@ -193,8 +198,11 @@ def name(self): na_class = type(self).__name__ if na_class[:len_prefix] != na_prefix: raise RuntimeError( - "Trainer builder node class must start with ", na_prefix, - "the current class is named: ", na_class) + "Trainer builder node class must start with ", + na_prefix, + "the current class is named: ", + na_class, + ) return type(self).__name__[len_prefix:].lower() def is_myjob(self, request): @@ -218,11 +226,14 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): can be either a trainer or a model """ list_reg_model, list_mu_model = self.decoratee.cal_reg_loss( - tensor_x, tensor_y, tensor_d, others) + tensor_x, tensor_y, tensor_d, others + ) assert len(list_reg_model) == len(list_mu_model) - list_reg_trainer, list_mu_trainer = self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others) - assert len(list_reg_trainer) == len(list_mu_trainer) + list_reg_trainer, list_mu_trainer = self._cal_reg_loss( + tensor_x, tensor_y, tensor_d, others + ) + assert len(list_reg_trainer) == len(list_mu_trainer) list_loss = list_reg_model + list_reg_trainer list_mu = list_mu_model + list_mu_trainer diff --git a/domainlab/algos/trainers/args_dial.py b/domainlab/algos/trainers/args_dial.py index 118715f35..0c0445915 100644 --- a/domainlab/algos/trainers/args_dial.py +++ b/domainlab/algos/trainers/args_dial.py @@ -7,12 +7,28 @@ def add_args2parser_dial(parser): """ append hyper-parameters to the main argparser """ - parser.add_argument('--dial_steps_perturb', type=int, default=3, - help='how many gradient step to go to find an image as adversarials') - parser.add_argument('--dial_noise_scale', type=float, default=0.001, - help='variance of gaussian noise to inject on pure image') - parser.add_argument('--dial_lr', type=float, default=0.003, - help='learning rate to generate adversarial images') - parser.add_argument('--dial_epsilon', type=float, default=0.031, - help='pixel wise threshold to perturb images') + parser.add_argument( + "--dial_steps_perturb", + type=int, + default=3, + help="how many gradient step to go to find an image as adversarials", + ) + parser.add_argument( + "--dial_noise_scale", + type=float, + default=0.001, + help="variance of gaussian noise to inject on pure image", + ) + parser.add_argument( + "--dial_lr", + type=float, + default=0.003, + help="learning rate to generate adversarial images", + ) + parser.add_argument( + "--dial_epsilon", + type=float, + default=0.031, + help="pixel wise threshold to perturb images", + ) return parser diff --git a/domainlab/algos/trainers/compos/matchdg_args.py b/domainlab/algos/trainers/compos/matchdg_args.py index cce939e7a..ce649ac0e 100644 --- a/domainlab/algos/trainers/compos/matchdg_args.py +++ b/domainlab/algos/trainers/compos/matchdg_args.py @@ -8,12 +8,18 @@ def add_args2parser_matchdg(parser): args for matchdg """ # parser = argparse.ArgumentParser() - parser.add_argument('--tau', type=float, default=0.05, - help='factor to magnify cosine similarity') - parser.add_argument('--epos_per_match_update', type=int, default=5, - help='Number of epochs before updating the match tensor') - parser.add_argument('--epochs_ctr', type=int, default=1, - help='Total number of epochs for ctr') + parser.add_argument( + "--tau", type=float, default=0.05, help="factor to magnify cosine similarity" + ) + parser.add_argument( + "--epos_per_match_update", + type=int, + default=5, + help="Number of epochs before updating the match tensor", + ) + parser.add_argument( + "--epochs_ctr", type=int, default=1, help="Total number of epochs for ctr" + ) # args = parser.parse_args("") # return args return parser diff --git a/domainlab/algos/trainers/compos/matchdg_match.py b/domainlab/algos/trainers/compos/matchdg_match.py index a92f31008..78e67abde 100644 --- a/domainlab/algos/trainers/compos/matchdg_match.py +++ b/domainlab/algos/trainers/compos/matchdg_match.py @@ -1,26 +1,33 @@ import warnings -from domainlab.utils.logger import Logger import numpy as np import torch from domainlab.algos.trainers.compos.matchdg_utils import ( - MatchDictNumDomain2SizeDomain, MatchDictVirtualRefDset2EachDomain) + MatchDictNumDomain2SizeDomain, + MatchDictVirtualRefDset2EachDomain, +) from domainlab.tasks.utils_task import mk_loader +from domainlab.utils.logger import Logger from domainlab.utils.utils_class import store_args -class MatchPair(): +class MatchPair: + """ + match different input + """ @store_args - def __init__(self, - dim_y, - i_c, - i_h, - i_w, - bs_match, - virtual_ref_dset_size, - num_domains_tr, - list_tr_domain_size): + def __init__( + self, + dim_y, + i_c, + i_h, + i_w, + bs_match, + virtual_ref_dset_size, + num_domains_tr, + list_tr_domain_size, + ): """ :param virtual_ref_dset_size: sum of biggest class sizes :param num_domains_tr: @@ -35,12 +42,18 @@ def __init__(self, self.dict_virtual_dset2each_domain = MatchDictVirtualRefDset2EachDomain( virtual_ref_dset_size=virtual_ref_dset_size, num_domains_tr=num_domains_tr, - i_c=i_c, i_h=i_h, i_w=i_w)() + i_c=i_c, + i_h=i_h, + i_w=i_w, + )() self.dict_domain_data = MatchDictNumDomain2SizeDomain( num_domains_tr=num_domains_tr, list_tr_domain_size=list_tr_domain_size, - i_c=i_c, i_h=i_h, i_w=i_w)() + i_c=i_c, + i_h=i_h, + i_w=i_w, + )() self.indices_matched = {} for key in range(virtual_ref_dset_size): @@ -59,9 +72,8 @@ def _fill_data(self, loader): # NOTE: loader contains data from several dataset list_idx_several_ds = [] loader_full_data = mk_loader( - loader.dataset, - bsize=loader.batch_size, - drop_last=False) + loader.dataset, bsize=loader.batch_size, drop_last=False + ) # @FIXME: training loader will always drop the last incomplete batch for _, (x_e, y_e, d_e, idx_e) in enumerate(loader_full_data): # traverse mixed domain data from loader @@ -73,7 +85,7 @@ def _fill_data(self, loader): unique_domains = np.unique(d_e) for domain_idx in unique_domains: # select all instances belong to one domain - flag_curr_domain = (d_e == domain_idx) + flag_curr_domain = d_e == domain_idx # flag_curr_domain is subset indicator of # True of False for selection of data from the mini-batch @@ -90,15 +102,18 @@ def _fill_data(self, loader): # tensor.item get the scalar global_ind = global_indices[local_ind].item() - self.dict_domain_data[domain_idx]['data'][global_ind] = \ - x_e[flag_curr_domain][local_ind] + self.dict_domain_data[domain_idx]["data"][global_ind] = x_e[ + flag_curr_domain + ][local_ind] # flag_curr_domain are subset indicator # for selection of domain - self.dict_domain_data[domain_idx]['label'][global_ind] = \ - y_e[flag_curr_domain][local_ind] + self.dict_domain_data[domain_idx]["label"][global_ind] = y_e[ + flag_curr_domain + ][local_ind] # copy trainining batch to dict_domain_data - self.dict_domain_data[domain_idx]['idx'][global_ind] = \ - idx_e[flag_curr_domain][local_ind] + self.dict_domain_data[domain_idx]["idx"][global_ind] = idx_e[ + flag_curr_domain + ][local_ind] self.domain_count[domain_idx] += 1 # if all data has been re-organized(filled) into the current tensor @@ -109,10 +124,8 @@ def _fill_data(self, loader): for domain in range(self.num_domains_tr): if self.domain_count[domain] != self.list_tr_domain_size[domain]: logger = Logger.get_logger() - logger.warning("domain_count show matching " - "dictionary missing data!") - warnings.warn("domain_count show matching " - "dictionary missing data!") + logger.warning("domain_count show matching: dictionary missing data!") + warnings.warn("domain_count show matching: dictionary missing data!") def _cal_base_domain(self): """ @@ -124,10 +137,12 @@ def _cal_base_domain(self): base_domain_size = 0 base_domain_idx = -1 for domain_idx in range(self.num_domains_tr): - flag_curr_class = \ - (self.dict_domain_data[domain_idx]['label'] == y_c) # tensor of True/False - curr_size = \ - self.dict_domain_data[domain_idx]['label'][flag_curr_class].shape[0] + flag_curr_class = ( + self.dict_domain_data[domain_idx]["label"] == y_c + ) # tensor of True/False + curr_size = self.dict_domain_data[domain_idx]["label"][ + flag_curr_class + ].shape[0] # flag_curr_class are subset indicator if base_domain_size < curr_size: base_domain_size = curr_size @@ -152,70 +167,88 @@ def __call__(self, device, loader, fun_extract_semantic_feat, flag_match_min_dis base_domain_idx = self.dict_cls_ind_base_domain_ind[y_c] # subset indicator - flags_base_domain_curr_cls = \ - (self.dict_domain_data[base_domain_idx]['label'] == y_c) + flags_base_domain_curr_cls = ( + self.dict_domain_data[base_domain_idx]["label"] == y_c + ) flags_base_domain_curr_cls = flags_base_domain_curr_cls[:, 0] - global_inds_base_domain_curr_cls = \ - self.dict_domain_data[base_domain_idx]['idx'][flags_base_domain_curr_cls] + global_inds_base_domain_curr_cls = self.dict_domain_data[ + base_domain_idx + ]["idx"][flags_base_domain_curr_cls] # pick out base domain class label y_c images # the difference of this block is "curr_domain_ind" # in iteration is # used instead of base_domain_idx for current class # pick out current domain y_c class images - flag_curr_domain_curr_cls = (self.dict_domain_data[curr_domain_ind]['label'] == y_c) + flag_curr_domain_curr_cls = ( + self.dict_domain_data[curr_domain_ind]["label"] == y_c + ) # NO label matches y_c flag_curr_domain_curr_cls = flag_curr_domain_curr_cls[:, 0] - global_inds_curr_domain_curr_cls = \ - self.dict_domain_data[curr_domain_ind]['idx'][flag_curr_domain_curr_cls] + global_inds_curr_domain_curr_cls = self.dict_domain_data[ + curr_domain_ind + ]["idx"][flag_curr_domain_curr_cls] size_curr_domain_curr_cls = global_inds_curr_domain_curr_cls.shape[0] - if size_curr_domain_curr_cls == 0: # there is no class y_c in current domain + if ( + size_curr_domain_curr_cls == 0 + ): # there is no class y_c in current domain raise RuntimeError( - f"current domain {curr_domain_ind} does not contain class {y_c}") + f"current domain {curr_domain_ind} does not contain class {y_c}" + ) # compute base domain features for class label y_c - x_base_domain_curr_cls = \ - self.dict_domain_data[base_domain_idx]['data'][flags_base_domain_curr_cls] + x_base_domain_curr_cls = self.dict_domain_data[base_domain_idx]["data"][ + flags_base_domain_curr_cls + ] # pick out base domain class label y_c images # split data into chunks - tuple_batch_x_base_domain_curr_cls = \ - torch.split(x_base_domain_curr_cls, self.bs_match, dim=0) + tuple_batch_x_base_domain_curr_cls = torch.split( + x_base_domain_curr_cls, self.bs_match, dim=0 + ) # @FIXME. when x_base_domain_curr_cls is smaller # than the self.bs_match, then there is only one batch list_base_feat = [] for batch_x_base_domain_curr_cls in tuple_batch_x_base_domain_curr_cls: with torch.no_grad(): - batch_x_base_domain_curr_cls = batch_x_base_domain_curr_cls.to(device) + batch_x_base_domain_curr_cls = batch_x_base_domain_curr_cls.to( + device + ) feat = fun_extract_semantic_feat(batch_x_base_domain_curr_cls) list_base_feat.append(feat.cpu()) tensor_feat_base_domain_curr_cls = torch.cat(list_base_feat) # base domain features if flag_match_min_dist: # if epoch > 0:flag_match_min_dist=True - x_curr_domain_curr_cls = \ - self.dict_domain_data[curr_domain_ind]['data'][flag_curr_domain_curr_cls] + x_curr_domain_curr_cls = self.dict_domain_data[curr_domain_ind][ + "data" + ][flag_curr_domain_curr_cls] # indices_curr pick out current domain y_c class images - tuple_x_batch_curr_domain_curr_cls = \ - torch.split(x_curr_domain_curr_cls, self.bs_match, dim=0) + tuple_x_batch_curr_domain_curr_cls = torch.split( + x_curr_domain_curr_cls, self.bs_match, dim=0 + ) list_feat_x_curr_domain_curr_cls = [] for batch_feat in tuple_x_batch_curr_domain_curr_cls: with torch.no_grad(): batch_feat = batch_feat.to(device) out = fun_extract_semantic_feat(batch_feat) list_feat_x_curr_domain_curr_cls.append(out.cpu()) - tensor_feat_curr_domain_curr_cls = torch.cat(list_feat_x_curr_domain_curr_cls) + tensor_feat_curr_domain_curr_cls = torch.cat( + list_feat_x_curr_domain_curr_cls + ) # feature through inference network for the current domain of class y_c - tensor_feat_base_domain_curr_cls = tensor_feat_base_domain_curr_cls.unsqueeze(1) - tuple_feat_base_domain_curr_cls = \ - torch.split(tensor_feat_base_domain_curr_cls, self.bs_match, dim=0) + tensor_feat_base_domain_curr_cls = ( + tensor_feat_base_domain_curr_cls.unsqueeze(1) + ) + tuple_feat_base_domain_curr_cls = torch.split( + tensor_feat_base_domain_curr_cls, self.bs_match, dim=0 + ) counter_curr_cls_base_domain = 0 # tuple_feat_base_domain_curr_cls is a tuple of splitted part for feat_base_domain_curr_cls in tuple_feat_base_domain_curr_cls: - - if flag_match_min_dist: # if epoch > 0:flag_match_min_dist=True + if flag_match_min_dist: # if epoch > 0:flag_match_min_dist=True # Need to compute over batches of # feature due to device Memory out errors # Else no need for loop over @@ -223,10 +256,14 @@ def __call__(self, device, loader, fun_extract_semantic_feat, flag_match_min_dis # could have simply computed # tensor_feat_curr_domain_curr_cls - # tensor_feat_base_domain_curr_cls - dist_same_class_base_domain_curr_domain = \ - torch.sum( - (tensor_feat_curr_domain_curr_cls - feat_base_domain_curr_cls)**2, - dim=2) + dist_same_class_base_domain_curr_domain = torch.sum( + ( + tensor_feat_curr_domain_curr_cls + - feat_base_domain_curr_cls + ) + ** 2, + dim=2, + ) # tensor_feat_curr_domain_curr_cls.shape torch.Size([184, 512]) # feat_base_domain_curr_cls.shape torch.Size([64, 1, 512]) # (tensor_feat_curr_domain_curr_cls - feat_base_domain_curr_cls).shape: @@ -235,8 +272,9 @@ def __call__(self, device, loader, fun_extract_semantic_feat, flag_match_min_dis # torch.Size([64, 184]) is the per element distance of # the cartesian product of feat_base_domain_curr_cls vs # tensor_feat_curr_domain_curr_cls - match_ind_base_domain_curr_domain = \ - torch.argmin(dist_same_class_base_domain_curr_domain, dim=1) + match_ind_base_domain_curr_domain = torch.argmin( + dist_same_class_base_domain_curr_domain, dim=1 + ) # the batch index of the neareast neighbors # len(match_ind_base_domain_curr_domain)=64 # theoretically match_ind_base_domain_curr_domain can @@ -251,24 +289,44 @@ def __call__(self, device, loader, fun_extract_semantic_feat, flag_match_min_dis # ## global_inds_base_domain_curr_cls pick out base # domain class label y_c images - global_pos_base_domain_curr_cls = \ - global_inds_base_domain_curr_cls[counter_curr_cls_base_domain].item() + global_pos_base_domain_curr_cls = ( + global_inds_base_domain_curr_cls[ + counter_curr_cls_base_domain + ].item() + ) if curr_domain_ind == base_domain_idx: - ind_match_global_curr_domain_curr_cls = global_pos_base_domain_curr_cls + ind_match_global_curr_domain_curr_cls = ( + global_pos_base_domain_curr_cls + ) else: if flag_match_min_dist: # if epoch > 0:match_min_dist=True - ind_match_global_curr_domain_curr_cls = \ + ind_match_global_curr_domain_curr_cls = ( global_inds_curr_domain_curr_cls[ - match_ind_base_domain_curr_domain[idx]].item() + match_ind_base_domain_curr_domain[idx] + ].item() + ) else: # if epoch == 0 - ind_match_global_curr_domain_curr_cls = \ + ind_match_global_curr_domain_curr_cls = ( global_inds_curr_domain_curr_cls[ - counter_curr_cls_base_domain%size_curr_domain_curr_cls].item() - - self.dict_virtual_dset2each_domain[counter_ref_dset_size]['data'][curr_domain_ind] = \ - self.dict_domain_data[curr_domain_ind]['data'][ind_match_global_curr_domain_curr_cls] - self.dict_virtual_dset2each_domain[counter_ref_dset_size]['label'][curr_domain_ind] = \ - self.dict_domain_data[curr_domain_ind]['label'][ind_match_global_curr_domain_curr_cls] + counter_curr_cls_base_domain + % size_curr_domain_curr_cls + ].item() + ) + + self.dict_virtual_dset2each_domain[counter_ref_dset_size][ + "data" + ][curr_domain_ind] = self.dict_domain_data[curr_domain_ind][ + "data" + ][ + ind_match_global_curr_domain_curr_cls + ] + self.dict_virtual_dset2each_domain[counter_ref_dset_size][ + "label" + ][curr_domain_ind] = self.dict_domain_data[curr_domain_ind][ + "label" + ][ + ind_match_global_curr_domain_curr_cls + ] # @FIXME: label initially were set to random continuous # value, which is a technique to check if # every data has been filled @@ -279,42 +337,64 @@ def __call__(self, device, loader, fun_extract_semantic_feat, flag_match_min_dis logger = Logger.get_logger() logger.info(f"counter_ref_dset_size {counter_ref_dset_size}") logger.info(f"self.virtual_ref_dset_size {self.virtual_ref_dset_size}") - logger.warning("counter_ref_dset_size not equal to self.virtual_ref_dset_size") - raise RuntimeError("counter_ref_dset_size not equal to self.virtual_ref_dset_size") - + logger.warning( + "counter_ref_dset_size not equal to self.virtual_ref_dset_size" + ) + raise RuntimeError( + "counter_ref_dset_size not equal to self.virtual_ref_dset_size" + ) for key in self.dict_virtual_dset2each_domain.keys(): - if self.dict_virtual_dset2each_domain[key]['label'].shape[0] != self.num_domains_tr: - raise RuntimeError("self.dict_virtual_dset2each_domain, \ + if ( + self.dict_virtual_dset2each_domain[key]["label"].shape[0] + != self.num_domains_tr + ): + raise RuntimeError( + "self.dict_virtual_dset2each_domain, \ one key correspond to value tensor not \ - equal to number of training domains") + equal to number of training domains" + ) # Sanity Check: Ensure paired points have the same class label wrong_case = 0 for key in self.dict_virtual_dset2each_domain.keys(): - for d_i in range(self.dict_virtual_dset2each_domain[key]['label'].shape[0]): - for d_j in range(self.dict_virtual_dset2each_domain[key]['label'].shape[0]): + for d_i in range(self.dict_virtual_dset2each_domain[key]["label"].shape[0]): + for d_j in range( + self.dict_virtual_dset2each_domain[key]["label"].shape[0] + ): if d_j > d_i: - if self.dict_virtual_dset2each_domain[key]['label'][d_i] != self.dict_virtual_dset2each_domain[key]['label'][d_j]: + if ( + self.dict_virtual_dset2each_domain[key]["label"][d_i] + != self.dict_virtual_dset2each_domain[key]["label"][d_j] + ): # raise RuntimeError("the reference domain has 'rows' with inconsistent class labels") wrong_case += 1 logger = Logger.get_logger() - logger.info(f'Total Label MisMatch across pairs: {wrong_case}') + logger.info(f"Total Label MisMatch across pairs: {wrong_case}") if wrong_case != 0: - raise RuntimeError("the reference domain " - "has 'rows' with inconsistent class labels") + raise RuntimeError( + "the reference domain has 'rows' with inconsistent class labels" + ) list_ref_domain_each_domain = [] list_ref_domain_each_domain_label = [] for ind_ref_domain_key in self.dict_virtual_dset2each_domain.keys(): - list_ref_domain_each_domain.append(self.dict_virtual_dset2each_domain[ind_ref_domain_key]['data']) - list_ref_domain_each_domain_label.append(self.dict_virtual_dset2each_domain[ind_ref_domain_key]['label']) + list_ref_domain_each_domain.append( + self.dict_virtual_dset2each_domain[ind_ref_domain_key]["data"] + ) + list_ref_domain_each_domain_label.append( + self.dict_virtual_dset2each_domain[ind_ref_domain_key]["label"] + ) tensor_ref_domain_each_domain_x = torch.stack(list_ref_domain_each_domain) - tensor_ref_domain_each_domain_label = torch.stack(list_ref_domain_each_domain_label) - - logger.info(f"{tensor_ref_domain_each_domain_x.shape} " - f"{tensor_ref_domain_each_domain_label.shape}") + tensor_ref_domain_each_domain_label = torch.stack( + list_ref_domain_each_domain_label + ) + + logger.info( + f"{tensor_ref_domain_each_domain_x.shape} " + f"{tensor_ref_domain_each_domain_label.shape}" + ) del self.dict_domain_data del self.dict_virtual_dset2each_domain diff --git a/domainlab/algos/trainers/compos/matchdg_utils.py b/domainlab/algos/trainers/compos/matchdg_utils.py index 04061e68c..62785e6a6 100644 --- a/domainlab/algos/trainers/compos/matchdg_utils.py +++ b/domainlab/algos/trainers/compos/matchdg_utils.py @@ -2,13 +2,15 @@ create dictionary for matching """ import torch + from domainlab.utils.logger import Logger -class MatchDictInit(): +class MatchDictInit: """ base class for matching dictionary creator """ + def __init__(self, keys, vals, i_c, i_h, i_w): self.keys = keys self.vals = vals @@ -24,41 +26,55 @@ def __call__(self): for key in self.keys: dict_data[key] = {} num_rows = self.get_num_rows(key) - dict_data[key]['data'] = torch.rand((num_rows, self.i_c, self.i_w, self.i_h)) + dict_data[key]["data"] = torch.rand( + (num_rows, self.i_c, self.i_w, self.i_h) + ) # @FIXME: some labels won't be filled at all, when using training loader since the incomplete batch is dropped - dict_data[key]['label'] = torch.rand((num_rows, 1)) # scalar label - dict_data[key]['idx'] = torch.randint(low=0, high=1, size=(num_rows, 1)) + dict_data[key]["label"] = torch.rand((num_rows, 1)) # scalar label + dict_data[key]["idx"] = torch.randint(low=0, high=1, size=(num_rows, 1)) return dict_data -class MatchDictVirtualRefDset2EachDomain(MatchDictInit): +class MatchDictVirtualRefDset2EachDomain(MatchDictInit): """ dict[0:virtual_ref_dset_size] has tensor dimension: (num_domains_tr, i_c, i_h, i_w) """ + def __init__(self, virtual_ref_dset_size, num_domains_tr, i_c, i_h, i_w): """ virtual_ref_dset_size is a virtual dataset, len(virtual_ref_dset_size) = sum of all popular domains """ - super().__init__(keys=range(virtual_ref_dset_size), vals=num_domains_tr, - i_c=i_c, i_h=i_h, i_w=i_w) + super().__init__( + keys=range(virtual_ref_dset_size), + vals=num_domains_tr, + i_c=i_c, + i_h=i_h, + i_w=i_w, + ) def get_num_rows(self, key=None): """ key is 0:virtual_ref_dset_size """ - return self.vals # total_domains + return self.vals # total_domains class MatchDictNumDomain2SizeDomain(MatchDictInit): """ tensor dimension for the kth domain: [num_domains_tr, (size_domain_k, i_c, i_h, i_w)] """ + def __init__(self, num_domains_tr, list_tr_domain_size, i_c, i_h, i_w): - super().__init__(keys=range(num_domains_tr), vals=list_tr_domain_size, - i_c=i_c, i_h=i_h, i_w=i_w) + super().__init__( + keys=range(num_domains_tr), + vals=list_tr_domain_size, + i_c=i_c, + i_h=i_h, + i_w=i_w, + ) def get_num_rows(self, key): - return self.vals[key] # list_tr_domain_size[domain_index] + return self.vals[key] # list_tr_domain_size[domain_index] def dist_cosine_agg(x1, x2): @@ -68,10 +84,15 @@ def dist_cosine_agg(x1, x2): fun_cos = torch.nn.CosineSimilarity(dim=1, eps=1e-08) return 1.0 - fun_cos(x1, x2) + def fun_tensor_normalize(tensor_batch_x): eps = 1e-8 - batch_norm_x = tensor_batch_x.norm(dim=1) # Frobenius norm or Euclidean Norm long the embedding direction, len(norm) should be batch_size - batch_norm_x = batch_norm_x.view(batch_norm_x.shape[0], 1) # add dimension to tensor + batch_norm_x = tensor_batch_x.norm( + dim=1 + ) # Frobenius norm or Euclidean Norm long the embedding direction, len(norm) should be batch_size + batch_norm_x = batch_norm_x.view( + batch_norm_x.shape[0], 1 + ) # add dimension to tensor tensor_eps = eps * torch.ones_like(batch_norm_x) tensor_batch_x = tensor_batch_x / torch.max(batch_norm_x, tensor_eps) assert not torch.sum(torch.isnan(tensor_batch_x)) @@ -89,12 +110,14 @@ def dist_pairwise_cosine(x1, x2, tau=0.05): x1 = fun_tensor_normalize(x1) x2 = fun_tensor_normalize(x2) - x1_extended_dim = x1.unsqueeze(1) # Returns a new tensor with a dimension of size one inserted at the specified position. + x1_extended_dim = x1.unsqueeze( + 1 + ) # Returns a new tensor with a dimension of size one inserted at the specified position. # extend the order of by insering a new dimension so that cartesion product of pairwise distance can be calculated # since the batch size of x1 and x2 won't be the same, directly calculting elementwise product will cause an error # with order 3 multiply order 2 tensor, the feature dimension will be matched then the rest dimensions form cartesian product - cos_sim = torch.sum(x1_extended_dim*x2, dim=2) # elementwise product + cos_sim = torch.sum(x1_extended_dim * x2, dim=2) # elementwise product cos_sim = cos_sim / tau # make cosine similarity bigger than 1 assert not torch.sum(torch.isnan(cos_sim)) loss = torch.sum(torch.exp(cos_sim), dim=1) @@ -121,7 +144,9 @@ def get_base_domain_size4match_dg(task): ref_domain = domain_key num = task.dict_domain_class_count[domain_key][mclass] logger = Logger.get_logger() - logger.info(f"for class {mclass} bigest sample size is {num} " - f"ref domain is {ref_domain}") + logger.info( + f"for class {mclass} bigest sample size is {num} " + f"ref domain is {ref_domain}" + ) base_domain_size += num - return base_domain_size \ No newline at end of file + return base_domain_size diff --git a/domainlab/algos/trainers/hyper_scheduler.py b/domainlab/algos/trainers/hyper_scheduler.py index 5bf13c447..ee61504e1 100644 --- a/domainlab/algos/trainers/hyper_scheduler.py +++ b/domainlab/algos/trainers/hyper_scheduler.py @@ -4,10 +4,11 @@ import numpy as np -class HyperSchedulerWarmupLinear(): +class HyperSchedulerWarmupLinear: """ HyperSchedulerWarmupLinear """ + def __init__(self, trainer, **kwargs): """ kwargs is a dictionary with key the hyper-parameter name and its value @@ -28,7 +29,7 @@ def warmup(self, par_setpoint, epoch): # total_steps :param epoch: """ - ratio = ((epoch+1) * 1.) / self.total_steps + ratio = ((epoch + 1) * 1.0) / self.total_steps list_par = [par_setpoint, par_setpoint * ratio] par = min(list_par) return par @@ -44,15 +45,16 @@ class HyperSchedulerWarmupExponential(HyperSchedulerWarmupLinear): """ HyperScheduler Exponential """ + def warmup(self, par_setpoint, epoch): """ start from a small value of par to ramp up the steady state value using number of total_steps :param epoch: """ - percent_steps = ((epoch+1) * 1.) / self.total_steps - denominator = 1. + np.exp(-10 * percent_steps) - ratio = (2. / denominator - 1) + percent_steps = ((epoch + 1) * 1.0) / self.total_steps + denominator = 1.0 + np.exp(-10 * percent_steps) + ratio = 2.0 / denominator - 1 # percent_steps is 0, denom is 2, 2/denom is 1, ratio is 0 # percent_steps is 1, denom is 1+exp(-10), 2/denom is 2/(1+exp(-10))=2, ratio is 1 # exp(-10)=4.5e-5 is approximately 0 diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 54f0a449a..7e6d7cac7 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -3,11 +3,11 @@ """ import math from operator import add + import torch from domainlab import g_tensor_batch_agg -from domainlab.algos.trainers.a_trainer import AbstractTrainer -from domainlab.algos.trainers.a_trainer import mk_opt +from domainlab.algos.trainers.a_trainer import AbstractTrainer, mk_opt def list_divide(list_val, scalar): @@ -18,6 +18,7 @@ class TrainerBasic(AbstractTrainer): """ basic trainer """ + def before_tr(self): """ check the performance of randomly initialized weight @@ -37,10 +38,10 @@ def before_epoch(self): def tr_epoch(self, epoch): self.before_epoch() - for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in \ - enumerate(self.loader_tr): - self.tr_batch(tensor_x, tensor_y, tensor_d, others, - ind_batch, epoch) + for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( + self.loader_tr + ): + self.tr_batch(tensor_x, tensor_y, tensor_d, others, ind_batch, epoch) return self.after_epoch(epoch) def after_epoch(self, epoch): @@ -49,8 +50,7 @@ def after_epoch(self, epoch): """ self.epo_loss_tr /= self.counter_batch self.epo_task_loss_tr /= self.counter_batch - self.epo_reg_loss_tr = list_divide(self.epo_reg_loss_tr, - self.counter_batch) + self.epo_reg_loss_tr = list_divide(self.epo_reg_loss_tr, self.counter_batch) assert self.epo_loss_tr is not None assert not math.isnan(self.epo_loss_tr) flag_stop = self.observer.update(epoch) # notify observer @@ -63,10 +63,10 @@ def log_loss(self, list_b_reg_loss, loss_task, loss): """ self.epo_task_loss_tr += loss_task.sum().detach().item() # - list_b_reg_loss_sumed = [ele.sum().detach().item() - for ele in list_b_reg_loss] - self.epo_reg_loss_tr = list(map(add, self.epo_reg_loss_tr, - list_b_reg_loss_sumed)) + list_b_reg_loss_sumed = [ele.sum().detach().item() for ele in list_b_reg_loss] + self.epo_reg_loss_tr = list( + map(add, self.epo_reg_loss_tr, list_b_reg_loss_sumed) + ) self.epo_loss_tr += loss.detach().item() def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): @@ -74,9 +74,11 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): optimize neural network one step upon a mini-batch of data """ self.before_batch(epoch, ind_batch) - tensor_x, tensor_y, tensor_d = \ - tensor_x.to(self.device), tensor_y.to(self.device), \ - tensor_d.to(self.device) + tensor_x, tensor_y, tensor_d = ( + tensor_x.to(self.device), + tensor_y.to(self.device), + tensor_d.to(self.device), + ) self.optimizer.zero_grad() loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss.backward() @@ -90,13 +92,17 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): """ loss_task = self.model.cal_task_loss(tensor_x, tensor_y) - list_reg_tr_batch, list_mu_tr = self.cal_reg_loss(tensor_x, tensor_y, - tensor_d, others) + list_reg_tr_batch, list_mu_tr = self.cal_reg_loss( + tensor_x, tensor_y, tensor_d, others + ) tensor_batch_reg_loss_penalized = self.model.list_inner_product( - list_reg_tr_batch, list_mu_tr) + list_reg_tr_batch, list_mu_tr + ) assert len(tensor_batch_reg_loss_penalized.shape) == 1 loss_erm_agg = g_tensor_batch_agg(loss_task) loss_reg_penalized_agg = g_tensor_batch_agg(tensor_batch_reg_loss_penalized) - loss_penalized = self.model.multiplier4task_loss * loss_erm_agg + loss_reg_penalized_agg + loss_penalized = ( + self.model.multiplier4task_loss * loss_erm_agg + loss_reg_penalized_agg + ) self.log_loss(list_reg_tr_batch, loss_task, loss_penalized) return loss_penalized diff --git a/domainlab/algos/trainers/train_dial.py b/domainlab/algos/trainers/train_dial.py index 438ac326e..75a5e34f0 100644 --- a/domainlab/algos/trainers/train_dial.py +++ b/domainlab/algos/trainers/train_dial.py @@ -11,6 +11,7 @@ class TrainerDIAL(TrainerBasic): """ Trainer Domain Invariant Adversarial Learning """ + def gen_adversarial(self, device, img_natural, vec_y): """ use naive trimming to find optimize img in the direction of adversarial gradient, @@ -23,7 +24,9 @@ def gen_adversarial(self, device, img_natural, vec_y): step_size = self.aconf.dial_lr epsilon = self.aconf.dial_epsilon img_adv_ini = img_natural.detach() - img_adv_ini = img_adv_ini + scale * torch.randn(img_natural.shape).to(device).detach() + img_adv_ini = ( + img_adv_ini + scale * torch.randn(img_natural.shape).to(device).detach() + ) img_adv = img_adv_ini for _ in range(steps_perturb): img_adv.requires_grad_() @@ -31,7 +34,9 @@ def gen_adversarial(self, device, img_natural, vec_y): grad = torch.autograd.grad(loss_gen_adv, [img_adv])[0] # instead of gradient descent, we gradient ascent here img_adv = img_adv_ini.detach() + step_size * torch.sign(grad.detach()) - img_adv = torch.min(torch.max(img_adv, img_natural - epsilon), img_natural + epsilon) + img_adv = torch.min( + torch.max(img_adv, img_natural - epsilon), img_natural + epsilon + ) img_adv = torch.clamp(img_adv, 0.0, 1.0) return img_adv diff --git a/domainlab/algos/trainers/train_hyper_scheduler.py b/domainlab/algos/trainers/train_hyper_scheduler.py index b9eddec53..2e60bf5e8 100644 --- a/domainlab/algos/trainers/train_hyper_scheduler.py +++ b/domainlab/algos/trainers/train_hyper_scheduler.py @@ -1,8 +1,8 @@ """ update hyper-parameters during training """ -from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.algos.trainers.hyper_scheduler import HyperSchedulerWarmupLinear +from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.utils.logger import Logger @@ -10,9 +10,10 @@ class TrainerHyperScheduler(TrainerBasic): """ TrainerHyperScheduler """ - def set_scheduler(self, scheduler, total_steps, - flag_update_epoch=False, - flag_update_batch=False): + + def set_scheduler( + self, scheduler, total_steps, flag_update_epoch=False, flag_update_batch=False + ): """ set the warmup strategy from objective scheduler set wheter the hyper-parameter scheduling happens per epoch or per batch @@ -36,17 +37,23 @@ def before_batch(self, epoch, ind_batch): should be set to epoch*self.num_batches + ind_batch """ if self.flag_update_hyper_per_batch: - self.model.hyper_update(epoch*self.num_batches + ind_batch, self.hyper_scheduler) + self.model.hyper_update( + epoch * self.num_batches + ind_batch, self.hyper_scheduler + ) return super().before_batch(epoch, ind_batch) def before_tr(self): if self.hyper_scheduler is None: logger = Logger.get_logger() - logger.warning("hyper-parameter scheduler not set," - "going to use default Warmpup and epoch update") - self.set_scheduler(HyperSchedulerWarmupLinear, - total_steps=self.aconf.warmup, - flag_update_epoch=True) + logger.warning( + "hyper-parameter scheduler not set," + "going to use default Warmpup and epoch update" + ) + self.set_scheduler( + HyperSchedulerWarmupLinear, + total_steps=self.aconf.warmup, + flag_update_epoch=True, + ) def tr_epoch(self, epoch): """ diff --git a/domainlab/algos/trainers/train_matchdg.py b/domainlab/algos/trainers/train_matchdg.py index 98afc3126..6a3edd996 100644 --- a/domainlab/algos/trainers/train_matchdg.py +++ b/domainlab/algos/trainers/train_matchdg.py @@ -5,24 +5,28 @@ from domainlab import g_inst_component_loss_agg, g_list_loss_agg from domainlab.algos.trainers.a_trainer import AbstractTrainer -from domainlab.algos.trainers.compos.matchdg_utils import \ -get_base_domain_size4match_dg from domainlab.algos.trainers.compos.matchdg_match import MatchPair -from domainlab.algos.trainers.compos.matchdg_utils import (dist_cosine_agg, - dist_pairwise_cosine) -from domainlab.utils.logger import Logger +from domainlab.algos.trainers.compos.matchdg_utils import ( + dist_cosine_agg, + dist_pairwise_cosine, + get_base_domain_size4match_dg, +) from domainlab.tasks.utils_task_dset import DsetIndDecorator4XYD +from domainlab.utils.logger import Logger class TrainerMatchDG(AbstractTrainer): """ Contrastive Learning """ + def dset_decoration_args_algo(self, args, ddset): ddset = DsetIndDecorator4XYD(ddset) return ddset - def init_business(self, model, task, observer, device, aconf, flag_accept=True, flag_erm=False): + def init_business( + self, model, task, observer, device, aconf, flag_accept=True, flag_erm=False + ): """ initialize member objects """ @@ -55,16 +59,20 @@ def tr_epoch(self, epoch): # self.tensor_ref_domain2each_domain_x[inds_shuffle] # shuffles the match tensor at the first dimension self.tuple_tensor_refdomain2each = torch.split( - self.tensor_ref_domain2each_domain_x[inds_shuffle], - self.aconf.bs, dim=0) + self.tensor_ref_domain2each_domain_x[inds_shuffle], self.aconf.bs, dim=0 + ) # Splits the tensor into chunks. # Each chunk is a view of the original tensor of batch size self.aconf.bs # return is a tuple of the splited chunks self.tuple_tensor_ref_domain2each_y = torch.split( - self.tensor_ref_domain2each_domain_y[inds_shuffle], - self.aconf.bs, dim=0) - logger.info(f"number of batches in match tensor: {len(self.tuple_tensor_refdomain2each)}") - logger.info(f"single batch match tensor size: {self.tuple_tensor_refdomain2each[0].shape}") + self.tensor_ref_domain2each_domain_y[inds_shuffle], self.aconf.bs, dim=0 + ) + logger.info( + f"number of batches in match tensor: {len(self.tuple_tensor_refdomain2each)}" + ) + logger.info( + f"single batch match tensor size: {self.tuple_tensor_refdomain2each[0].shape}" + ) for batch_idx, (x_e, y_e, d_e, *others) in enumerate(self.loader_tr): # random loader with same batch size as the match tensor loader @@ -72,8 +80,10 @@ def tr_epoch(self, epoch): # is only used for creating the match tensor self.tr_batch(epoch, batch_idx, x_e, y_e, d_e, others) if self.flag_match_tensor_sweep_over is True: - logger.info("ref/base domain vs each domain match \ - traversed one sweep, starting new epoch") + logger.info( + "ref/base domain vs each domain match \ + traversed one sweep, starting new epoch" + ) self.flag_match_tensor_sweep_over = False break @@ -109,11 +119,15 @@ def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): if self.flag_erm: # decoratee can be both trainer or model - list_loss_reg_rand, list_mu_reg = self.decoratee.cal_reg_loss(x_e, y_e, d_e, others) + list_loss_reg_rand, list_mu_reg = self.decoratee.cal_reg_loss( + x_e, y_e, d_e, others + ) loss_reg = self.model.list_inner_product(list_loss_reg_rand, list_mu_reg) loss_task_rand = self.model.cal_task_loss(x_e, y_e) # loss_erm_rnd_loader, *_ = self.model.cal_loss(x_e, y_e, d_e, others) - loss_erm_rnd_loader = loss_reg + loss_task_rand * self.model.multiplier4task_loss + loss_erm_rnd_loader = ( + loss_reg + loss_task_rand * self.model.multiplier4task_loss + ) num_batches_match_tensor = len(self.tuple_tensor_refdomain2each) @@ -123,17 +137,22 @@ def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): curr_batch_size = self.tuple_tensor_refdomain2each[batch_idx].shape[0] - batch_tensor_ref_domain2each = self.tuple_tensor_refdomain2each[batch_idx].to(self.device) + batch_tensor_ref_domain2each = self.tuple_tensor_refdomain2each[batch_idx].to( + self.device + ) # make order 5 tensor: (ref_domain, domain, channel, img_h, img_w) # with first dimension as batch size # clamp the first two dimensions so the model network could map image to feature - batch_tensor_ref_domain2each = match_tensor_reshape(batch_tensor_ref_domain2each) + batch_tensor_ref_domain2each = match_tensor_reshape( + batch_tensor_ref_domain2each + ) # now batch_tensor_ref_domain2each first dim will not be batch_size! # batch_tensor_ref_domain2each.shape torch.Size([40, channel, 224, 224]) batch_feat_ref_domain2each = self.model.extract_semantic_feat( - batch_tensor_ref_domain2each) + batch_tensor_ref_domain2each + ) # batch_feat_ref_domain2each.shape torch.Size[40, 512] # torch.sum(torch.isnan(batch_tensor_ref_domain2each)) # assert not torch.sum(torch.isnan(batch_feat_ref_domain2each)) @@ -141,22 +160,28 @@ def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): if flag_isnan: logger = Logger.get_logger() logger.info(batch_tensor_ref_domain2each) - raise RuntimeError("batch_feat_ref_domain2each NAN! is learning rate too big or" - "hyper-parameter tau not set appropriately?") + raise RuntimeError( + "batch_feat_ref_domain2each NAN! is learning rate too big or" + "hyper-parameter tau not set appropriately?" + ) # for contrastive training phase, # the last layer of the model is replaced with identity - batch_ref_domain2each_y = self.tuple_tensor_ref_domain2each_y[batch_idx].to(self.device) + batch_ref_domain2each_y = self.tuple_tensor_ref_domain2each_y[batch_idx].to( + self.device + ) batch_ref_domain2each_y = batch_ref_domain2each_y.view( - batch_ref_domain2each_y.shape[0]*batch_ref_domain2each_y.shape[1]) + batch_ref_domain2each_y.shape[0] * batch_ref_domain2each_y.shape[1] + ) if self.flag_erm: # @FIXME: check if batch_ref_domain2each_y is # continuous number which means it is at its initial value, # not yet filled loss_erm_match_tensor, *_ = self.model.cal_task_loss( - batch_tensor_ref_domain2each, batch_ref_domain2each_y.long()) + batch_tensor_ref_domain2each, batch_ref_domain2each_y.long() + ) # Creating tensor of shape (domain size, total domains, feat size ) # The match tensor's first two dimension @@ -172,61 +197,68 @@ def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): dim_feat = batch_feat_ref_domain2each.shape[1] num_domain_tr = len(self.task.list_domain_tr) batch_feat_ref_domain2each = batch_feat_ref_domain2each.view( - curr_batch_size, num_domain_tr, dim_feat) + curr_batch_size, num_domain_tr, dim_feat + ) batch_ref_domain2each_y = batch_ref_domain2each_y.view( - curr_batch_size, num_domain_tr) + curr_batch_size, num_domain_tr + ) # The match tensor's first two dimension # [(Ref domain size) * (# train domains)] has been clamped # together to get features extracted through self.model - batch_tensor_ref_domain2each = \ - batch_tensor_ref_domain2each.view(curr_batch_size, - num_domain_tr, - batch_tensor_ref_domain2each.shape[1], # channel - batch_tensor_ref_domain2each.shape[2], # img_h - batch_tensor_ref_domain2each.shape[3]) # img_w + batch_tensor_ref_domain2each = batch_tensor_ref_domain2each.view( + curr_batch_size, + num_domain_tr, + batch_tensor_ref_domain2each.shape[1], # channel + batch_tensor_ref_domain2each.shape[2], # img_h + batch_tensor_ref_domain2each.shape[3], + ) # img_w # Contrastive Loss: class \times domain \times domain counter_same_cls_diff_domain = 1 logger = Logger.get_logger() for y_c in range(self.task.dim_y): - - subset_same_cls = (batch_ref_domain2each_y[:, 0] == y_c) - subset_diff_cls = (batch_ref_domain2each_y[:, 0] != y_c) + subset_same_cls = batch_ref_domain2each_y[:, 0] == y_c + subset_diff_cls = batch_ref_domain2each_y[:, 0] != y_c feat_same_cls = batch_feat_ref_domain2each[subset_same_cls] feat_diff_cls = batch_feat_ref_domain2each[subset_diff_cls] - logger.debug(f'class {y_c} with same class and different class: ' + - f'{feat_same_cls.shape[0]} {feat_diff_cls.shape[0]}') + logger.debug( + f"class {y_c} with same class and different class: " + + f"{feat_same_cls.shape[0]} {feat_diff_cls.shape[0]}" + ) if feat_same_cls.shape[0] == 0 or feat_diff_cls.shape[0] == 0: - logger.debug(f"no instances of label {y_c}" - f"in the current batch, continue") + logger.debug( + f"no instances of label {y_c}" f"in the current batch, continue" + ) continue if torch.sum(torch.isnan(feat_diff_cls)): - raise RuntimeError('feat_diff_cls has nan entrie(s)') + raise RuntimeError("feat_diff_cls has nan entrie(s)") feat_diff_cls = feat_diff_cls.view( - feat_diff_cls.shape[0]*feat_diff_cls.shape[1], - feat_diff_cls.shape[2]) + feat_diff_cls.shape[0] * feat_diff_cls.shape[1], feat_diff_cls.shape[2] + ) for d_i in range(feat_same_cls.shape[1]): dist_diff_cls_same_domain = dist_pairwise_cosine( - feat_same_cls[:, d_i, :], feat_diff_cls[:, :]) + feat_same_cls[:, d_i, :], feat_diff_cls[:, :] + ) if torch.sum(torch.isnan(dist_diff_cls_same_domain)): - raise RuntimeError('dist_diff_cls_same_domain NAN') + raise RuntimeError("dist_diff_cls_same_domain NAN") # iterate other domains for d_j in range(feat_same_cls.shape[1]): if d_i >= d_j: continue - dist_same_cls_diff_domain = dist_cosine_agg(feat_same_cls[:, d_i, :], - feat_same_cls[:, d_j, :]) + dist_same_cls_diff_domain = dist_cosine_agg( + feat_same_cls[:, d_i, :], feat_same_cls[:, d_j, :] + ) if torch.sum(torch.isnan(dist_same_cls_diff_domain)): - raise RuntimeError('dist_same_cls_diff_domain NAN') + raise RuntimeError("dist_same_cls_diff_domain NAN") # CTR (contrastive) loss is exclusive for # CTR phase and ERM phase @@ -235,12 +267,16 @@ def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): list_batch_loss_ctr.append(torch.sum(dist_same_cls_diff_domain)) else: i_dist_same_cls_diff_domain = 1.0 - dist_same_cls_diff_domain - i_dist_same_cls_diff_domain = \ + i_dist_same_cls_diff_domain = ( i_dist_same_cls_diff_domain / self.aconf.tau - partition = torch.log(torch.exp(i_dist_same_cls_diff_domain) + - dist_diff_cls_same_domain) + ) + partition = torch.log( + torch.exp(i_dist_same_cls_diff_domain) + + dist_diff_cls_same_domain + ) list_batch_loss_ctr.append( - -1 * torch.sum(i_dist_same_cls_diff_domain - partition)) + -1 * torch.sum(i_dist_same_cls_diff_domain - partition) + ) counter_same_cls_diff_domain += dist_same_cls_diff_domain.shape[0] @@ -250,7 +286,7 @@ def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): epos = self.aconf.epos else: epos = self.aconf.epochs_ctr - percentage_finished_epochs = (epoch + 1)/(epos + 1) + percentage_finished_epochs = (epoch + 1) / (epos + 1) # loss aggregation is over different domain # combinations of the same batch # https://discuss.pytorch.org/t/leaf-variable-was-used-in-an-inplace-operation/308 @@ -264,13 +300,18 @@ def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None): # erm loss comes from two different data loaders, # one is rnd (random) data loader # the other one is the data loader from the match tensor - loss_e = torch.tensor(0.0, requires_grad=True) + \ - g_inst_component_loss_agg(loss_erm_rnd_loader) + \ - g_inst_component_loss_agg(loss_erm_match_tensor) * self.model.multiplier4task_loss + \ - self.lambda_ctr * percentage_finished_epochs * loss_ctr + loss_e = ( + torch.tensor(0.0, requires_grad=True) + + g_inst_component_loss_agg(loss_erm_rnd_loader) + + g_inst_component_loss_agg(loss_erm_match_tensor) + * self.model.multiplier4task_loss + + self.lambda_ctr * percentage_finished_epochs * loss_ctr + ) else: - loss_e = torch.tensor(0.0, requires_grad=True) + \ - self.lambda_ctr * percentage_finished_epochs * loss_ctr + loss_e = ( + torch.tensor(0.0, requires_grad=True) + + self.lambda_ctr * percentage_finished_epochs * loss_ctr + ) # @FIXME: without torch.tensor(0.0), after a few epochs, # error "'float' object has no attribute 'backward'" @@ -284,22 +325,27 @@ def mk_match_tensor(self, epoch): """ initialize or update match tensor """ - obj_match = MatchPair(self.task.dim_y, - self.task.isize.i_c, - self.task.isize.i_h, - self.task.isize.i_w, - self.aconf.bs, - virtual_ref_dset_size=self.base_domain_size, - num_domains_tr=len(self.task.list_domain_tr), - list_tr_domain_size=self.list_tr_domain_size) + obj_match = MatchPair( + self.task.dim_y, + self.task.isize.i_c, + self.task.isize.i_h, + self.task.isize.i_w, + self.aconf.bs, + virtual_ref_dset_size=self.base_domain_size, + num_domains_tr=len(self.task.list_domain_tr), + list_tr_domain_size=self.list_tr_domain_size, + ) # @FIXME: what is the usefulness of (epoch > 0) as argument - self.tensor_ref_domain2each_domain_x, self.tensor_ref_domain2each_domain_y = \ - obj_match( + ( + self.tensor_ref_domain2each_domain_x, + self.tensor_ref_domain2each_domain_y, + ) = obj_match( self.device, self.task.loader_tr, self.model.extract_semantic_feat, - (epoch > 0)) + (epoch > 0), + ) def before_tr(self): """ @@ -311,14 +357,15 @@ def before_tr(self): # different than phase 2, ctr_model has no classification loss -def match_tensor_reshape(batch_tensor_ref_domain2each): +def match_tensor_reshape(batch_tensor_ref_domain2each): """ # original dimension is (ref_domain, domain, (channel, img_h, img_w)) # use a function so it is easier to accomodate other data mode (not image) """ batch_tensor_refdomain_other_domain_chw = batch_tensor_ref_domain2each.view( - batch_tensor_ref_domain2each.shape[0]*batch_tensor_ref_domain2each.shape[1], - batch_tensor_ref_domain2each.shape[2], # channel - batch_tensor_ref_domain2each.shape[3], # img_h - batch_tensor_ref_domain2each.shape[4]) # img_w + batch_tensor_ref_domain2each.shape[0] * batch_tensor_ref_domain2each.shape[1], + batch_tensor_ref_domain2each.shape[2], # channel + batch_tensor_ref_domain2each.shape[3], # img_h + batch_tensor_ref_domain2each.shape[4], + ) # img_w return batch_tensor_refdomain_other_domain_chw diff --git a/domainlab/algos/trainers/train_mldg.py b/domainlab/algos/trainers/train_mldg.py index 178561c94..90318286c 100644 --- a/domainlab/algos/trainers/train_mldg.py +++ b/domainlab/algos/trainers/train_mldg.py @@ -3,6 +3,7 @@ """ import copy import random + from torch.utils.data.dataset import ConcatDataset from domainlab.algos.trainers.a_trainer import AbstractTrainer @@ -15,6 +16,7 @@ class TrainerMLDG(AbstractTrainer): """ basic trainer """ + def before_tr(self): """ check the performance of randomly initialized weight @@ -24,8 +26,13 @@ def before_tr(self): self.inner_trainer.extend(self._decoratee) inner_model = copy.deepcopy(self.model) self.inner_trainer.init_business( - inner_model, copy.deepcopy(self.task), self.observer, self.device, self.aconf, - flag_accept=False) + inner_model, + copy.deepcopy(self.task), + self.observer, + self.device, + self.aconf, + flag_accept=False, + ) self.prepare_ziped_loader() def prepare_ziped_loader(self): @@ -36,7 +43,8 @@ def prepare_ziped_loader(self): num_domains = len(list_dsets) ind_target_domain = random.randrange(num_domains) tuple_dsets_source = tuple( - list_dsets[ind] for ind in range(num_domains) if ind != ind_target_domain) + list_dsets[ind] for ind in range(num_domains) if ind != ind_target_domain + ) ddset_source = ConcatDataset(tuple_dsets_source) ddset_target = list_dsets[ind_target_domain] ddset_mix = DsetZip(ddset_source, ddset_target) @@ -47,15 +55,27 @@ def tr_epoch(self, epoch): self.epo_loss_tr = 0 self.prepare_ziped_loader() # s means source, t means target - for ind_batch, (tensor_x_s, vec_y_s, vec_d_s, others_s, - tensor_x_t, vec_y_t, vec_d_t, *_) \ - in enumerate(self.loader_tr_source_target): - - tensor_x_s, vec_y_s, vec_d_s = \ - tensor_x_s.to(self.device), vec_y_s.to(self.device), vec_d_s.to(self.device) + for ind_batch, ( + tensor_x_s, + vec_y_s, + vec_d_s, + others_s, + tensor_x_t, + vec_y_t, + vec_d_t, + *_, + ) in enumerate(self.loader_tr_source_target): + tensor_x_s, vec_y_s, vec_d_s = ( + tensor_x_s.to(self.device), + vec_y_s.to(self.device), + vec_d_s.to(self.device), + ) - tensor_x_t, vec_y_t, vec_d_t = \ - tensor_x_t.to(self.device), vec_y_t.to(self.device), vec_d_t.to(self.device) + tensor_x_t, vec_y_t, vec_d_t = ( + tensor_x_t.to(self.device), + vec_y_t.to(self.device), + vec_d_t.to(self.device), + ) self.optimizer.zero_grad() @@ -64,22 +84,32 @@ def tr_epoch(self, epoch): self.inner_trainer.before_epoch() # set model to train mode self.inner_trainer.reset() # force optimizer to re-initialize self.inner_trainer.tr_batch( - tensor_x_s, vec_y_s, vec_d_s, others_s, ind_batch, epoch) + tensor_x_s, vec_y_s, vec_d_s, others_s, ind_batch, epoch + ) # inner_model has now accumulated gradients Gi # with parameters theta_i - lr * G_i where i index batch - loss_look_forward = self.inner_trainer.model.cal_task_loss(tensor_x_t, vec_y_t) + loss_look_forward = self.inner_trainer.model.cal_task_loss( + tensor_x_t, vec_y_t + ) loss_source_task = self.model.cal_task_loss(tensor_x_s, vec_y_s) - list_source_reg_tr, list_source_mu_tr = self.cal_reg_loss(tensor_x_s, vec_y_s, vec_d_s, others_s) + list_source_reg_tr, list_source_mu_tr = self.cal_reg_loss( + tensor_x_s, vec_y_s, vec_d_s, others_s + ) # call cal_reg_loss from decoratee # super()._cal_reg_loss returns [],[], # since mldg's reg loss is on target domain, # no other trainer except hyperscheduler could decorate it unless we use state pattern # in the future to control source and target domain loader behavior - source_reg_tr = self.model.list_inner_product(list_source_reg_tr, list_source_mu_tr) + source_reg_tr = self.model.list_inner_product( + list_source_reg_tr, list_source_mu_tr + ) # self.aconf.gamma_reg * loss_look_forward.sum() - loss = loss_source_task.sum() + source_reg_tr.sum() +\ - self.aconf.gamma_reg * loss_look_forward.sum() + loss = ( + loss_source_task.sum() + + source_reg_tr.sum() + + self.aconf.gamma_reg * loss_look_forward.sum() + ) # loss.backward() # optimizer only optimize parameters of self.model, not inner_model diff --git a/domainlab/algos/trainers/zoo_trainer.py b/domainlab/algos/trainers/zoo_trainer.py index abeae9520..980e9148c 100644 --- a/domainlab/algos/trainers/zoo_trainer.py +++ b/domainlab/algos/trainers/zoo_trainer.py @@ -3,9 +3,9 @@ """ from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.algos.trainers.train_dial import TrainerDIAL +from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler from domainlab.algos.trainers.train_matchdg import TrainerMatchDG from domainlab.algos.trainers.train_mldg import TrainerMLDG -from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler class TrainerChainNodeGetter(object): @@ -13,13 +13,14 @@ class TrainerChainNodeGetter(object): Chain of Responsibility: node is named in pattern Trainer[XXX] where the string after 'Trainer' is the name to be passed to args.trainer. """ + def __init__(self, str_trainer): """__init__. :param args: command line arguments """ self._list_str_trainer = None if str_trainer is not None: - self._list_str_trainer = str_trainer.split('_') + self._list_str_trainer = str_trainer.split("_") self.request = self._list_str_trainer.pop(0) else: self.request = str_trainer @@ -31,12 +32,16 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None): 2. hard code seems to be the best solution """ if lst_candidates is not None and self.request not in lst_candidates: - raise RuntimeError(f"desired {self.request} is not supported \ - among {lst_candidates}") + raise RuntimeError( + f"desired {self.request} is not supported \ + among {lst_candidates}" + ) if default is not None and self.request is None: self.request = default if lst_excludes is not None and self.request in lst_excludes: - raise RuntimeError(f"desired {self.request} is not supported among {lst_excludes}") + raise RuntimeError( + f"desired {self.request} is not supported among {lst_excludes}" + ) chain = TrainerBasic(None) chain = TrainerDIAL(chain) diff --git a/domainlab/algos/zoo_algos.py b/domainlab/algos/zoo_algos.py index c2a1bfe2a..e067c1326 100644 --- a/domainlab/algos/zoo_algos.py +++ b/domainlab/algos/zoo_algos.py @@ -1,30 +1,28 @@ """ chain of responsibility pattern for algorithm selection """ +from domainlab.algos.builder_api_model import NodeAlgoBuilderAPIModel from domainlab.algos.builder_dann import NodeAlgoBuilderDANN -from domainlab.algos.builder_jigen1 import NodeAlgoBuilderJiGen -from domainlab.algos.builder_erm import NodeAlgoBuilderERM from domainlab.algos.builder_diva import NodeAlgoBuilderDIVA +from domainlab.algos.builder_erm import NodeAlgoBuilderERM from domainlab.algos.builder_hduva import NodeAlgoBuilderHDUVA -from domainlab.algos.builder_api_model import NodeAlgoBuilderAPIModel - +from domainlab.algos.builder_jigen1 import NodeAlgoBuilderJiGen from domainlab.utils.u_import import import_path -class AlgoBuilderChainNodeGetter(): +class AlgoBuilderChainNodeGetter: """ 1. Hardcoded chain 3. Return selected node """ + def __init__(self, model, apath): self.model = model self.apath = apath - # - self._list_str_model = model.split('_') + # + self._list_str_model = model.split("_") self.model = self._list_str_model.pop(0) - - def register_external_node(self, chain): """ if the user specify an external python file to implement the algorithm diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 69b21c087..f8f78c28b 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -6,8 +6,8 @@ import yaml -from domainlab.algos.trainers.compos.matchdg_args import add_args2parser_matchdg from domainlab.algos.trainers.args_dial import add_args2parser_dial +from domainlab.algos.trainers.compos.matchdg_args import add_args2parser_matchdg from domainlab.models.args_jigen import add_args2parser_jigen from domainlab.models.args_vae import add_args2parser_vae from domainlab.utils.logger import Logger @@ -17,185 +17,292 @@ def mk_parser_main(): """ Args for command line definition """ - parser = argparse.ArgumentParser(description='DomainLab') - - parser.add_argument('-c', "--config", default=None, - help="load YAML configuration", dest="config_file", - type=argparse.FileType(mode='r')) - - parser.add_argument('--lr', type=float, default=1e-4, - help='learning rate') - - parser.add_argument('--gamma_reg', type=float, default=0.1, - help='weight of regularization loss') - - parser.add_argument('--es', type=int, default=1, - help='early stop steps') - - parser.add_argument('--seed', type=int, default=0, - help='random seed (default: 0)') - - parser.add_argument('--nocu', action='store_true', default=False, - help='disables CUDA') - - parser.add_argument('--device', type=str, default=None, - help='device name default None') - - parser.add_argument('--gen', action='store_true', default=False, - help='save generated images') - - parser.add_argument('--keep_model', action='store_true', default=False, - help='do not delete model at the end of training') - - parser.add_argument('--epos', default=2, type=int, - help='maximum number of epochs') - - parser.add_argument('--epos_min', default=0, type=int, - help='maximum number of epochs') - - parser.add_argument('--epo_te', default=1, type=int, - help='test performance per {} epochs') - - parser.add_argument('-w', '--warmup', type=int, default=100, - help='number of epochs for hyper-parameter warm-up. \ - Set to 0 to turn warmup off.') - - parser.add_argument('--debug', action='store_true', default=False) - parser.add_argument('--dmem', action='store_true', default=False) - parser.add_argument('--no_dump', action='store_true', default=False, - help='suppress saving the confusion matrix') - - parser.add_argument('--trainer', type=str, default=None, - help='specify which trainer to use') - - parser.add_argument('--out', type=str, default="zoutput", - help='absolute directory to store outputs') - - parser.add_argument('--dpath', type=str, default="zdpath", - help="path for storing downloaded dataset") - - parser.add_argument('--tpath', type=str, default=None, - help="path for custom task, should implement \ - get_task function") - - parser.add_argument('--npath', type=str, default=None, - help="path of custom neural network for feature \ - extraction") - - parser.add_argument('--npath_dom', type=str, default=None, - help="path of custom neural network for feature \ - extraction") - - parser.add_argument('--npath_argna2val', action='append', - help="specify new arguments and their value \ + parser = argparse.ArgumentParser(description="DomainLab") + + parser.add_argument( + "-c", + "--config", + default=None, + help="load YAML configuration", + dest="config_file", + type=argparse.FileType(mode="r"), + ) + + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") + + parser.add_argument( + "--gamma_reg", type=float, default=0.1, help="weight of regularization loss" + ) + + parser.add_argument("--es", type=int, default=1, help="early stop steps") + + parser.add_argument("--seed", type=int, default=0, help="random seed (default: 0)") + + parser.add_argument( + "--nocu", action="store_true", default=False, help="disables CUDA" + ) + + parser.add_argument( + "--device", type=str, default=None, help="device name default None" + ) + + parser.add_argument( + "--gen", action="store_true", default=False, help="save generated images" + ) + + parser.add_argument( + "--keep_model", + action="store_true", + default=False, + help="do not delete model at the end of training", + ) + + parser.add_argument("--epos", default=2, type=int, help="maximum number of epochs") + + parser.add_argument( + "--epos_min", default=0, type=int, help="maximum number of epochs" + ) + + parser.add_argument( + "--epo_te", default=1, type=int, help="test performance per {} epochs" + ) + + parser.add_argument( + "-w", + "--warmup", + type=int, + default=100, + help="number of epochs for hyper-parameter warm-up. \ + Set to 0 to turn warmup off.", + ) + + parser.add_argument("--debug", action="store_true", default=False) + parser.add_argument("--dmem", action="store_true", default=False) + parser.add_argument( + "--no_dump", + action="store_true", + default=False, + help="suppress saving the confusion matrix", + ) + + parser.add_argument( + "--trainer", type=str, default=None, help="specify which trainer to use" + ) + + parser.add_argument( + "--out", type=str, default="zoutput", help="absolute directory to store outputs" + ) + + parser.add_argument( + "--dpath", + type=str, + default="zdpath", + help="path for storing downloaded dataset", + ) + + parser.add_argument( + "--tpath", + type=str, + default=None, + help="path for custom task, should implement \ + get_task function", + ) + + parser.add_argument( + "--npath", + type=str, + default=None, + help="path of custom neural network for feature \ + extraction", + ) + + parser.add_argument( + "--npath_dom", + type=str, + default=None, + help="path of custom neural network for feature \ + extraction", + ) + + parser.add_argument( + "--npath_argna2val", + action="append", + help="specify new arguments and their value \ e.g. '--npath_argna2val my_custom_arg_na \ --npath_argna2val xx/yy/zz.py', additional \ - pairs can be appended") + pairs can be appended", + ) - parser.add_argument('--nname_argna2val', action='append', - help="specify new arguments and their values \ + parser.add_argument( + "--nname_argna2val", + action="append", + help="specify new arguments and their values \ e.g. '--nname_argna2val my_custom_network_arg_na \ --nname_argna2val alexnet', additional pairs \ - can be appended") - - parser.add_argument('--nname', type=str, default=None, - help="name of custom neural network for feature \ - extraction of classification") - - parser.add_argument('--nname_dom', type=str, default=None, - help="name of custom neural network for feature \ - extraction of domain") - - parser.add_argument('--apath', type=str, default=None, - help="path for custom AlgorithmBuilder") - - parser.add_argument('--exptag', type=str, default="exptag", - help='tag as prefix of result aggregation file name \ - e.g. git hash for reproducibility') - - parser.add_argument('--aggtag', type=str, default="aggtag", - help='tag in each line of result aggregation file \ - e.g., to specify potential different configurations') - - parser.add_argument('--agg_partial_bm', type=str, - default=None, dest="bm_dir", - help="Aggregates and plots partial data of a snakemake \ + can be appended", + ) + + parser.add_argument( + "--nname", + type=str, + default=None, + help="name of custom neural network for feature \ + extraction of classification", + ) + + parser.add_argument( + "--nname_dom", + type=str, + default=None, + help="name of custom neural network for feature \ + extraction of domain", + ) + + parser.add_argument( + "--apath", type=str, default=None, help="path for custom AlgorithmBuilder" + ) + + parser.add_argument( + "--exptag", + type=str, + default="exptag", + help="tag as prefix of result aggregation file name \ + e.g. git hash for reproducibility", + ) + + parser.add_argument( + "--aggtag", + type=str, + default="aggtag", + help="tag in each line of result aggregation file \ + e.g., to specify potential different configurations", + ) + + parser.add_argument( + "--agg_partial_bm", + type=str, + default=None, + dest="bm_dir", + help="Aggregates and plots partial data of a snakemake \ benchmark. Requires the benchmark config file. \ - Other arguments will be ignored.") - - parser.add_argument('--gen_plots', type=str, - default=None, dest="plot_data", - help="plots the data of a snakemake benchmark. " - "Requires the results.csv file" - "and an output file (specify by --outp_file," - "default is zoutput/benchmarks/shell_benchmark). " - "Other arguments will be ignored.") - - parser.add_argument('--outp_dir', type=str, - default='zoutput/benchmarks/shell_benchmark', dest="outp_dir", - help="outpus file for the plots when creating them" - "using --gen_plots. " - "Default is zoutput/benchmarks/shell_benchmark") - - parser.add_argument('--param_idx', type=bool, - default=True, dest="param_idx", - help="True: parameter index is used in the " - "pots generated with --gen_plots." - "False: parameter name is used." - "Default is True.") - - parser.add_argument('--msel', choices=['val', 'loss_tr'], default="val", - help='model selection for early stop: val, loss_tr, recon, the \ + Other arguments will be ignored.", + ) + + parser.add_argument( + "--gen_plots", + type=str, + default=None, + dest="plot_data", + help="plots the data of a snakemake benchmark. " + "Requires the results.csv file" + "and an output file (specify by --outp_file," + "default is zoutput/benchmarks/shell_benchmark). " + "Other arguments will be ignored.", + ) + + parser.add_argument( + "--outp_dir", + type=str, + default="zoutput/benchmarks/shell_benchmark", + dest="outp_dir", + help="outpus file for the plots when creating them" + "using --gen_plots. " + "Default is zoutput/benchmarks/shell_benchmark", + ) + + parser.add_argument( + "--param_idx", + type=bool, + default=True, + dest="param_idx", + help="True: parameter index is used in the " + "pots generated with --gen_plots." + "False: parameter name is used." + "Default is True.", + ) + + parser.add_argument( + "--msel", + choices=["val", "loss_tr"], + default="val", + help="model selection for early stop: val, loss_tr, recon, the \ elbo and recon only make sense for vae models,\ - will be ignored by other methods') - - parser.add_argument('--model', metavar="an", type=str, - default=None, - help='algorithm name') - - parser.add_argument('--acon', metavar="ac", type=str, default=None, - help='algorithm configuration name, (default None)') - - parser.add_argument('--task', metavar="ta", type=str, - help='task name') - - arg_group_task = parser.add_argument_group('task args') - - arg_group_task.add_argument('--bs', type=int, default=100, - help='loader batch size for mixed domains') - - arg_group_task.add_argument('--split', type=float, default=0, - help='proportion of training, a value between \ - 0 and 1, 0 means no train-validation split') - - arg_group_task.add_argument('--te_d', nargs='*', default=None, - help='test domain names separated by single space, \ - will be parsed to be list of strings') - - arg_group_task.add_argument('--tr_d', nargs='*', default=None, - help='training domain names separated by \ + will be ignored by other methods", + ) + + parser.add_argument( + "--model", metavar="an", type=str, default=None, help="algorithm name" + ) + + parser.add_argument( + "--acon", + metavar="ac", + type=str, + default=None, + help="algorithm configuration name, (default None)", + ) + + parser.add_argument("--task", metavar="ta", type=str, help="task name") + + arg_group_task = parser.add_argument_group("task args") + + arg_group_task.add_argument( + "--bs", type=int, default=100, help="loader batch size for mixed domains" + ) + + arg_group_task.add_argument( + "--split", + type=float, + default=0, + help="proportion of training, a value between \ + 0 and 1, 0 means no train-validation split", + ) + + arg_group_task.add_argument( + "--te_d", + nargs="*", + default=None, + help="test domain names separated by single space, \ + will be parsed to be list of strings", + ) + + arg_group_task.add_argument( + "--tr_d", + nargs="*", + default=None, + help="training domain names separated by \ single space, will be parsed to be list of \ strings; if not provided then all available \ domains that are not assigned to \ - the test set will be used as training domains') - - arg_group_task.add_argument('--san_check', action='store_true', default=False, - help='save images from the dataset as a sanity check') - - arg_group_task.add_argument('--san_num', type=int, default=8, - help='number of images to be dumped for the sanity check') - - arg_group_task.add_argument('--loglevel', type=str, default='DEBUG', - help='sets the loglevel of the logger') + the test set will be used as training domains", + ) + + arg_group_task.add_argument( + "--san_check", + action="store_true", + default=False, + help="save images from the dataset as a sanity check", + ) + + arg_group_task.add_argument( + "--san_num", + type=int, + default=8, + help="number of images to be dumped for the sanity check", + ) + + arg_group_task.add_argument( + "--loglevel", type=str, default="DEBUG", help="sets the loglevel of the logger" + ) # args for variational auto encoder - arg_group_vae = parser.add_argument_group('vae') + arg_group_vae = parser.add_argument_group("vae") arg_group_vae = add_args2parser_vae(arg_group_vae) - arg_group_matchdg = parser.add_argument_group('matchdg') + arg_group_matchdg = parser.add_argument_group("matchdg") arg_group_matchdg = add_args2parser_matchdg(arg_group_matchdg) - arg_group_jigen = parser.add_argument_group('jigen') + arg_group_jigen = parser.add_argument_group("jigen") arg_group_jigen = add_args2parser_jigen(arg_group_jigen) - args_group_dial = parser.add_argument_group('dial') + args_group_dial = parser.add_argument_group("dial") args_group_dial = add_args2parser_dial(args_group_dial) return parser @@ -213,9 +320,11 @@ def apply_dict_to_args(args, data: dict, extend=False): cur_val = arg_dict.get(key, None) if not isinstance(cur_val, list): if cur_val is not None: - raise RuntimeError(f"input dictionary value is list, \ + raise RuntimeError( + f"input dictionary value is list, \ however, in DomainLab args, we have {cur_val}, \ - going to overrite to list") + going to overrite to list" + ) arg_dict[key] = [] # if args_dict[key] is None, cast it into a list # domainlab will take care of it if this argument can not be a list arg_dict[key].extend(value) # args_dict[key] is already a list @@ -233,10 +342,10 @@ def parse_cmd_args(): """ parser = mk_parser_main() args = parser.parse_args() - logger = Logger.get_logger(logger_name='main_out_logger', loglevel=args.loglevel) + logger = Logger.get_logger(logger_name="main_out_logger", loglevel=args.loglevel) if args.config_file: data = yaml.safe_load(args.config_file) - delattr(args, 'config_file') + delattr(args, "config_file") apply_dict_to_args(args, data) if args.acon is None and args.bm_dir is None: diff --git a/domainlab/compos/a_nn_builder.py b/domainlab/compos/a_nn_builder.py index 59c3589fe..1784c3384 100644 --- a/domainlab/compos/a_nn_builder.py +++ b/domainlab/compos/a_nn_builder.py @@ -13,6 +13,7 @@ class AbstractFeatExtractNNBuilderChainNode(AbstractChainNodeHandler): avoid override the initializer so that node construction is always light weight. """ + def __init__(self, successor_node): """__init__. @@ -22,8 +23,16 @@ def __init__(self, successor_node): super().__init__(successor_node) @store_args - def init_business(self, dim_out, args, i_c=None, i_h=None, i_w=None, - flag_pretrain=None, remove_last_layer=False): + def init_business( + self, + dim_out, + args, + i_c=None, + i_h=None, + i_w=None, + flag_pretrain=None, + remove_last_layer=False, + ): """ initialize **and** return the heavy weight business object for doing the real job diff --git a/domainlab/compos/builder_nn_alex.py b/domainlab/compos/builder_nn_alex.py index bc2263287..a030febf1 100644 --- a/domainlab/compos/builder_nn_alex.py +++ b/domainlab/compos/builder_nn_alex.py @@ -7,8 +7,10 @@ class NodeFeatExtractNNBuilderAlex(AbstractFeatExtractNNBuilderChainNode): """NodeFeatExtractNNBuilderAlex. Uniform interface to return AlexNet and other neural network as feature extractor from torchvision or external python file""" - def init_business(self, dim_out, args, isize=None, - remove_last_layer=False, flag_pretrain=True): + + def init_business( + self, dim_out, args, isize=None, remove_last_layer=False, flag_pretrain=True + ): """ initialize **and** return the heavy weight business object for doing the real job @@ -30,4 +32,5 @@ def is_myjob(self, args): """ arg_name = getattr(args, arg_name4net) return arg_name == arg_val + return NodeFeatExtractNNBuilderAlex diff --git a/domainlab/compos/builder_nn_conv_bn_relu_2.py b/domainlab/compos/builder_nn_conv_bn_relu_2.py index 6e199857a..0238595a2 100644 --- a/domainlab/compos/builder_nn_conv_bn_relu_2.py +++ b/domainlab/compos/builder_nn_conv_bn_relu_2.py @@ -2,8 +2,7 @@ from domainlab.compos.nn_zoo.net_conv_conv_bn_pool_2 import NetConvBnReluPool2L -def mkNodeFeatExtractNNBuilderNameConvBnRelu2(arg_name4net, arg_val, - conv_stride): +def mkNodeFeatExtractNNBuilderNameConvBnRelu2(arg_name4net, arg_val, conv_stride): """mkNodeFeatExtractNNBuilderNameConvBnRelu2. In chain of responsibility selection of neural network, reuse code to add more possibilities of neural network of the same family. @@ -14,18 +13,19 @@ def mkNodeFeatExtractNNBuilderNameConvBnRelu2(arg_name4net, arg_val, :param i_h: :param i_w: """ - class _NodeFeatExtractNNBuilderConvBnRelu2L( - AbstractFeatExtractNNBuilderChainNode): + + class _NodeFeatExtractNNBuilderConvBnRelu2L(AbstractFeatExtractNNBuilderChainNode): """NodeFeatExtractNNBuilderConvBnRelu2L.""" - def init_business(self, dim_out, args, isize, - flag_pretrain=None, remove_last_layer=False): + def init_business( + self, dim_out, args, isize, flag_pretrain=None, remove_last_layer=False + ): """ :param flag_pretrain """ self.net_feat_extract = NetConvBnReluPool2L( - isize=isize, - conv_stride=conv_stride, dim_out_h=dim_out) + isize=isize, conv_stride=conv_stride, dim_out_h=dim_out + ) return self.net_feat_extract def is_myjob(self, args): @@ -35,4 +35,5 @@ def is_myjob(self, args): """ arg_name = getattr(args, arg_name4net) return arg_name == arg_val + return _NodeFeatExtractNNBuilderConvBnRelu2L diff --git a/domainlab/compos/builder_nn_external_from_file.py b/domainlab/compos/builder_nn_external_from_file.py index 7ab244164..0d0161d19 100644 --- a/domainlab/compos/builder_nn_external_from_file.py +++ b/domainlab/compos/builder_nn_external_from_file.py @@ -1,6 +1,7 @@ from domainlab.compos.a_nn_builder import AbstractFeatExtractNNBuilderChainNode -from domainlab.utils.u_import_net_module import \ - build_external_obj_net_module_feat_extract +from domainlab.utils.u_import_net_module import ( + build_external_obj_net_module_feat_extract, +) def mkNodeFeatExtractNNBuilderExternFromFile(arg_name_net_path): @@ -9,13 +10,17 @@ def mkNodeFeatExtractNNBuilderExternFromFile(arg_name_net_path): for diva, there can be class feature extractor and domain feature extractor """ + class _LNodeFeatExtractNNBuilderExternFromFile( - AbstractFeatExtractNNBuilderChainNode): + AbstractFeatExtractNNBuilderChainNode + ): """LNodeFeatExtractNNBuilderExternFromFile. Local class to return """ - def init_business(self, dim_out, args, flag_pretrain, - remove_last_layer, isize=None): + + def init_business( + self, dim_out, args, flag_pretrain, remove_last_layer, isize=None + ): """ initialize **and** return the heavy weight business object for doing the real job @@ -25,12 +30,13 @@ def init_business(self, dim_out, args, flag_pretrain, """ pyfile4net = getattr(args, arg_name_net_path) net = build_external_obj_net_module_feat_extract( - pyfile4net, dim_out, remove_last_layer) + pyfile4net, dim_out, remove_last_layer + ) return net def is_myjob(self, args): - """is_myjob. - """ + """is_myjob.""" pyfile4net = getattr(args, arg_name_net_path) return pyfile4net is not None + return _LNodeFeatExtractNNBuilderExternFromFile diff --git a/domainlab/compos/nn_zoo/net_adversarial.py b/domainlab/compos/nn_zoo/net_adversarial.py index f90821ec9..5ed11c882 100644 --- a/domainlab/compos/nn_zoo/net_adversarial.py +++ b/domainlab/compos/nn_zoo/net_adversarial.py @@ -13,6 +13,7 @@ class AutoGradFunReverseMultiply(Function): https://pytorch.org/docs/stable/autograd.html https://pytorch.org/docs/stable/notes/extending.html#extending-autograd """ + @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha diff --git a/domainlab/compos/nn_zoo/net_classif.py b/domainlab/compos/nn_zoo/net_classif.py index da406490c..7ab1aee22 100644 --- a/domainlab/compos/nn_zoo/net_classif.py +++ b/domainlab/compos/nn_zoo/net_classif.py @@ -6,6 +6,7 @@ class ClassifDropoutReluLinear(nn.Module): """first apply dropout, then relu, then linearly fully connected, without activation""" + def __init__(self, z_dim, target_dim): """ :param z_dim: diff --git a/domainlab/compos/nn_zoo/net_conv_conv_bn_pool_2.py b/domainlab/compos/nn_zoo/net_conv_conv_bn_pool_2.py index e37137c52..c03cf6135 100644 --- a/domainlab/compos/nn_zoo/net_conv_conv_bn_pool_2.py +++ b/domainlab/compos/nn_zoo/net_conv_conv_bn_pool_2.py @@ -17,8 +17,7 @@ def mk_conv_bn_relu_pool(i_channel, conv_stride=1, max_pool_stride=2): :param max_pool_stride: """ conv_net = nn.Sequential( - nn.Conv2d(i_channel, 32, kernel_size=5, - stride=conv_stride, bias=False), + nn.Conv2d(i_channel, 32, kernel_size=5, stride=conv_stride, bias=False), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, stride=max_pool_stride), @@ -47,8 +46,10 @@ def forward(self, tensor_x): """ :param tensor_x: image """ - conv_out = self.conv_net(tensor_x) # conv-bn-relu-pool-conv-bn-relu-pool(no activation) - flat = conv_out.view(-1, self.hdim) # 1024 = 64 * (4*4) + conv_out = self.conv_net( + tensor_x + ) # conv-bn-relu-pool-conv-bn-relu-pool(no activation) + flat = conv_out.view(-1, self.hdim) # 1024 = 64 * (4*4) hidden = self.layer_last(flat) return hidden @@ -60,6 +61,7 @@ class NetConvDense(nn.Module): until classifier. note in encoder, there is extra layer of hidden to mean and scale, in this component, it is replaced with another hidden layer. """ + def __init__(self, isize, conv_stride, dim_out_h, args, dense_layer=None): """ :param dim_out_h: @@ -78,7 +80,9 @@ def forward(self, tensor_x): """ :param tensor_x: image """ - conv_out = self.conv_net(tensor_x) # conv-bn-relu-pool-conv-bn-relu-pool(no activation) - flat = conv_out.view(-1, self.hdim) # 1024 = 64 * (4*4) + conv_out = self.conv_net( + tensor_x + ) # conv-bn-relu-pool-conv-bn-relu-pool(no activation) + flat = conv_out.view(-1, self.hdim) # 1024 = 64 * (4*4) hidden = self.dense_layers(flat) return hidden diff --git a/domainlab/compos/nn_zoo/net_gated.py b/domainlab/compos/nn_zoo/net_gated.py index b0f9c5823..dd9e0999a 100644 --- a/domainlab/compos/nn_zoo/net_gated.py +++ b/domainlab/compos/nn_zoo/net_gated.py @@ -19,34 +19,67 @@ def forward(self, x): return h * g -#========================================================================== + +# ========================================================================== class GatedConv2d(nn.Module): - def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None): + def __init__( + self, + input_channels, + output_channels, + kernel_size, + stride, + padding, + dilation=1, + activation=None, + ): super(GatedConv2d, self).__init__() self.activation = activation self.sigmoid = nn.Sigmoid() - self.h = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) - self.g = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) + self.h = nn.Conv2d( + input_channels, output_channels, kernel_size, stride, padding, dilation + ) + self.g = nn.Conv2d( + input_channels, output_channels, kernel_size, stride, padding, dilation + ) def forward(self, x): if self.activation is None: h = self.h(x) else: - h = self.activation( self.h( x ) ) + h = self.activation(self.h(x)) - g = self.sigmoid( self.g( x ) ) + g = self.sigmoid(self.g(x)) return h * g -#============================================================================== + +# ============================================================================== class Conv2d(nn.Module): - def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None, bias=True): + def __init__( + self, + input_channels, + output_channels, + kernel_size, + stride, + padding, + dilation=1, + activation=None, + bias=True, + ): super(Conv2d, self).__init__() self.activation = activation - self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation, bias=bias) + self.conv = nn.Conv2d( + input_channels, + output_channels, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) def forward(self, x): h = self.conv(x) diff --git a/domainlab/compos/nn_zoo/nn.py b/domainlab/compos/nn_zoo/nn.py index c1d338e2d..85dd6849b 100644 --- a/domainlab/compos/nn_zoo/nn.py +++ b/domainlab/compos/nn_zoo/nn.py @@ -5,6 +5,7 @@ class LayerId(nn.Module): """ used to delete layers """ + def __init__(self): super().__init__() @@ -19,7 +20,10 @@ class DenseNet(nn.Module): """ (input)-dropout-linear-relu-dropout-linear-relu(output) """ - def __init__(self, input_flat_size, out_hidden_size=1024, last_feat_dim=4096, p_dropout=0.5): + + def __init__( + self, input_flat_size, out_hidden_size=1024, last_feat_dim=4096, p_dropout=0.5 + ): """ :param input_flat_size: :param out_hidden_size: diff --git a/domainlab/compos/nn_zoo/nn_alex.py b/domainlab/compos/nn_zoo/nn_alex.py index 5d9fded43..8a1e6339f 100644 --- a/domainlab/compos/nn_zoo/nn_alex.py +++ b/domainlab/compos/nn_zoo/nn_alex.py @@ -38,21 +38,24 @@ class AlexNetBase(NetTorchVisionBase): ) ) """ + def fetch_net(self, flag_pretrain): - self.net_torchvision = torchvisionmodels.alexnet( - pretrained=flag_pretrain) + self.net_torchvision = torchvisionmodels.alexnet(pretrained=flag_pretrain) class Alex4DeepAll(AlexNetBase): """ change the last layer output of AlexNet to the dimension of the """ + def __init__(self, flag_pretrain, dim_y): super().__init__(flag_pretrain) if self.net_torchvision.classifier[6].out_features != dim_y: logger = Logger.get_logger() - logger.info(f"original alex net out dim " - f"{self.net_torchvision.classifier[6].out_features}") + logger.info( + f"original alex net out dim " + f"{self.net_torchvision.classifier[6].out_features}" + ) num_ftrs = self.net_torchvision.classifier[6].in_features self.net_torchvision.classifier[6] = nn.Linear(num_ftrs, dim_y) logger.info(f"re-initialized to {dim_y}") @@ -64,6 +67,7 @@ class AlexNetNoLastLayer(AlexNetBase): the classifier from VAE can then have the same layer depth as erm model so it is fair for comparison """ + def __init__(self, flag_pretrain): super().__init__(flag_pretrain) self.net_torchvision.classifier[6] = LayerId() diff --git a/domainlab/compos/nn_zoo/nn_torchvision.py b/domainlab/compos/nn_zoo/nn_torchvision.py index 05d0143a2..1cca73a43 100644 --- a/domainlab/compos/nn_zoo/nn_torchvision.py +++ b/domainlab/compos/nn_zoo/nn_torchvision.py @@ -1,4 +1,5 @@ import torch.nn as nn + from domainlab.utils.logger import Logger @@ -6,6 +7,7 @@ class NetTorchVisionBase(nn.Module): """ fetch model from torchvision """ + def __init__(self, flag_pretrain): super().__init__() self.net_torchvision = None diff --git a/domainlab/compos/pcr/p_chain_handler.py b/domainlab/compos/pcr/p_chain_handler.py index b8ffaba07..31c428897 100644 --- a/domainlab/compos/pcr/p_chain_handler.py +++ b/domainlab/compos/pcr/p_chain_handler.py @@ -3,6 +3,7 @@ __author__ = "Xudong Sun" import abc + from domainlab.utils.logger import Logger @@ -10,6 +11,7 @@ class Request4Chain(metaclass=abc.ABCMeta): """ define all available fields of request to ensure operation safety """ + @abc.abstractmethod def convert(self, obj): """ @@ -99,13 +101,12 @@ def print_options(self): self._parent_node.print_options() -class DummyBusiness(): +class DummyBusiness: message = "dummy business" class DummyChainNodeHandlerBeaver(AbstractChainNodeHandler): - """Dummy class to show how to inherit from Chain of Responsibility - """ + """Dummy class to show how to inherit from Chain of Responsibility""" def init_business(self, *kargs, **kwargs): return DummyBusiness() @@ -119,8 +120,7 @@ def is_myjob(self, request): class DummyChainNodeHandlerLazy(AbstractChainNodeHandler): - """Dummy class to show how to inherit from Chain of Responsibility - """ + """Dummy class to show how to inherit from Chain of Responsibility""" def init_business(self, *kargs, **kwargs): return DummyBusiness() diff --git a/domainlab/compos/pcr/request.py b/domainlab/compos/pcr/request.py index 9f097f4f1..e3fe5ceee 100644 --- a/domainlab/compos/pcr/request.py +++ b/domainlab/compos/pcr/request.py @@ -1,21 +1,25 @@ from domainlab.utils.utils_class import store_args -class RequestVAEBuilderCHW(): +class RequestVAEBuilderCHW: @store_args def __init__(self, i_c, i_h, i_w, args): pass -class RequestVAEBuilderNN(): + +class RequestVAEBuilderNN: """creates request when input does not come from command-line (args) but from test_exp file""" + @store_args def __init__(self, net_class_d, net_x, net_class_y, i_c, i_h, i_w): """net_class_d, net_x and net_class_y are neural networks defined by the user""" -class RequestTask(): + +class RequestTask: """ Isolate args from Request object of chain of responsibility node for task """ + def __init__(self, args): self.args = args @@ -23,12 +27,13 @@ def __call__(self): return self.args.task -class RequestArgs2ExpCmd(): +class RequestArgs2ExpCmd: """ Isolate args from Request object of chain of responsibility node for experiment For example, args has field names which will couple with experiment class, this request class also serves as isolation class or adaptation class """ + @store_args def __init__(self, args): self.args = args diff --git a/domainlab/compos/vae/a_vae_builder.py b/domainlab/compos/vae/a_vae_builder.py index 4563268cd..7baef60a5 100644 --- a/domainlab/compos/vae/a_vae_builder.py +++ b/domainlab/compos/vae/a_vae_builder.py @@ -16,6 +16,7 @@ class AbstractVAEBuilderChainNode(AbstractChainNodeHandler): avoid override the initializer so that node construction is always light weight. """ + def __init__(self, successor_node): self.args = None self.zd_dim = None diff --git a/domainlab/compos/vae/c_vae_adaptor_model_recon.py b/domainlab/compos/vae/c_vae_adaptor_model_recon.py index f2d8d14df..b6f38cd16 100644 --- a/domainlab/compos/vae/c_vae_adaptor_model_recon.py +++ b/domainlab/compos/vae/c_vae_adaptor_model_recon.py @@ -5,12 +5,13 @@ """ -class AdaptorReconVAEXYD(): +class AdaptorReconVAEXYD: """ This adaptor couples intensively with the heavy-weight model class The model class can be refactored, we do want to use the trained old-version model, which we only need to change this adaptor class. """ + def __init__(self, model): self.model = model @@ -23,8 +24,7 @@ def cal_latent(self, x): we only need to change this method. :param x: """ - q_zd, _, q_zx, _, q_zy, _ = \ - self.model.encoder(x) + q_zd, _, q_zx, _, q_zy, _ = self.model.encoder(x) return q_zd, q_zx, q_zy def recon_ydx(self, zy, zd, zx, x): @@ -41,13 +41,11 @@ def recon_ydx(self, zy, zd, zx, x): return x_mean def cal_prior_zy(self, vec_y): - """ - """ + """ """ p_zy = self.model.net_p_zy(vec_y) return p_zy def cal_prior_zd(self, vec_d): - """ - """ + """ """ p_zd = self.model.net_p_zd(vec_d) return p_zd diff --git a/domainlab/compos/vae/c_vae_builder_classif.py b/domainlab/compos/vae/c_vae_builder_classif.py index 1616dd40f..3478193ef 100644 --- a/domainlab/compos/vae/c_vae_builder_classif.py +++ b/domainlab/compos/vae/c_vae_builder_classif.py @@ -5,8 +5,9 @@ """ from domainlab.compos.nn_zoo.net_classif import ClassifDropoutReluLinear from domainlab.compos.vae.a_vae_builder import AbstractVAEBuilderChainNode -from domainlab.compos.vae.compos.decoder_cond_prior import \ - LSCondPriorLinearBnReluLinearSoftPlus +from domainlab.compos.vae.compos.decoder_cond_prior import ( + LSCondPriorLinearBnReluLinearSoftPlus, +) class ChainNodeVAEBuilderClassifCondPrior(AbstractVAEBuilderChainNode): @@ -16,6 +17,7 @@ class ChainNodeVAEBuilderClassifCondPrior(AbstractVAEBuilderChainNode): - conditional prior 2. Bridge pattern: separate abstraction (vae model) and implementation) """ + def construct_classifier(self, input_dim, output_dim): """ classifier can be used to both classify class-label and domain-label diff --git a/domainlab/compos/vae/c_vae_recon.py b/domainlab/compos/vae/c_vae_recon.py index 0e65cbad3..50e2518db 100644 --- a/domainlab/compos/vae/c_vae_recon.py +++ b/domainlab/compos/vae/c_vae_recon.py @@ -8,19 +8,26 @@ from domainlab.compos.vae.c_vae_adaptor_model_recon import AdaptorReconVAEXYD -class ReconVAEXYD(): +class ReconVAEXYD: """ Adaptor is vital for data generation so this class can be decoupled from model class. The model class can be refactored, we do want to use the trained old-version model, which we only need to change adaptor class. """ + def __init__(self, model, na_adaptor=AdaptorReconVAEXYD): self.model = model self.adaptor = na_adaptor(self.model) - def recon(self, x, vec_y=None, vec_d=None, - sample_p_zy=False, sample_p_zd=False, - scalar_zx2fill=None): + def recon( + self, + x, + vec_y=None, + vec_d=None, + sample_p_zy=False, + sample_p_zd=False, + scalar_zx2fill=None, + ): """ common function """ @@ -44,7 +51,9 @@ def recon(self, x, vec_y=None, vec_d=None, if scalar_zx2fill is not None: recon_zx = torch.zeros_like(zx_loc_q) recon_zx = recon_zx.fill_(scalar_zx2fill) - str_type = "_".join([str_type, "__fill_zx_", str(scalar_zx2fill), "___"]) + str_type = "_".join( + [str_type, "__fill_zx_", str(scalar_zx2fill), "___"] + ) else: recon_zx = zx_loc_q str_type = "_".join([str_type, "__zx_q__"]) @@ -62,10 +71,18 @@ def recon(self, x, vec_y=None, vec_d=None, img_recon = self.adaptor.recon_ydx(recon_zy, recon_zd, recon_zx, x) return img_recon, str_type - def recon_cf(self, x, na_cf, dim_cf, device, - vec_y=None, vec_d=None, - zx2fill=None, - sample_p_zy=False, sample_p_zd=False): + def recon_cf( + self, + x, + na_cf, + dim_cf, + device, + vec_y=None, + vec_d=None, + zx2fill=None, + sample_p_zy=False, + sample_p_zd=False, + ): """ Countefactual reconstruction: :param na_cf: name of counterfactual, 'y' or 'd' @@ -79,10 +96,16 @@ def recon_cf(self, x, na_cf, dim_cf, device, label_cf = torch.zeros(batch_size, dim_cf).to(device) label_cf[:, i] = 1 if na_cf == "y": - img_recon_cf, str_type = self.recon(x, label_cf, vec_d, sample_p_zy, sample_p_zd, zx2fill) + img_recon_cf, str_type = self.recon( + x, label_cf, vec_d, sample_p_zy, sample_p_zd, zx2fill + ) elif na_cf == "d": - img_recon_cf, str_type = self.recon(x, vec_y, label_cf, sample_p_zy, sample_p_zd, zx2fill) + img_recon_cf, str_type = self.recon( + x, vec_y, label_cf, sample_p_zy, sample_p_zd, zx2fill + ) else: - raise RuntimeError("counterfactual image generation can only be 'y' or 'd'") + raise RuntimeError( + "counterfactual image generation can only be 'y' or 'd'" + ) list_recon_cf.append(img_recon_cf) return list_recon_cf, str_type diff --git a/domainlab/compos/vae/compos/decoder_concat_vec_reshape_conv.py b/domainlab/compos/vae/compos/decoder_concat_vec_reshape_conv.py index 2065d88a5..2b0ff5421 100644 --- a/domainlab/compos/vae/compos/decoder_concat_vec_reshape_conv.py +++ b/domainlab/compos/vae/compos/decoder_concat_vec_reshape_conv.py @@ -10,12 +10,19 @@ class DecoderConcatLatentFcReshapeConv(nn.Module): Latent vector re-arranged to image-size directly, then convolute to get the textures of the original image """ - def __init__(self, z_dim, i_c, i_h, i_w, - cls_fun_nll_p_x, - net_fc_z2flat_img, - net_conv, - net_p_x_mean, - net_p_x_log_var): + + def __init__( + self, + z_dim, + i_c, + i_h, + i_w, + cls_fun_nll_p_x, + net_fc_z2flat_img, + net_conv, + net_p_x_mean, + net_p_x_log_var, + ): """ :param z_dim: :param list_im_chw: [channel, height, width] @@ -35,14 +42,18 @@ def cal_p_x_pars_loc_scale(self, vec_z): """ h_flat = self.net_fc_z2flat_img(vec_z) # reshape to image - h_img = h_flat.view(-1, self.list_im_chw[0], self.list_im_chw[1], self.list_im_chw[2]) + h_img = h_flat.view( + -1, self.list_im_chw[0], self.list_im_chw[1], self.list_im_chw[2] + ) h_img_conv = self.net_conv(h_img) # pixel must be positive: enforced by sigmoid activation x_mean = self.net_p_x_mean(h_img_conv) # .view(-1, np.prod(self.list_im_chw)) # remove the saturated part of sigmoid - x_mean = torch.clamp(x_mean, min=0.+1./512., max=1.-1./512.) + x_mean = torch.clamp(x_mean, min=0.0 + 1.0 / 512.0, max=1.0 - 1.0 / 512.0) # negative values - x_logvar = self.net_p_x_log_var(h_img_conv) # .view(-1, np.prod(self.list_im_chw)) + x_logvar = self.net_p_x_log_var( + h_img_conv + ) # .view(-1, np.prod(self.list_im_chw)) return x_mean, x_logvar def concat_ydx(self, zy, zd, zx): diff --git a/domainlab/compos/vae/compos/decoder_concat_vec_reshape_conv_gated_conv.py b/domainlab/compos/vae/compos/decoder_concat_vec_reshape_conv_gated_conv.py index 95561f40e..00fd57b5c 100644 --- a/domainlab/compos/vae/compos/decoder_concat_vec_reshape_conv_gated_conv.py +++ b/domainlab/compos/vae/compos/decoder_concat_vec_reshape_conv_gated_conv.py @@ -6,8 +6,9 @@ import torch.nn as nn from domainlab.compos.nn_zoo.net_gated import Conv2d, GatedConv2d, GatedDense -from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv import \ - DecoderConcatLatentFcReshapeConv +from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv import ( + DecoderConcatLatentFcReshapeConv, +) from domainlab.compos.vae.compos.decoder_losses import NLLPixelLogistic256 @@ -18,6 +19,7 @@ class DecoderConcatLatentFCReshapeConvGatedConv(DecoderConcatLatentFcReshapeConv Latent vector re-arranged to image-size directly, then convolute to get the textures of the original image """ + def __init__(self, z_dim, i_c, i_h, i_w): """ :param z_dim: @@ -26,9 +28,7 @@ def __init__(self, z_dim, i_c, i_h, i_w): """ list_im_chw = [i_c, i_h, i_w] cls_fun_nll_p_x = NLLPixelLogistic256 - net_fc_z2flat_img = nn.Sequential( - GatedDense(z_dim, np.prod(list_im_chw)) - ) + net_fc_z2flat_img = nn.Sequential(GatedDense(z_dim, np.prod(list_im_chw))) net_conv = nn.Sequential( # GatedConv2d @@ -43,13 +43,23 @@ def __init__(self, z_dim, i_c, i_h, i_w): # # hidden image to mean and variance of each pixel # stride(1) and kernel size 1, pad 0 - net_p_x_mean = Conv2d(64, list_im_chw[0], 1, 1, 0, - activation=nn.Sigmoid()) - net_p_x_log_var = Conv2d(64, list_im_chw[0], 1, 1, 0, - activation=nn.Hardtanh(min_val=-4.5, max_val=0.)) - super().__init__(z_dim, i_c, i_h, i_w, - cls_fun_nll_p_x, - net_fc_z2flat_img, - net_conv, - net_p_x_mean, - net_p_x_log_var) + net_p_x_mean = Conv2d(64, list_im_chw[0], 1, 1, 0, activation=nn.Sigmoid()) + net_p_x_log_var = Conv2d( + 64, + list_im_chw[0], + 1, + 1, + 0, + activation=nn.Hardtanh(min_val=-4.5, max_val=0.0), + ) + super().__init__( + z_dim, + i_c, + i_h, + i_w, + cls_fun_nll_p_x, + net_fc_z2flat_img, + net_conv, + net_p_x_mean, + net_p_x_log_var, + ) diff --git a/domainlab/compos/vae/compos/decoder_cond_prior.py b/domainlab/compos/vae/compos/decoder_cond_prior.py index bdb0301b5..d54d1d860 100644 --- a/domainlab/compos/vae/compos/decoder_cond_prior.py +++ b/domainlab/compos/vae/compos/decoder_cond_prior.py @@ -7,6 +7,7 @@ class LSCondPriorLinearBnReluLinearSoftPlus(nn.Module): """ Location-Scale: from hyper-prior to current layer prior distribution """ + def __init__(self, hyper_prior_dim, z_dim, hidden_dim=None): super().__init__() if hidden_dim is None: @@ -14,7 +15,8 @@ def __init__(self, hyper_prior_dim, z_dim, hidden_dim=None): self.net_linear_bn_relu = nn.Sequential( nn.Linear(hyper_prior_dim, self.hidden_dim, bias=False), nn.BatchNorm1d(self.hidden_dim), - nn.ReLU()) + nn.ReLU(), + ) self.fc_loc = nn.Sequential(nn.Linear(self.hidden_dim, z_dim)) # No activation, because latent code z variable can take both negative and positive value self.fc_scale = nn.Sequential(nn.Linear(self.hidden_dim, z_dim), nn.Softplus()) @@ -22,9 +24,9 @@ def __init__(self, hyper_prior_dim, z_dim, hidden_dim=None): # initialization torch.nn.init.xavier_uniform_(self.net_linear_bn_relu[0].weight) torch.nn.init.xavier_uniform_(self.fc_loc[0].weight) - self.fc_loc[0].bias.data.zero_() # No Bias + self.fc_loc[0].bias.data.zero_() # No Bias torch.nn.init.xavier_uniform_(self.fc_scale[0].weight) - self.fc_scale[0].bias.data.zero_() # No Bias + self.fc_scale[0].bias.data.zero_() # No Bias def forward(self, hyper_prior): """ diff --git a/domainlab/compos/vae/compos/decoder_losses.py b/domainlab/compos/vae/compos/decoder_losses.py index 29acbc39c..7eccbc098 100644 --- a/domainlab/compos/vae/compos/decoder_losses.py +++ b/domainlab/compos/vae/compos/decoder_losses.py @@ -2,6 +2,7 @@ Upon pixel wise mean and variance """ import torch + from domainlab import g_inst_component_loss_agg @@ -14,7 +15,8 @@ class NLLPixelLogistic256(object): c.d.f.(x_{i,j}+bin_size/scale) - c.d.f.(x_{i,j}) # https://github.com/openai/iaf/blob/master/tf_utils/distributions.py#L29 """ - def __init__(self, reduce_dims=(1, 2, 3), bin_size=1. / 256.): + + def __init__(self, reduce_dims=(1, 2, 3), bin_size=1.0 / 256.0): """ :param reduce_dims: """ @@ -31,11 +33,11 @@ def __call__(self, tensor, mean, logvar): """ scale = torch.exp(logvar) tensor = (torch.floor(tensor / self.bin_size) * self.bin_size - mean) / scale - cdf_plus = torch.sigmoid(tensor + self.bin_size/scale) + cdf_plus = torch.sigmoid(tensor + self.bin_size / scale) cdf_minus = torch.sigmoid(tensor) # negative log-likelihood for each pixel - log_logist_256 = - torch.log(cdf_plus - cdf_minus + 1.e-7) + log_logist_256 = -torch.log(cdf_plus - cdf_minus + 1.0e-7) # torch.Size([100, 3, 28, 28]) nll = g_inst_component_loss_agg(log_logist_256, dim=self.reduce_dims) # NOTE: pixel NLL should always be summed diff --git a/domainlab/compos/vae/compos/encoder.py b/domainlab/compos/vae/compos/encoder.py index de69a0dc3..b8d475324 100644 --- a/domainlab/compos/vae/compos/encoder.py +++ b/domainlab/compos/vae/compos/encoder.py @@ -6,8 +6,7 @@ import torch.distributions as dist import torch.nn as nn -from domainlab.compos.nn_zoo.net_conv_conv_bn_pool_2 import \ - mk_conv_bn_relu_pool +from domainlab.compos.nn_zoo.net_conv_conv_bn_pool_2 import mk_conv_bn_relu_pool from domainlab.compos.nn_zoo.nn import DenseNet from domainlab.compos.utils_conv_get_flat_dim import get_flat_dim @@ -17,6 +16,7 @@ class LSEncoderConvBnReluPool(nn.Module): Batch Normalization, Relu and Pooling. Softplus for scale """ + def __init__(self, z_dim: int, i_channel, i_h, i_w, conv_stride): """ :param z_dim: @@ -32,13 +32,13 @@ def __init__(self, z_dim: int, i_channel, i_h, i_w, conv_stride): self.i_h = i_h self.i_w = i_w - self.conv = mk_conv_bn_relu_pool(self.i_channel, - conv_stride=conv_stride) + self.conv = mk_conv_bn_relu_pool(self.i_channel, conv_stride=conv_stride) # conv-bn-relu-pool-conv-bn-relu-pool(no activation) self.flat_dim = get_flat_dim(self.conv, i_channel, i_h, i_w) self.fc_loc = nn.Sequential(nn.Linear(self.flat_dim, z_dim)) - self.fc_scale = nn.Sequential(nn.Linear(self.flat_dim, z_dim), - nn.Softplus()) # for scale calculation + self.fc_scale = nn.Sequential( + nn.Linear(self.flat_dim, z_dim), nn.Softplus() + ) # for scale calculation # initialization torch.nn.init.xavier_uniform_(self.fc_loc[0].weight) @@ -64,6 +64,7 @@ class LSEncoderLinear(nn.Module): Location-Scale Encoder with DenseNet as feature extractor Softplus for scale """ + def __init__(self, z_dim, dim_input): """ :param z_dim: @@ -72,8 +73,9 @@ def __init__(self, z_dim, dim_input): """ super().__init__() self.fc_loc = nn.Sequential(nn.Linear(dim_input, z_dim)) - self.fc_scale = nn.Sequential(nn.Linear(dim_input, z_dim), - nn.Softplus()) # for scale calculation + self.fc_scale = nn.Sequential( + nn.Linear(dim_input, z_dim), nn.Softplus() + ) # for scale calculation # initialization torch.nn.init.xavier_uniform_(self.fc_loc[0].weight) diff --git a/domainlab/compos/vae/compos/encoder_dirichlet.py b/domainlab/compos/vae/compos/encoder_dirichlet.py index 6ff232c5b..9862eea01 100644 --- a/domainlab/compos/vae/compos/encoder_dirichlet.py +++ b/domainlab/compos/vae/compos/encoder_dirichlet.py @@ -9,8 +9,7 @@ class EncoderH2Dirichlet(nn.Module): """ def __init__(self, dim_topic, device): - """ - """ + """ """ super().__init__() self.layer_bn = nn.BatchNorm1d(dim_topic) self.layer_concentration = nn.Softplus() diff --git a/domainlab/compos/vae/compos/encoder_domain_topic.py b/domainlab/compos/vae/compos/encoder_domain_topic.py index 2c00edd53..57674871e 100644 --- a/domainlab/compos/vae/compos/encoder_domain_topic.py +++ b/domainlab/compos/vae/compos/encoder_domain_topic.py @@ -1,18 +1,17 @@ import torch.nn as nn -from domainlab.compos.vae.compos.encoder_domain_topic_img2topic import \ - EncoderImg2TopicDistri -from domainlab.compos.vae.compos.encoder_domain_topic_img_topic2zd import \ - EncoderSandwichTopicImg2Zd +from domainlab.compos.vae.compos.encoder_domain_topic_img2topic import ( + EncoderImg2TopicDistri, +) +from domainlab.compos.vae.compos.encoder_domain_topic_img_topic2zd import ( + EncoderSandwichTopicImg2Zd, +) class EncoderImg2TopicDirZd(nn.Module): - """ - """ - def __init__(self, i_c, i_h, i_w, num_topics, - device, - zd_dim, - args): + """ """ + + def __init__(self, i_c, i_h, i_w, num_topics, device, zd_dim, args): """__init__. :param i_c: @@ -21,8 +20,8 @@ def __init__(self, i_c, i_h, i_w, num_topics, :param num_topics: :param device: :param zd_dim: - :param img_h_dim: - - (img->h_img, topic->h_topic)-> q_zd, + :param img_h_dim: + - (img->h_img, topic->h_topic)-> q_zd, the dimension to concatenate with topic vector to infer z_d - img->img_h_dim->topic distribution """ @@ -30,19 +29,22 @@ def __init__(self, i_c, i_h, i_w, num_topics, self.device = device self.zd_dim = zd_dim - self.add_module("net_img2topicdistri", - EncoderImg2TopicDistri( - (i_c, i_h, i_w), num_topics, - device, - args)) + self.add_module( + "net_img2topicdistri", + EncoderImg2TopicDistri((i_c, i_h, i_w), num_topics, device, args), + ) # [topic, image] -> [h(topic), h(image)] -> [zd_mean, zd_scale] self.add_module( - "imgtopic2zd", EncoderSandwichTopicImg2Zd( - self.zd_dim, (i_c, i_h, i_w), + "imgtopic2zd", + EncoderSandwichTopicImg2Zd( + self.zd_dim, + (i_c, i_h, i_w), num_topics, img_h_dim=num_topics, - args=args)) + args=args, + ), + ) def forward(self, img): """forward. diff --git a/domainlab/compos/vae/compos/encoder_domain_topic_img2topic.py b/domainlab/compos/vae/compos/encoder_domain_topic_img2topic.py index 90576c641..f6d1aaa85 100644 --- a/domainlab/compos/vae/compos/encoder_domain_topic_img2topic.py +++ b/domainlab/compos/vae/compos/encoder_domain_topic_img2topic.py @@ -9,9 +9,8 @@ class EncoderImg2TopicDistri(nn.Module): image to topic distribution (not image to topic hidden representation used by another path) """ - def __init__(self, isize, num_topics, - device, - args): + + def __init__(self, isize, num_topics, device, args): """__init__. :param isize: @@ -26,21 +25,25 @@ def __init__(self, isize, num_topics, net_builder = FeatExtractNNBuilderChainNodeGetter( args=args, arg_name_of_net="nname_encoder_x2topic_h", - arg_path_of_net="npath_encoder_x2topic_h")() - - self.add_module("layer_img2hidden", - net_builder.init_business( - flag_pretrain=True, - isize=isize, - remove_last_layer=False, - dim_out=num_topics, - args=args)) + arg_path_of_net="npath_encoder_x2topic_h", + )() + + self.add_module( + "layer_img2hidden", + net_builder.init_business( + flag_pretrain=True, + isize=isize, + remove_last_layer=False, + dim_out=num_topics, + args=args, + ), + ) # h_image->[alpha,topic] - self.add_module("layer_hidden2dirichlet", - EncoderH2Dirichlet( - dim_topic=num_topics, - device=self.device)) + self.add_module( + "layer_hidden2dirichlet", + EncoderH2Dirichlet(dim_topic=num_topics, device=self.device), + ) def forward(self, x): """forward. diff --git a/domainlab/compos/vae/compos/encoder_domain_topic_img_topic2zd.py b/domainlab/compos/vae/compos/encoder_domain_topic_img_topic2zd.py index d3babfd1b..c712c4ce0 100644 --- a/domainlab/compos/vae/compos/encoder_domain_topic_img_topic2zd.py +++ b/domainlab/compos/vae/compos/encoder_domain_topic_img_topic2zd.py @@ -1,16 +1,16 @@ import torch import torch.nn as nn -from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter from domainlab.compos.vae.compos.encoder import LSEncoderLinear +from domainlab.compos.zoo_nn import FeatExtractNNBuilderChainNodeGetter class EncoderSandwichTopicImg2Zd(nn.Module): """ sandwich encoder: (img, s)->zd """ - def __init__(self, zd_dim, isize, num_topics, - img_h_dim, args): + + def __init__(self, zd_dim, isize, num_topics, img_h_dim, args): """ num_topics, img_h_dim: (img->h_img, topic->h_topic)-> q_zd :param img_h_dim: (img->h_img, topic->h_topic)-> q_zd @@ -23,20 +23,26 @@ def __init__(self, zd_dim, isize, num_topics, net_builder = FeatExtractNNBuilderChainNodeGetter( args=args, arg_name_of_net="nname_encoder_sandwich_x2h4zd", - arg_path_of_net="npath_encoder_sandwich_x2h4zd")() + arg_path_of_net="npath_encoder_sandwich_x2h4zd", + )() # image->h_img - self.add_module("layer_img2h4zd", net_builder.init_business( - dim_out=self.img_h_dim, - flag_pretrain=True, - remove_last_layer=False, - isize=isize, args=args)) + self.add_module( + "layer_img2h4zd", + net_builder.init_business( + dim_out=self.img_h_dim, + flag_pretrain=True, + remove_last_layer=False, + isize=isize, + args=args, + ), + ) # [h_img, h_topic] -> zd - self.add_module("encoder_cat_topic_img_h2zd", - LSEncoderLinear( - dim_input=self.img_h_dim+num_topics, - z_dim=self.zd_dim)) + self.add_module( + "encoder_cat_topic_img_h2zd", + LSEncoderLinear(dim_input=self.img_h_dim + num_topics, z_dim=self.zd_dim), + ) def forward(self, img, vec_topic): """forward. diff --git a/domainlab/compos/vae/compos/encoder_xyd_parallel.py b/domainlab/compos/vae/compos/encoder_xyd_parallel.py index 0544e9a2e..2a89f9a95 100644 --- a/domainlab/compos/vae/compos/encoder_xyd_parallel.py +++ b/domainlab/compos/vae/compos/encoder_xyd_parallel.py @@ -1,8 +1,7 @@ import torch.nn as nn from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool -from domainlab.compos.vae.compos.encoder_zy import \ - EncoderConnectLastFeatLayer2Z +from domainlab.compos.vae.compos.encoder_zy import EncoderConnectLastFeatLayer2Z from domainlab.utils.utils_class import store_args @@ -11,6 +10,7 @@ class XYDEncoderParallel(nn.Module): calculate zx, zy, zd vars independently (without order, parallel): x->zx, x->zy, x->zd """ + def __init__(self, net_infer_zd, net_infer_zx, net_infer_zy): super().__init__() self.add_module("net_infer_zd", net_infer_zd) @@ -41,6 +41,7 @@ class XYDEncoderParallelUser(XYDEncoderParallel): """ This class only reimplemented constructor of parent class """ + @store_args def __init__(self, net_class_d, net_x, net_class_y): super().__init__(net_class_d, net_x, net_class_y) @@ -50,6 +51,7 @@ class XYDEncoderParallelConvBnReluPool(XYDEncoderParallel): """ This class only reimplemented constructor of parent class """ + @store_args def __init__(self, zd_dim, zx_dim, zy_dim, i_c, i_h, i_w, conv_stride=1): """ @@ -65,16 +67,16 @@ def __init__(self, zd_dim, zx_dim, zy_dim, i_c, i_h, i_w, conv_stride=1): # Calculated output size: (64x0x0). # Output size is too small net_infer_zd = LSEncoderConvBnReluPool( - self.zd_dim, self.i_c, self.i_w, self.i_h, - conv_stride=conv_stride) + self.zd_dim, self.i_c, self.i_w, self.i_h, conv_stride=conv_stride + ) # if self.zx_dim != 0: # pytorch can generate emtpy tensor, so no need to judge here net_infer_zx = LSEncoderConvBnReluPool( - self.zx_dim, self.i_c, self.i_w, self.i_h, - conv_stride=conv_stride) + self.zx_dim, self.i_c, self.i_w, self.i_h, conv_stride=conv_stride + ) net_infer_zy = LSEncoderConvBnReluPool( - self.zy_dim, self.i_c, self.i_w, self.i_h, - conv_stride=conv_stride) + self.zy_dim, self.i_c, self.i_w, self.i_h, conv_stride=conv_stride + ) super().__init__(net_infer_zd, net_infer_zx, net_infer_zy) @@ -84,9 +86,9 @@ class XYDEncoderParallelAlex(XYDEncoderParallel): at the end of the constructor of this class, the parent class contructor is called """ + @store_args - def __init__(self, zd_dim, zx_dim, zy_dim, i_c, i_h, i_w, args, - conv_stride=1): + def __init__(self, zd_dim, zx_dim, zy_dim, i_c, i_h, i_w, args, conv_stride=1): """ :param zd_dim: :param zx_dim: @@ -100,17 +102,23 @@ def __init__(self, zd_dim, zx_dim, zy_dim, i_c, i_h, i_w, args, # Calculated output size: (64x0x0). # Output size is too small net_infer_zd = LSEncoderConvBnReluPool( - self.zd_dim, self.i_c, self.i_w, self.i_h, - conv_stride=conv_stride) + self.zd_dim, self.i_c, self.i_w, self.i_h, conv_stride=conv_stride + ) # if self.zx_dim != 0: pytorch can generate emtpy tensor, # so no need to judge here net_infer_zx = LSEncoderConvBnReluPool( - self.zx_dim, self.i_c, self.i_w, self.i_h, - conv_stride=conv_stride) - net_infer_zy = EncoderConnectLastFeatLayer2Z(self.zy_dim, True, - i_c, i_h, i_w, args, - arg_name="nname", - arg_path_name="npath") + self.zx_dim, self.i_c, self.i_w, self.i_h, conv_stride=conv_stride + ) + net_infer_zy = EncoderConnectLastFeatLayer2Z( + self.zy_dim, + True, + i_c, + i_h, + i_w, + args, + arg_name="nname", + arg_path_name="npath", + ) super().__init__(net_infer_zd, net_infer_zx, net_infer_zy) @@ -120,26 +128,38 @@ class XYDEncoderParallelExtern(XYDEncoderParallel): at the end of the constructor of this class, the parent class contructor is called """ + @store_args - def __init__(self, zd_dim, zx_dim, zy_dim, args, - i_c, i_h, i_w, conv_stride=1): + def __init__(self, zd_dim, zx_dim, zy_dim, args, i_c, i_h, i_w, conv_stride=1): """ :param zd_dim: :param zx_dim: :param zy_dim: """ - net_infer_zd = EncoderConnectLastFeatLayer2Z(self.zd_dim, True, - i_c, i_h, i_w, args, - arg_name="nname_dom", - arg_path_name="npath_dom") + net_infer_zd = EncoderConnectLastFeatLayer2Z( + self.zd_dim, + True, + i_c, + i_h, + i_w, + args, + arg_name="nname_dom", + arg_path_name="npath_dom", + ) # if self.zx_dim != 0: pytorch can generate emtpy tensor, # so no need to judge zx_dim=0 here net_infer_zx = LSEncoderConvBnReluPool( - self.zx_dim, self.i_c, self.i_w, self.i_h, - conv_stride=conv_stride) + self.zx_dim, self.i_c, self.i_w, self.i_h, conv_stride=conv_stride + ) - net_infer_zy = EncoderConnectLastFeatLayer2Z(self.zy_dim, True, - i_c, i_h, i_w, args, - arg_name="nname", - arg_path_name="npath") + net_infer_zy = EncoderConnectLastFeatLayer2Z( + self.zy_dim, + True, + i_c, + i_h, + i_w, + args, + arg_name="nname", + arg_path_name="npath", + ) super().__init__(net_infer_zd, net_infer_zx, net_infer_zy) diff --git a/domainlab/compos/vae/compos/encoder_xydt_elevator.py b/domainlab/compos/vae/compos/encoder_xydt_elevator.py index 9d7b093e5..194964816 100644 --- a/domainlab/compos/vae/compos/encoder_xydt_elevator.py +++ b/domainlab/compos/vae/compos/encoder_xydt_elevator.py @@ -1,10 +1,8 @@ import torch.nn as nn from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool -from domainlab.compos.vae.compos.encoder_domain_topic import \ - EncoderImg2TopicDirZd -from domainlab.compos.vae.compos.encoder_zy import \ - EncoderConnectLastFeatLayer2Z +from domainlab.compos.vae.compos.encoder_domain_topic import EncoderImg2TopicDirZd +from domainlab.compos.vae.compos.encoder_zy import EncoderConnectLastFeatLayer2Z from domainlab.utils.utils_class import store_args @@ -12,6 +10,7 @@ class XYDTEncoderElevator(nn.Module): """ x->zx, x->zy, x->s, (x,s)->zd """ + def __init__(self, net_infer_zd_topic, net_infer_zx, net_infer_zy): super().__init__() self.add_module("net_infer_zd_topic", net_infer_zd_topic) @@ -42,10 +41,9 @@ class XYDTEncoderArg(XYDTEncoderElevator): """ This class only reimplemented constructor of parent class """ + @store_args - def __init__(self, device, topic_dim, zd_dim, - zx_dim, zy_dim, i_c, i_h, i_w, - args): + def __init__(self, device, topic_dim, zd_dim, zx_dim, zy_dim, i_c, i_h, i_w, args): """ :param zd_dim: :param zx_dim: @@ -64,59 +62,28 @@ def __init__(self, device, topic_dim, zd_dim, # if self.zx_dim != 0: pytorch can generate emtpy tensor, # so no need to judge here net_infer_zx = LSEncoderConvBnReluPool( - self.zx_dim, self.i_c, self.i_w, self.i_h, - conv_stride=1) + self.zx_dim, self.i_c, self.i_w, self.i_h, conv_stride=1 + ) net_infer_zy = EncoderConnectLastFeatLayer2Z( - self.zy_dim, True, i_c, i_h, i_w, args, - arg_name="nname", arg_path_name="npath") - - net_infer_zd_topic = EncoderImg2TopicDirZd(args=args, - num_topics=topic_dim, - zd_dim=self.zd_dim, - i_c=self.i_c, - i_w=self.i_w, - i_h=self.i_h, - device=device) - - super().__init__(net_infer_zd_topic, net_infer_zx, net_infer_zy) + self.zy_dim, + True, + i_c, + i_h, + i_w, + args, + arg_name="nname", + arg_path_name="npath", + ) + net_infer_zd_topic = EncoderImg2TopicDirZd( + args=args, + num_topics=topic_dim, + zd_dim=self.zd_dim, + i_c=self.i_c, + i_w=self.i_w, + i_h=self.i_h, + device=device, + ) -# To remove -class XYDTEncoderConvBnReluPool(XYDTEncoderElevator): - """ - This class only reimplemented constructor of parent class - """ - @store_args - def __init__(self, device, topic_dim, zd_dim, zx_dim, zy_dim, - i_c, i_h, i_w, - conv_stride, - args): - """ - :param zd_dim: - :param zx_dim: - :param zy_dim: - :param i_c: number of image channels - :param i_h: image height - :param i_w: image width - """ - # conv_stride=2 on size 28 got RuntimeError: - # Given input size: (64x1x1). - # Calculated output size: (64x0x0). - # Output size is too small - net_infer_zd_topic = EncoderImg2TopicDirZd(args=args, - num_topics=topic_dim, - zd_dim=self.zd_dim, - i_c=self.i_c, - i_w=self.i_w, - i_h=self.i_h, - device=device) - # if self.zx_dim != 0: pytorch can generate emtpy tensor, - # so no need to judge here - net_infer_zx = LSEncoderConvBnReluPool( - self.zx_dim, self.i_c, self.i_w, self.i_h, - conv_stride=conv_stride) - net_infer_zy = LSEncoderConvBnReluPool( - self.zy_dim, self.i_c, self.i_w, self.i_h, - conv_stride=conv_stride) super().__init__(net_infer_zd_topic, net_infer_zx, net_infer_zy) diff --git a/domainlab/compos/vae/compos/encoder_zy.py b/domainlab/compos/vae/compos/encoder_zy.py index cac366dfe..c406d0829 100644 --- a/domainlab/compos/vae/compos/encoder_zy.py +++ b/domainlab/compos/vae/compos/encoder_zy.py @@ -12,28 +12,32 @@ class EncoderConnectLastFeatLayer2Z(nn.Module): neural network to the latent representation This class should be transparent to where to fetch the network """ - def __init__(self, z_dim, flag_pretrain, - i_c, i_h, i_w, args, arg_name, - arg_path_name): + + def __init__( + self, z_dim, flag_pretrain, i_c, i_h, i_w, args, arg_name, arg_path_name + ): """__init__. :param hidden_size: """ super().__init__() net_builder = FeatExtractNNBuilderChainNodeGetter( - args, arg_name, arg_path_name)() # request + args, arg_name, arg_path_name + )() # request self.net_feat_extract = net_builder.init_business( - flag_pretrain=flag_pretrain, dim_out=z_dim, - remove_last_layer=True, args=args, isize=(i_c, i_h, i_w)) + flag_pretrain=flag_pretrain, + dim_out=z_dim, + remove_last_layer=True, + args=args, + isize=(i_c, i_h, i_w), + ) - size_last_layer_before_z = get_flat_dim( - self.net_feat_extract, i_c, i_h, i_w) + size_last_layer_before_z = get_flat_dim(self.net_feat_extract, i_c, i_h, i_w) - self.net_fc_mean = nn.Sequential( - nn.Linear(size_last_layer_before_z, z_dim)) + self.net_fc_mean = nn.Sequential(nn.Linear(size_last_layer_before_z, z_dim)) self.net_fc_scale = nn.Sequential( - nn.Linear(size_last_layer_before_z, z_dim), - nn.Softplus()) # for scale calculation + nn.Linear(size_last_layer_before_z, z_dim), nn.Softplus() + ) # for scale calculation torch.nn.init.xavier_uniform_(self.net_fc_mean[0].weight) self.net_fc_mean[0].bias.data.zero_() diff --git a/domainlab/compos/vae/utils_request_chain_builder.py b/domainlab/compos/vae/utils_request_chain_builder.py index c740d5e94..153399ad1 100644 --- a/domainlab/compos/vae/utils_request_chain_builder.py +++ b/domainlab/compos/vae/utils_request_chain_builder.py @@ -1,5 +1,9 @@ from domainlab.compos.vae.zoo_vae_builders_classif import ( - NodeVAEBuilderArg, NodeVAEBuilderUser, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool) + NodeVAEBuilderArg, + NodeVAEBuilderImgAlex, + NodeVAEBuilderImgConvBnPool, + NodeVAEBuilderUser, +) from domainlab.compos.vae.zoo_vae_builders_classif_topic import NodeVAEBuilderImgTopic @@ -10,9 +14,9 @@ class VAEChainNodeGetter(object): 3. heavy weight business objective is returned by selected node 4. convert Scenario object to request object, so that class can be reused """ + def __init__(self, request, topic_dim=None): - """ - """ + """ """ self.request = request self.topic_dim = topic_dim diff --git a/domainlab/compos/vae/zoo_vae_builders_classif.py b/domainlab/compos/vae/zoo_vae_builders_classif.py index 1adf43c7e..ef765529d 100644 --- a/domainlab/compos/vae/zoo_vae_builders_classif.py +++ b/domainlab/compos/vae/zoo_vae_builders_classif.py @@ -1,20 +1,25 @@ """ Chain node VAE builders """ -from domainlab.compos.vae.c_vae_builder_classif import \ - ChainNodeVAEBuilderClassifCondPrior -from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import \ - DecoderConcatLatentFCReshapeConvGatedConv +from domainlab.compos.vae.c_vae_builder_classif import ( + ChainNodeVAEBuilderClassifCondPrior, +) +from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import ( + DecoderConcatLatentFCReshapeConvGatedConv, +) from domainlab.compos.vae.compos.encoder_xyd_parallel import ( - XYDEncoderParallelAlex, XYDEncoderParallelConvBnReluPool, - XYDEncoderParallelExtern, XYDEncoderParallelUser) + XYDEncoderParallelAlex, + XYDEncoderParallelConvBnReluPool, + XYDEncoderParallelExtern, + XYDEncoderParallelUser, +) -class ChainNodeVAEBuilderClassifCondPriorBase( - ChainNodeVAEBuilderClassifCondPrior): +class ChainNodeVAEBuilderClassifCondPriorBase(ChainNodeVAEBuilderClassifCondPrior): """ base class of AE builder """ + def config_img(self, flag, request): """config_img. @@ -41,14 +46,16 @@ def build_decoder(self): """build_decoder.""" decoder = DecoderConcatLatentFCReshapeConvGatedConv( z_dim=self.zd_dim + self.zx_dim + self.zy_dim, - i_c=self.i_c, i_w=self.i_w, - i_h=self.i_h) + i_c=self.i_c, + i_w=self.i_w, + i_h=self.i_h, + ) return decoder class NodeVAEBuilderArg(ChainNodeVAEBuilderClassifCondPriorBase): - """Build encoder decoder according to commandline arguments - """ + """Build encoder decoder according to commandline arguments""" + def is_myjob(self, request): """is_myjob. :param request: @@ -63,10 +70,14 @@ def is_myjob(self, request): def build_encoder(self): """build_encoder.""" encoder = XYDEncoderParallelExtern( - self.zd_dim, self.zx_dim, self.zy_dim, args=self.args, + self.zd_dim, + self.zx_dim, + self.zy_dim, + args=self.args, i_c=self.i_c, i_h=self.i_h, - i_w=self.i_w) + i_w=self.i_w, + ) return encoder @@ -80,9 +91,9 @@ def is_myjob(self, request): return flag def build_encoder(self): - encoder = XYDEncoderParallelUser(self.request.net_class_d, - self.request.net_x, - self.request.net_class_y) + encoder = XYDEncoderParallelUser( + self.request.net_class_d, self.request.net_x, self.request.net_class_y + ) return encoder @@ -92,18 +103,18 @@ def is_myjob(self, request): :param request: """ - flag = (request.args.nname == "conv_bn_pool_2" or - request.args.nname_dom == "conv_bn_pool_2") # @FIXME + flag = ( + request.args.nname == "conv_bn_pool_2" + or request.args.nname_dom == "conv_bn_pool_2" + ) # @FIXME self.config_img(flag, request) return flag def build_encoder(self): """build_encoder.""" encoder = XYDEncoderParallelConvBnReluPool( - self.zd_dim, self.zx_dim, self.zy_dim, - self.i_c, - self.i_h, - self.i_w) + self.zd_dim, self.zx_dim, self.zy_dim, self.i_c, self.i_h, self.i_w + ) return encoder @@ -116,15 +127,19 @@ def is_myjob(self, request): :param request: """ self.args = request.args - flag = (self.args.nname == "alexnet") # @FIXME + flag = self.args.nname == "alexnet" # @FIXME self.config_img(flag, request) return flag def build_encoder(self): """build_encoder.""" encoder = XYDEncoderParallelAlex( - self.zd_dim, self.zx_dim, self.zy_dim, + self.zd_dim, + self.zx_dim, + self.zy_dim, self.i_c, self.i_h, - self.i_w, args=self.args) + self.i_w, + args=self.args, + ) return encoder diff --git a/domainlab/compos/vae/zoo_vae_builders_classif_topic.py b/domainlab/compos/vae/zoo_vae_builders_classif_topic.py index 5f630c4c0..2508b40bd 100644 --- a/domainlab/compos/vae/zoo_vae_builders_classif_topic.py +++ b/domainlab/compos/vae/zoo_vae_builders_classif_topic.py @@ -1,14 +1,16 @@ """ Chain node VAE builders """ -from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import \ - DecoderConcatLatentFCReshapeConvGatedConv +from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import ( + DecoderConcatLatentFCReshapeConvGatedConv, +) from domainlab.compos.vae.compos.encoder_xydt_elevator import XYDTEncoderArg from domainlab.compos.vae.zoo_vae_builders_classif import NodeVAEBuilderArg class NodeVAEBuilderImgTopic(NodeVAEBuilderArg): """NodeVAEBuilderImgTopic.""" + def is_myjob(self, request): """is_myjob. @@ -25,13 +27,17 @@ def build_encoder(self, device, topic_dim): :param device: :param topic_dim: """ - encoder = XYDTEncoderArg(device, topic_dim, - self.zd_dim, self.zx_dim, - self.zy_dim, - self.i_c, - self.i_h, - self.i_w, - args=self.args) + encoder = XYDTEncoderArg( + device, + topic_dim, + self.zd_dim, + self.zx_dim, + self.zy_dim, + self.i_c, + self.i_h, + self.i_w, + args=self.args, + ) return encoder def build_decoder(self, topic_dim): @@ -40,7 +46,9 @@ def build_decoder(self, topic_dim): :param topic_dim: """ decoder = DecoderConcatLatentFCReshapeConvGatedConv( - z_dim=self.zd_dim+self.zx_dim+self.zy_dim+topic_dim, - i_c=self.i_c, i_w=self.i_w, - i_h=self.i_h) + z_dim=self.zd_dim + self.zx_dim + self.zy_dim + topic_dim, + i_c=self.i_c, + i_w=self.i_w, + i_h=self.i_h, + ) return decoder diff --git a/domainlab/compos/zoo_nn.py b/domainlab/compos/zoo_nn.py index 79cc7926d..ba3dfed3c 100644 --- a/domainlab/compos/zoo_nn.py +++ b/domainlab/compos/zoo_nn.py @@ -1,8 +1,10 @@ from domainlab.compos.builder_nn_alex import mkNodeFeatExtractNNBuilderNameAlex -from domainlab.compos.builder_nn_conv_bn_relu_2 import \ - mkNodeFeatExtractNNBuilderNameConvBnRelu2 -from domainlab.compos.builder_nn_external_from_file import \ - mkNodeFeatExtractNNBuilderExternFromFile +from domainlab.compos.builder_nn_conv_bn_relu_2 import ( + mkNodeFeatExtractNNBuilderNameConvBnRelu2, +) +from domainlab.compos.builder_nn_external_from_file import ( + mkNodeFeatExtractNNBuilderExternFromFile, +) class FeatExtractNNBuilderChainNodeGetter(object): @@ -10,8 +12,8 @@ class FeatExtractNNBuilderChainNodeGetter(object): 1. Hardcoded chain 3. Return selected node """ - def __init__(self, args, arg_name_of_net, - arg_path_of_net): + + def __init__(self, args, arg_name_of_net, arg_path_of_net): """__init__. :param args: command line arguments :param arg_name_of_net: args.npath to specify @@ -28,14 +30,14 @@ def __call__(self): 2. hard code seems to be the best solution """ chain = mkNodeFeatExtractNNBuilderNameConvBnRelu2( - self.arg_name_of_net, - arg_val="conv_bn_pool_2", conv_stride=1)(None) + self.arg_name_of_net, arg_val="conv_bn_pool_2", conv_stride=1 + )(None) chain = mkNodeFeatExtractNNBuilderNameConvBnRelu2( - arg_name4net="nname_dom", - arg_val="conv_bn_pool_2", conv_stride=1)(chain) - chain = mkNodeFeatExtractNNBuilderNameAlex( - self.arg_name_of_net, "alexnet")(chain) - chain = mkNodeFeatExtractNNBuilderExternFromFile( - self.arg_path_of_net)(chain) + arg_name4net="nname_dom", arg_val="conv_bn_pool_2", conv_stride=1 + )(chain) + chain = mkNodeFeatExtractNNBuilderNameAlex(self.arg_name_of_net, "alexnet")( + chain + ) + chain = mkNodeFeatExtractNNBuilderExternFromFile(self.arg_path_of_net)(chain) node = chain.handle(self.request) return node diff --git a/domainlab/dsets/a_dset_mnist_color_rgb_solo.py b/domainlab/dsets/a_dset_mnist_color_rgb_solo.py index 6b48f6408..07b88df31 100644 --- a/domainlab/dsets/a_dset_mnist_color_rgb_solo.py +++ b/domainlab/dsets/a_dset_mnist_color_rgb_solo.py @@ -22,6 +22,7 @@ class ADsetMNISTColorRGBSolo(Dataset, metaclass=abc.ABCMeta): 3. structure: each subdomain contains a combination of foreground+background color """ + @abc.abstractmethod def get_foreground_color(self, ind): raise NotImplementedError @@ -35,14 +36,17 @@ def get_num_colors(self): raise NotImplementedError @store_args - def __init__(self, ind_color, path="zoutput", - subset_step=100, - color_scheme="both", - label_transform=mk_fun_label2onehot(10), - list_transforms=None, - raw_split='train', - flag_rand_color=False, - ): + def __init__( + self, + ind_color, + path="zoutput", + subset_step=100, + color_scheme="both", + label_transform=mk_fun_label2onehot(10), + list_transforms=None, + raw_split="train", + flag_rand_color=False, + ): """ :param ind_color: index of a color palette :param path: disk storage directory @@ -60,12 +64,11 @@ def __init__(self, ind_color, path="zoutput", flag_train = True if raw_split != "train": flag_train = False - dataset = datasets.MNIST(root=dpath, - train=flag_train, - download=True, - transform=transforms.ToTensor()) + dataset = datasets.MNIST( + root=dpath, train=flag_train, download=True, transform=transforms.ToTensor() + ) - if color_scheme not in ['num', 'back', 'both']: + if color_scheme not in ["num", "back", "both"]: raise ValueError("color must be either 'num', 'back' or 'both") raw_path = os.path.dirname(dataset.raw_folder) self._collect_imgs_labels(raw_path, raw_split) @@ -79,21 +82,20 @@ def _collect_imgs_labels(self, path, raw_split): :param path: :param raw_split: """ - if raw_split == 'train': - fimages = os.path.join(path, 'raw', 'train-images-idx3-ubyte') - flabels = os.path.join(path, 'raw', 'train-labels-idx1-ubyte') + if raw_split == "train": + fimages = os.path.join(path, "raw", "train-images-idx3-ubyte") + flabels = os.path.join(path, "raw", "train-labels-idx1-ubyte") else: - fimages = os.path.join(path, 'raw', 't10k-images-idx3-ubyte') - flabels = os.path.join(path, 'raw', 't10k-labels-idx1-ubyte') + fimages = os.path.join(path, "raw", "t10k-images-idx3-ubyte") + flabels = os.path.join(path, "raw", "t10k-labels-idx1-ubyte") # Load images - with open(fimages, 'rb') as f_h: + with open(fimages, "rb") as f_h: _, _, rows, cols = struct.unpack(">IIII", f_h.read(16)) - self.images = np.fromfile(f_h, dtype=np.uint8).reshape( - -1, rows, cols) + self.images = np.fromfile(f_h, dtype=np.uint8).reshape(-1, rows, cols) # Load labels - with open(flabels, 'rb') as f_h: + with open(flabels, "rb") as f_h: struct.unpack(">II", f_h.read(8)) self.labels = np.fromfile(f_h, dtype=np.int8) self.images = np.tile(self.images[:, :, :, np.newaxis], 3) @@ -107,37 +109,41 @@ def _op_color_img(self, image): """ # randomcolor is a flag orthogonal to num-back-both if self.flag_rand_color: - c_f = self.get_foreground_color(np.random.randint(0, self.get_num_colors())) c_b = 0 - if self.color_scheme == 'both': + if self.color_scheme == "both": count = 0 while True: - c_b = self.get_background_color(np.random.randint(0, self.get_num_colors())) + c_b = self.get_background_color( + np.random.randint(0, self.get_num_colors()) + ) if c_b != c_f and count < 10: # exit loop if background color # is not equal to foreground break else: - if self.color_scheme == 'num': + if self.color_scheme == "num": # domain and class label has perfect mutual information: # assign color # according to their class (0,10) c_f = self.get_foreground_color(self.ind_color) - c_b = np.array([0]*3) - elif self.color_scheme == 'back': # only paint background - c_f = np.array([0]*3) + c_b = np.array([0] * 3) + elif self.color_scheme == "back": # only paint background + c_f = np.array([0] * 3) c_b = self.get_background_color(self.ind_color) else: # paint both background and foreground c_f = self.get_foreground_color(self.ind_color) c_b = self.get_background_color(self.ind_color) - image[:, :, 0] = image[:, :, 0] / 255 * c_f[0] + \ - (255 - image[:, :, 0]) / 255 * c_b[0] - image[:, :, 1] = image[:, :, 1] / 255 * c_f[1] + \ - (255 - image[:, :, 1]) / 255 * c_b[1] - image[:, :, 2] = image[:, :, 2] / 255 * c_f[2] + \ - (255 - image[:, :, 2]) / 255 * c_b[2] + image[:, :, 0] = ( + image[:, :, 0] / 255 * c_f[0] + (255 - image[:, :, 0]) / 255 * c_b[0] + ) + image[:, :, 1] = ( + image[:, :, 1] / 255 * c_f[1] + (255 - image[:, :, 1]) / 255 * c_b[1] + ) + image[:, :, 2] = ( + image[:, :, 2] / 255 * c_f[2] + (255 - image[:, :, 2]) / 255 * c_b[2] + ) return image def _color_imgs_onehot_labels(self): @@ -153,7 +159,7 @@ def __getitem__(self, idx): label = self.labels[idx] if self.label_transform is not None: label = self.label_transform(label) - image = Image.fromarray(image) # numpy array 28*28*3 -> 3*28*28 + image = Image.fromarray(image) # numpy array 28*28*3 -> 3*28*28 if self.list_transforms is not None: for trans in self.list_transforms: image = trans(image) diff --git a/domainlab/dsets/dset_img_path_list.py b/domainlab/dsets/dset_img_path_list.py index e94453d25..bf36feff1 100644 --- a/domainlab/dsets/dset_img_path_list.py +++ b/domainlab/dsets/dset_img_path_list.py @@ -16,14 +16,16 @@ def __init__(self, root_img, path2filelist, trans_img=None, trans_target=None): self.get_list_tuple_img_label() def get_list_tuple_img_label(self): - with open(self.path2filelist, 'r') as f_h: + with open(self.path2filelist, "r") as f_h: for str_line in f_h.readlines(): path_img, label_img = str_line.strip().split() - self.list_tuple_img_label.append((path_img, int(label_img))) # @FIXME: string to int, not necessarily continuous + self.list_tuple_img_label.append( + (path_img, int(label_img)) + ) # @FIXME: string to int, not necessarily continuous def __getitem__(self, index): path_img, target = self.list_tuple_img_label[index] - target = target - 1 # @FIXME: make this more general + target = target - 1 # @FIXME: make this more general img = fun_img_path_loader_default(os.path.join(self.root_img, path_img)) if self.trans_img is not None: img = self.trans_img(img) diff --git a/domainlab/dsets/dset_mnist_color_solo_default.py b/domainlab/dsets/dset_mnist_color_solo_default.py index 9b14781ea..bbf2276c4 100644 --- a/domainlab/dsets/dset_mnist_color_solo_default.py +++ b/domainlab/dsets/dset_mnist_color_solo_default.py @@ -11,10 +11,10 @@ def get_num_colors(self): return len(self.palette) def get_background_color(self, ind): - if self.color_scheme == 'back': + if self.color_scheme == "back": return self.palette[ind] - if self.color_scheme == 'both': - return self.palette[-(ind-3)] + if self.color_scheme == "both": + return self.palette[-(ind - 3)] # only array can be multiplied with number 255 directly return self.palette[ind] # "num" do not use background at all diff --git a/domainlab/dsets/dset_poly_domains_mnist_color_default.py b/domainlab/dsets/dset_poly_domains_mnist_color_default.py index c51cc70c3..47240055c 100644 --- a/domainlab/dsets/dset_poly_domains_mnist_color_default.py +++ b/domainlab/dsets/dset_poly_domains_mnist_color_default.py @@ -4,8 +4,7 @@ import numpy as np from torch.utils.data import Dataset -from domainlab.dsets.dset_mnist_color_solo_default import \ - DsetMNISTColorSoloDefault +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault from domainlab.dsets.utils_data import mk_fun_label2onehot @@ -13,14 +12,15 @@ class DsetMNISTColorMix(Dataset): """ merge several solo-color mnist to form a mixed dataset """ - def __init__(self, n_domains, path, color_scheme='both'): + + def __init__(self, n_domains, path, color_scheme="both"): self.n_domains = n_domains self.list_dset = [None] * n_domains self.fun_dlabel2onehot = mk_fun_label2onehot(n_domains) for domain_ind in range(n_domains): - self.list_dset[domain_ind] = \ - DsetMNISTColorSoloDefault(domain_ind, path, - color_scheme=color_scheme) + self.list_dset[domain_ind] = DsetMNISTColorSoloDefault( + domain_ind, path, color_scheme=color_scheme + ) self.list_len = [len(ds) for ds in self.list_dset] self.size_single = min(self.list_len) @@ -29,7 +29,7 @@ def __len__(self): return sum(self.list_len) def __getitem__(self, idx): - rand_domain = np.random.random_integers(self.n_domains-1) # @FIXME + rand_domain = np.random.random_integers(self.n_domains - 1) # @FIXME idx_local = idx % self.size_single img, c_label = self.list_dset[rand_domain][idx_local] return img, c_label, self.fun_dlabel2onehot(rand_domain) @@ -39,6 +39,7 @@ class DsetMNISTColorMixNoDomainLabel(DsetMNISTColorMix): """ DsetMNISTColorMixNoDomainLabel """ + def __getitem__(self, idx): img, c_label, _ = super().__getitem__(idx) return img, c_label diff --git a/domainlab/dsets/dset_subfolder.py b/domainlab/dsets/dset_subfolder.py index 616157c3d..0d29f18b6 100644 --- a/domainlab/dsets/dset_subfolder.py +++ b/domainlab/dsets/dset_subfolder.py @@ -8,8 +8,10 @@ from typing import Any, Tuple from torchvision.datasets import DatasetFolder + from domainlab.utils.logger import Logger + def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: """ Checks if a file is an allowed extension. @@ -38,8 +40,10 @@ def fetch_img_paths(path_dir, class_to_idx, extensions=None, is_valid_file=None) # raise ValueError( # "Both extensions and is_valid_file cannot be None or not None at the same time") if extensions is not None: + def functor_is_valid_file(filena): return has_file_allowed_extension(filena, extensions) + is_valid_file = functor_is_valid_file for target in sorted(class_to_idx.keys()): apath = os.path.join(path_dir, target) @@ -50,7 +54,7 @@ def functor_is_valid_file(filena): path_file = os.path.join(root, fname) if is_valid_file(path_file): item = (path_file, class_to_idx[target]) - list_tuple_path_cls_ind.append(item) # @FIXME + list_tuple_path_cls_ind.append(item) # @FIXME return list_tuple_path_cls_ind @@ -59,27 +63,39 @@ class DsetSubFolder(DatasetFolder): Only use user provided class names, ignore the other subfolders :param list_class_dir: list of class directories to use as classes """ - def __init__(self, root, loader, list_class_dir, extensions=None, transform=None, - target_transform=None, is_valid_file=None): + + def __init__( + self, + root, + loader, + list_class_dir, + extensions=None, + transform=None, + target_transform=None, + is_valid_file=None, + ): self.list_class_dir = list_class_dir if is_valid_file is not None and extensions is not None: raise ValueError( - "Both extensions and is_valid_file cannot be not None at the same time") + "Both extensions and is_valid_file cannot be not None at the same time" + ) if is_valid_file is None and extensions is None: # setting default extensions - extensions = ('jpg', 'jpeg', 'png') + extensions = ("jpg", "jpeg", "png") logger = Logger.get_logger() logger.warn("no user provided extensions, set to be jpg, jpeg, png") warnings.warn("no user provided extensions, set to be jpg, jpeg, png") - super().__init__(root, - loader, - extensions=extensions, - transform=transform, - target_transform=target_transform, - is_valid_file=is_valid_file) + super().__init__( + root, + loader, + extensions=extensions, + transform=transform, + target_transform=target_transform, + is_valid_file=is_valid_file, + ) classes, class_to_idx = self._find_classes(self.root) samples = fetch_img_paths(self.root, class_to_idx, extensions, is_valid_file) self.classes = classes @@ -121,16 +137,24 @@ def _find_classes(self, mdir): # Faster and available in Python 3.5 and above list_subfolders = [subfolder.name for subfolder in list(os.scandir(mdir))] logger.info(f"list of subfolders {list_subfolders}") - classes = [d.name for d in os.scandir(mdir) \ - if d.is_dir() and d.name in self.list_class_dir] + classes = [ + d.name + for d in os.scandir(mdir) + if d.is_dir() and d.name in self.list_class_dir + ] else: - classes = [d for d in os.listdir(mdir) \ - if os.path.isdir(os.path.join(mdir, d)) and d in self.list_class_dir] - flag_user_input_classes_in_folder = (set(self.list_class_dir) <= set(classes)) + classes = [ + d + for d in os.listdir(mdir) + if os.path.isdir(os.path.join(mdir, d)) and d in self.list_class_dir + ] + flag_user_input_classes_in_folder = set(self.list_class_dir) <= set(classes) if not flag_user_input_classes_in_folder: logger.info(f"user provided class names: {self.list_class_dir}") logger.info(f"subfolder names from folder: {mdir} {classes}") - raise RuntimeError("user provided class names does not match the subfolder names") + raise RuntimeError( + "user provided class names does not match the subfolder names" + ) classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx diff --git a/domainlab/dsets/utils_color_palette.py b/domainlab/dsets/utils_color_palette.py index 6a0f6f106..5699d687c 100644 --- a/domainlab/dsets/utils_color_palette.py +++ b/domainlab/dsets/utils_color_palette.py @@ -1,12 +1,12 @@ default_rgb_palette = [ - [31, 119, 180], - [255, 127, 14], - [44, 160, 44], - [214, 39, 40], - [148, 103, 189], - [140, 86, 75], - [227, 119, 194], - [127, 127, 127], - [188, 189, 34], - [23, 190, 207] - ] + [31, 119, 180], + [255, 127, 14], + [44, 160, 44], + [214, 39, 40], + [148, 103, 189], + [140, 86, 75], + [227, 119, 194], + [127, 127, 127], + [188, 189, 34], + [23, 190, 207], +] diff --git a/domainlab/dsets/utils_data.py b/domainlab/dsets/utils_data.py index a363e9075..319b3d5a2 100644 --- a/domainlab/dsets/utils_data.py +++ b/domainlab/dsets/utils_data.py @@ -8,6 +8,7 @@ from PIL import Image from torch.utils.data import Dataset from torchvision.utils import save_image + from domainlab.utils.logger import Logger @@ -15,7 +16,7 @@ def fun_img_path_loader_default(path): """ https://discuss.pytorch.org/t/handling-rgba-images/88428/4 """ - return Image.open(path).convert('RGB') + return Image.open(path).convert("RGB") def mk_fun_label2onehot(dim): @@ -23,12 +24,14 @@ def mk_fun_label2onehot(dim): function generator index to onehot """ + def fun_label2onehot(label): """ :param label: """ m_eye = torch.eye(dim) return m_eye[label] + return fun_label2onehot @@ -67,6 +70,7 @@ class DsetInMemDecorator(Dataset): """ fetch all items of a dataset into memory """ + def __init__(self, dset, name=None): """ :param dset: x, y, *d diff --git a/domainlab/dsets/utils_wrapdset_patches.py b/domainlab/dsets/utils_wrapdset_patches.py index 529b227da..79fe6d7c1 100644 --- a/domainlab/dsets/utils_wrapdset_patches.py +++ b/domainlab/dsets/utils_wrapdset_patches.py @@ -5,35 +5,41 @@ https://github.com/fmcarlucci/JigenDG/blob/master/data/JigsawLoader.py """ import os + import numpy as np import torch import torchvision -from torchvision import transforms from torch.utils import data as torchdata +from torchvision import transforms class WrapDsetPatches(torchdata.Dataset): """ given dataset of images, return permuations of tiles of images re-weaved """ - def __init__(self, dataset, - num_perms2classify, - prob_no_perm, - grid_len, - ppath=None, - flag_do_not_weave_tiles=False): + + def __init__( + self, + dataset, + num_perms2classify, + prob_no_perm, + grid_len, + ppath=None, + flag_do_not_weave_tiles=False, + ): """ :param prob_no_perm: probability of no permutation: permutation will change the image, so the class label classifier will behave very differently compared to no permutation """ if ppath is None and grid_len != 3: - raise RuntimeError("please provide npy file of numpy array with each row \ + raise RuntimeError( + "please provide npy file of numpy array with each row \ being a permutation of the number of tiles, currently \ - we only support grid length 3") + we only support grid length 3" + ) self.dataset = dataset self._to_tensor = transforms.Compose([transforms.ToTensor()]) - self.arr1perm_per_row = self.__retrieve_permutations( - num_perms2classify, ppath) + self.arr1perm_per_row = self.__retrieve_permutations(num_perms2classify, ppath) # for 3*3 tiles, there are 9*8*7*6*5*...*1 >> 100, # we load from disk instead only 100 permutations # each row of the loaded array is a permutation of the 3*3 tile @@ -44,12 +50,13 @@ def __init__(self, dataset, if flag_do_not_weave_tiles: self.fun_weave_imgs = lambda x: x else: + def make_grid(img): """ sew tiles together to be an image """ - return torchvision.utils.make_grid( - img, nrow=self.grid_len, padding=0) + return torchvision.utils.make_grid(img, nrow=self.grid_len, padding=0) + self.fun_weave_imgs = make_grid def get_tile(self, img, ind_tile): @@ -70,10 +77,14 @@ def get_tile(self, img, ind_tile): img_pil = functor_tr(img) # PIL.crop((left, top, right, bottom)) # get rectangular region from box of [left, upper, right, lower] - tile = img_pil.crop([ind_horizontal * num_tiles, - ind_vertical * num_tiles, - (ind_horizontal + 1) * num_tiles, - (ind_vertical + 1) * num_tiles]) + tile = img_pil.crop( + [ + ind_horizontal * num_tiles, + ind_vertical * num_tiles, + (ind_horizontal + 1) * num_tiles, + (ind_vertical + 1) * num_tiles, + ] + ) tile = self._to_tensor(tile) return tile @@ -86,14 +97,13 @@ def __getitem__(self, index): dlabel = domain[0] else: dlabel = None - num_grids = self.grid_len ** 2 + num_grids = self.grid_len**2 # divide image into grid_len^2 tiles list_tiles = [None] * num_grids # list of length num_grids of image tiles for ind_tile in range(num_grids): - list_tiles[ind_tile] = self.get_tile(img, ind_tile) # populate tile list - ind_which_perm = np.random.randint( - self.arr1perm_per_row.shape[0] + 1) + list_tiles[ind_tile] = self.get_tile(img, ind_tile) # populate tile list + ind_which_perm = np.random.randint(self.arr1perm_per_row.shape[0] + 1) # +1 in line above is for when image is not permutated, which # also need to be classified corrected by the permutation classifier # let len(self.arr1perm_per_row)=31 @@ -113,18 +123,20 @@ def __getitem__(self, index): list_reordered_tiles = None if ind_which_perm == 0: list_reordered_tiles = list_tiles # no permutation of images - else: # default + else: # default perm_chosen = self.arr1perm_per_row[ind_which_perm - 1] - list_reordered_tiles = [list_tiles[perm_chosen[ind_tile]] - for ind_tile in range(num_grids)] + list_reordered_tiles = [ + list_tiles[perm_chosen[ind_tile]] for ind_tile in range(num_grids) + ] stacked_tiles = torch.stack(list_reordered_tiles, 0) # NOTE: stacked_tiles will be [9, 3, 30, 30], which will be weaved to # be a whole image again by self.fun_weave_imgs # NOTE: ind_which_perm = 0 means no permutation, the classifier need to # judge if the image has not been permutated as well re_tiled_img = self.fun_weave_imgs(stacked_tiles) - img_re_tiled_re_shaped = \ - torchvision.transforms.RandomResizedCrop(original_size)(re_tiled_img) + img_re_tiled_re_shaped = torchvision.transforms.RandomResizedCrop( + original_size + )(re_tiled_img) return img_re_tiled_re_shaped, label, dlabel, int(ind_which_perm) # ind_which_perm is the ground truth for the permutation index @@ -140,7 +152,7 @@ def __retrieve_permutations(self, num_perms_as_classes, ppath=None): # @FIXME: this assumes always a relative path mdir = os.path.dirname(os.path.realpath(__file__)) if ppath is None: - ppath = f'zdata/patches_permutation4jigsaw/permutations_{num_perms_as_classes}.npy' + ppath = f"zdata/patches_permutation4jigsaw/permutations_{num_perms_as_classes}.npy" mpath = os.path.join(mdir, "..", ppath) arr_permutation_rows = np.load(mpath) # from range [1,9] to [0,8] since python array start with 0 diff --git a/domainlab/exp/exp_main.py b/domainlab/exp/exp_main.py index 147400083..51397131b 100755 --- a/domainlab/exp/exp_main.py +++ b/domainlab/exp/exp_main.py @@ -5,20 +5,20 @@ import os import warnings - from domainlab.algos.zoo_algos import AlgoBuilderChainNodeGetter from domainlab.exp.exp_utils import AggWriter from domainlab.tasks.zoo_tasks import TaskChainNodeGetter -from domainlab.utils.sanity_check import SanityCheck from domainlab.utils.logger import Logger +from domainlab.utils.sanity_check import SanityCheck -os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # debug +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # debug -class Exp(): +class Exp: """ Exp is combination of Task, Algorithm, and Configuration (including random seed) """ + def __init__(self, args, task=None, model=None, observer=None, visitor=AggWriter): """ :param args: @@ -31,11 +31,15 @@ def __init__(self, args, task=None, model=None, observer=None, visitor=AggWriter self.task = TaskChainNodeGetter(args)() self.args = args - algo_builder = AlgoBuilderChainNodeGetter(self.args.model, self.args.apath)() # request + algo_builder = AlgoBuilderChainNodeGetter( + self.args.model, self.args.apath + )() # request # the critical logic below is to avoid circular dependence between task initialization # and trainer initialization: - self.trainer, self.model, observer_default, device = algo_builder.init_business(self) + self.trainer, self.model, observer_default, device = algo_builder.init_business( + self + ) # sanity check has to be done after init_business # jigen algorithm builder has method dset_decoration_args_algo, which could AOP # into the task intilization process @@ -64,18 +68,20 @@ def execute(self, num_epochs=None): num_epochs = self.epochs + 1 t_0 = datetime.datetime.now() logger = Logger.get_logger() - logger.info(f'\n Experiment start at: {str(t_0)}') + logger.info(f"\n Experiment start at: {str(t_0)}") t_c = t_0 self.trainer.before_tr() for epoch in range(1, num_epochs): t_before_epoch = t_c flag_stop = self.trainer.tr_epoch(epoch) t_c = datetime.datetime.now() - logger.info(f"after epoch: {epoch}," - f"now: {str(t_c)}," - f"epoch time: {t_c - t_before_epoch}," - f"used: {t_c - t_0}," - f"model: {self.visitor.model_name}") + logger.info( + f"after epoch: {epoch}," + f"now: {str(t_c)}," + f"epoch time: {t_c - t_before_epoch}," + f"used: {t_c - t_0}," + f"model: {self.visitor.model_name}" + ) logger.info(f"working direcotry: {self.curr_dir}") # current time, time since experiment start, epoch time if flag_stop: @@ -86,8 +92,10 @@ def execute(self, num_epochs=None): self.epoch_counter = self.epochs else: self.epoch_counter += 1 - logger.info(f"Experiment finished at epoch: {self.epoch_counter} " - f"with time: {t_c - t_0} at {t_c}") + logger.info( + f"Experiment finished at epoch: {self.epoch_counter} " + f"with time: {t_c - t_0} at {t_c}" + ) self.trainer.post_tr() def clean_up(self): diff --git a/domainlab/exp/exp_utils.py b/domainlab/exp/exp_utils.py index 0782a9357..2af681731 100644 --- a/domainlab/exp/exp_utils.py +++ b/domainlab/exp/exp_utils.py @@ -5,9 +5,9 @@ import copy import datetime import os -import numpy as np from pathlib import Path +import numpy as np import torch from sklearn.metrics import ConfusionMatrixDisplay @@ -15,10 +15,11 @@ from domainlab.utils.logger import Logger -class ExpModelPersistVisitor(): +class ExpModelPersistVisitor: """ This class couples with Task class attributes """ + model_dir = "saved_models" model_suffix = ".model" @@ -31,17 +32,15 @@ def __init__(self, host): """ self.host = host self.out = host.args.out - self.model_dir = os.path.join(self.out, - ExpModelPersistVisitor.model_dir) + self.model_dir = os.path.join(self.out, ExpModelPersistVisitor.model_dir) self.git_tag = get_git_tag() - self.task_name = self.host.task.get_na(self.host.args.tr_d, - self.host.args.te_d) + self.task_name = self.host.task.get_na(self.host.args.tr_d, self.host.args.te_d) self.algo_name = self.host.args.model self.seed = self.host.args.seed self.model_name = self.mk_model_na(self.git_tag) - self.model_path = os.path.join(self.model_dir, - self.model_name + - ExpModelPersistVisitor.model_suffix) + self.model_path = os.path.join( + self.model_dir, self.model_name + ExpModelPersistVisitor.model_suffix + ) Path(os.path.dirname(self.model_path)).mkdir(parents=True, exist_ok=True) self.model = copy.deepcopy(self.host.trainer.model) @@ -59,19 +58,22 @@ def mk_model_na(self, tag=None, dd_cut=19): suffix_t = str(datetime.datetime.now())[:dd_cut].replace(" ", "_") suffix_t = suffix_t.replace("-", "md_") suffix_t = suffix_t.replace(":", "_") - list4mname = [self.task_name, - self.algo_name, - tag, suffix_t, - "seed", - str(self.seed)] + list4mname = [ + self.task_name, + self.algo_name, + tag, + suffix_t, + "seed", + str(self.seed), + ] # the sequence of components (e.g. seed in the last place) # in model name is not crutial model_name = "_".join(list4mname) if self.host.args.debug: model_name = "debug_" + model_name - slurm = os.environ.get('SLURM_JOB_ID') + slurm = os.environ.get("SLURM_JOB_ID") if slurm: - model_name = model_name + '_' + slurm + model_name = model_name + "_" + slurm logger = Logger.get_logger() logger.info(f"model name: {model_name}") return model_name @@ -128,6 +130,7 @@ class AggWriter(ExpModelPersistVisitor): 1. aggregate results to text file. 2. all dependencies are in the constructor """ + def __init__(self, host): super().__init__(host) self.agg_tag = self.host.args.aggtag @@ -159,13 +162,14 @@ def get_cols(self): """ epos_name = "epos" dict_cols = { - "algo": self.algo_name, - epos_name: None, - "seed": self.seed, - "aggtag": self.agg_tag, - # algorithm configuration for instance - "mname": "mname_" + self.model_name, - "commit": "commit_" + self.git_tag} + "algo": self.algo_name, + epos_name: None, + "seed": self.seed, + "aggtag": self.agg_tag, + # algorithm configuration for instance + "mname": "mname_" + self.model_name, + "commit": "commit_" + self.git_tag, + } return dict_cols, epos_name def _gen_line(self, dict_metric): @@ -186,9 +190,10 @@ def get_fpath(self, dirname="aggrsts"): for writing and reading, the same function is called to ensure name change in the future will not break the software """ - list4fname = [self.task_name, - self.exp_tag, - ] + list4fname = [ + self.task_name, + self.exp_tag, + ] fname = "_".join(list4fname) + ".csv" if self.debug: fname = "_".join(["debug_agg", fname]) @@ -203,7 +208,7 @@ def to_file(self, str_line): Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True) logger = Logger.get_logger() logger.info(f"results aggregation path: {file_path}") - with open(file_path, 'a', encoding="utf8") as f_h: + with open(file_path, "a", encoding="utf8") as f_h: print(str_line, file=f_h) def confmat_to_file(self, confmat, confmat_filename): @@ -222,7 +227,9 @@ def confmat_to_file(self, confmat, confmat_filename): # @FIXME: still we want to have mname_ as a variable defined in some # configuration file in the future. confmat_filename = confmat_filename.removeprefix("mname_") - file_path = os.path.join(os.path.dirname(file_path), f"{confmat_filename}_conf_mat.png") + file_path = os.path.join( + os.path.dirname(file_path), f"{confmat_filename}_conf_mat.png" + ) logger = Logger.get_logger() logger.info(f"confusion matrix saved in file: {file_path}") disp.figure_.savefig(file_path) @@ -233,6 +240,7 @@ class ExpProtocolAggWriter(AggWriter): AggWriter tailored to experimental protocol Output contains additionally index, exp task, te_d and params. """ + def get_cols(self): """columns""" epos_name = "epos" @@ -245,7 +253,7 @@ def get_cols(self): epos_name: None, "te_d": self.host.args.te_d, "seed": self.seed, - "params": f"\"{self.host.args.params}\"", + "params": f'"{self.host.args.params}"', } return dict_cols, epos_name @@ -268,11 +276,10 @@ def confmat_to_file(self, confmat, confmat_filename): confmat_filename = confmat_filename.removeprefix("mname_") path4file = os.path.join(path4file, "confusion_matrix") os.makedirs(path4file, exist_ok=True) - file_path = os.path.join(path4file, - f"{index}.txt") - with open(file_path, 'a', encoding="utf8") as f_h: + file_path = os.path.join(path4file, f"{index}.txt") + with open(file_path, "a", encoding="utf8") as f_h: print(confmat_filename, file=f_h) for line in np.matrix(confmat): - np.savetxt(f_h, line, fmt='%.2f') + np.savetxt(f_h, line, fmt="%.2f") logger = Logger.get_logger() logger.info(f"confusion matrix saved in file: {file_path}") diff --git a/domainlab/exp_protocol/aggregate_results.py b/domainlab/exp_protocol/aggregate_results.py index a084f3651..d67cf5a0a 100644 --- a/domainlab/exp_protocol/aggregate_results.py +++ b/domainlab/exp_protocol/aggregate_results.py @@ -22,9 +22,9 @@ def agg_results(input_files: List[str], output_file: str): has_header = False # logger = Logger.get_logger() # logger.debug(f"exp_results={input.exp_results}") - with open(output_file, 'w') as out_stream: + with open(output_file, "w") as out_stream: for res in input_files: - with open(res, 'r') as in_stream: + with open(res, "r") as in_stream: if has_header: # skip header line in_stream.readline() @@ -46,4 +46,4 @@ def agg_main(bm_dir: str, skip_plotting: bool = False): agg_output = f"{bm_dir}/results.csv" agg_input = f"{bm_dir}/rule_results" agg_from_directory(agg_input, agg_output) - gen_benchmark_plots(agg_output, f'{bm_dir}/graphics', skip_plotting) + gen_benchmark_plots(agg_output, f"{bm_dir}/graphics", skip_plotting) diff --git a/domainlab/exp_protocol/run_experiment.py b/domainlab/exp_protocol/run_experiment.py index 45144104b..81cdd091f 100644 --- a/domainlab/exp_protocol/run_experiment.py +++ b/domainlab/exp_protocol/run_experiment.py @@ -3,18 +3,19 @@ and each random seed. """ import ast -import gc import copy +import gc + import numpy as np import pandas as pd import torch -from domainlab.arg_parser import mk_parser_main, apply_dict_to_args +from domainlab.arg_parser import apply_dict_to_args, mk_parser_main from domainlab.exp.exp_cuda_seed import set_seed from domainlab.exp.exp_main import Exp from domainlab.exp.exp_utils import ExpProtocolAggWriter -from domainlab.utils.logger import Logger from domainlab.utils.hyperparameter_sampling import G_METHOD_NA +from domainlab.utils.logger import Logger def load_parameters(file: str, index: int) -> tuple: @@ -47,13 +48,13 @@ def convert_dict2float(dict_in): def run_experiment( - config: dict, - param_file: str, - param_index: int, - out_file: str, - start_seed=None, - misc=None, - num_gpus=1 + config: dict, + param_file: str, + param_index: int, + out_file: str, + start_seed=None, + misc=None, + num_gpus=1, ): """ Runs the experiment several times: @@ -78,40 +79,62 @@ def run_experiment( misc = {} str_algo_as_task, hyperparameters = load_parameters(param_file, param_index) logger = Logger.get_logger() - logger.debug("\n*******************************************************************") - logger.debug(f"{str_algo_as_task}, param_index={param_index}, params={hyperparameters}") - logger.debug("*******************************************************************\n") - misc['result_file'] = out_file - misc['params'] = hyperparameters - misc['benchmark_task_name'] = str_algo_as_task - misc['param_index'] = param_index - misc['keep_model'] = False + logger.debug( + "\n*******************************************************************" + ) + logger.debug( + f"{str_algo_as_task}, param_index={param_index}, params={hyperparameters}" + ) + logger.debug( + "*******************************************************************\n" + ) + misc["result_file"] = out_file + misc["params"] = hyperparameters + misc["benchmark_task_name"] = str_algo_as_task + misc["param_index"] = param_index + misc["keep_model"] = False parser = mk_parser_main() args = parser.parse_args(args=[]) args_algo_specific = config[str_algo_as_task].copy() - if 'hyperparameters' in args_algo_specific: - del args_algo_specific['hyperparameters'] + if "hyperparameters" in args_algo_specific: + del args_algo_specific["hyperparameters"] args_domainlab_common_raw = config.get("domainlab_args", {}) args_domainlab_common = convert_dict2float(args_domainlab_common_raw) # check if some of the hyperparameters are already specified # in args_domainlab_common or args_algo_specific - if np.intersect1d(list(args_algo_specific.keys()), - list(hyperparameters.keys())).shape[0] > 0: - logger.error(f"the hyperparameter " - f"{np.intersect1d(list(args_algo_specific.keys()), list(hyperparameters.keys()))}" - f" has already been fixed to a value in the algorithm section.") - raise RuntimeError(f"the hyperparameter " - f"{np.intersect1d(list(args_algo_specific.keys()), list(hyperparameters.keys()))}" - f" has already been fixed to a value in the algorithm section.") - if np.intersect1d(list(args_domainlab_common.keys()), - list(hyperparameters.keys())).shape[0] > 0: - logger.error(f"the hyperparameter " - f"{np.intersect1d(list(args_algo_specific.keys()), list(hyperparameters.keys()))}" - f" has already been fixed to a value in the domainlab_args section.") - raise RuntimeError(f"the hyperparameter " - f"{np.intersect1d(list(args_algo_specific.keys()), list(hyperparameters.keys()))}" - f" has already been fixed to a value in the domainlab_args section.") + if ( + np.intersect1d( + list(args_algo_specific.keys()), list(hyperparameters.keys()) + ).shape[0] + > 0 + ): + logger.error( + f"the hyperparameter " + f"{np.intersect1d(list(args_algo_specific.keys()), list(hyperparameters.keys()))}" + f" has already been fixed to a value in the algorithm section." + ) + raise RuntimeError( + f"the hyperparameter " + f"{np.intersect1d(list(args_algo_specific.keys()), list(hyperparameters.keys()))}" + f" has already been fixed to a value in the algorithm section." + ) + if ( + np.intersect1d( + list(args_domainlab_common.keys()), list(hyperparameters.keys()) + ).shape[0] + > 0 + ): + logger.error( + f"the hyperparameter " + f"{np.intersect1d(list(args_algo_specific.keys()), list(hyperparameters.keys()))}" + f" has already been fixed to a value in the domainlab_args section." + ) + raise RuntimeError( + f"the hyperparameter " + f"{np.intersect1d(list(args_algo_specific.keys()), list(hyperparameters.keys()))}" + f" has already been fixed to a value in the domainlab_args section." + ) apply_dict_to_args(args, args_domainlab_common) args_algo_specific_scientific_notation = convert_dict2float(args_algo_specific) apply_dict_to_args(args, args_algo_specific_scientific_notation, extend=True) @@ -125,12 +148,12 @@ def run_experiment( logger.info("before experiment loop: ") logger.info(str(torch.cuda.memory_summary())) if start_seed is None: - start_seed = config['startseed'] - end_seed = config['endseed'] + start_seed = config["startseed"] + end_seed = config["endseed"] else: - end_seed = start_seed + (config['endseed'] - config['startseed']) + end_seed = start_seed + (config["endseed"] - config["startseed"]) for seed in range(start_seed, end_seed + 1): - for te_d in config['test_domains']: + for te_d in config["test_domains"]: args.te_d = te_d set_seed(seed) args.seed = seed @@ -145,7 +168,7 @@ def run_experiment( exp = Exp(args=args, visitor=ExpProtocolAggWriter) # NOTE: if key "testing" is set in benchmark, then do not execute # experiment - if not misc.get('testing', False): + if not misc.get("testing", False): exp.execute() try: if torch.cuda.is_available(): diff --git a/domainlab/mk_exp.py b/domainlab/mk_exp.py index 086af3937..6123df8c4 100644 --- a/domainlab/mk_exp.py +++ b/domainlab/mk_exp.py @@ -22,7 +22,9 @@ def mk_exp(task, model, trainer: str, test_domain: str, batchsize: int, nocu=Fal Returns: experiment """ - str_arg = f"--model=apimodel --trainer={trainer} --te_d={test_domain} --bs={batchsize}" + str_arg = ( + f"--model=apimodel --trainer={trainer} --te_d={test_domain} --bs={batchsize}" + ) if nocu: str_arg += " --nocu" parser = mk_parser_main() diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index fe249bd12..beb867167 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -5,6 +5,7 @@ import abc from torch import nn + from domainlab import g_list_model_penalized_reg_agg @@ -12,6 +13,7 @@ class AModel(nn.Module, metaclass=abc.ABCMeta): """ operations that all models (classification, segmentation, seq2seq) """ + def __init__(self): super().__init__() self._decoratee = None @@ -44,7 +46,9 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): """ calculate the loss """ - list_loss, list_multiplier = self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + list_loss, list_multiplier = self.cal_reg_loss( + tensor_x, tensor_y, tensor_d, others + ) loss_reg = self.list_inner_product(list_loss, list_multiplier) loss_task_alone = self.cal_task_loss(tensor_x, tensor_y) loss_task = self.multiplier4task_loss * loss_task_alone @@ -59,7 +63,7 @@ def list_inner_product(self, list_loss, list_multiplier): here only aggregate along the list """ list_tuple = zip(list_loss, list_multiplier) - list_penalized_reg = [mtuple[0]*mtuple[1] for mtuple in list_tuple] + list_penalized_reg = [mtuple[0] * mtuple[1] for mtuple in list_tuple] tensor_batch_penalized_loss = g_list_model_penalized_reg_agg(list_penalized_reg) # return value of list_inner_product should keep the minibatch structure, thus aggregation # here only aggregate along the list @@ -85,10 +89,8 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ task independent regularization loss for domain generalization """ - loss_reg, mu = self._extend_loss( - tensor_x, tensor_y, tensor_d, others) - loss_reg_, mu_ = self._cal_reg_loss( - tensor_x, tensor_y, tensor_d, others) + loss_reg, mu = self._extend_loss(tensor_x, tensor_y, tensor_d, others) + loss_reg_, mu_ = self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others) if loss_reg is not None: return loss_reg_ + loss_reg, mu_ + mu return loss_reg_, mu_ @@ -98,8 +100,7 @@ def _extend_loss(self, tensor_x, tensor_y, tensor_d, others=None): combine losses from two models """ if self._decoratee is not None: - return self._decoratee.cal_reg_loss( - tensor_x, tensor_y, tensor_d, others) + return self._decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) return None, None def forward(self, tensor_x, tensor_y, tensor_d, others=None): @@ -115,8 +116,8 @@ def extract_semantic_feat(self, tensor_x): """ extract semantic feature (not domain feature), note that extract semantic feature is an action, it is more general than - calling a static network(module)'s forward function since - there are extra action like reshape the tensor + calling a static network(module)'s forward function since + there are extra action like reshape the tensor """ if self._decoratee is not None: return self._decoratee.extract_semantic_feat(tensor_x) diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index c378569e9..470f0e6d8 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -4,6 +4,7 @@ import abc import math + import numpy as np import pandas as pd import torch @@ -11,17 +12,18 @@ from torch.nn import functional as F from domainlab.models.a_model import AModel -from domainlab.utils.utils_class import store_args -from domainlab.utils.utils_classif import get_label_na, logit2preds_vpic +from domainlab.utils.logger import Logger from domainlab.utils.perf import PerfClassif from domainlab.utils.perf_metrics import PerfMetricClassif -from domainlab.utils.logger import Logger +from domainlab.utils.utils_class import store_args +from domainlab.utils.utils_classif import get_label_na, logit2preds_vpic class AModelClassif(AModel, metaclass=abc.ABCMeta): """ operations that all classification model should have """ + match_feat_fun_na = "cal_logit_y" def extend(self, model): @@ -143,9 +145,7 @@ def cal_task_loss(self, tensor_x, tensor_y): # cross entropy always return a scalar, no need for inside instance reduction return lc_y - def pred2file(self, loader_te, device, filename, - metric_te, - spliter="#"): + def pred2file(self, loader_te, device, filename, metric_te, spliter="#"): """ pred2file dump predicted label to file as sanity check """ @@ -157,31 +157,40 @@ def pred2file(self, loader_te, device, filename, _, prob, *_ = model_local.infer_y_vpicn(x_s) list_pred_prob_list = prob.tolist() list_target_list = y_s.tolist() - list_target_scalar = [np.asarray(label).argmax() for label in list_target_list] + list_target_scalar = [ + np.asarray(label).argmax() for label in list_target_list + ] tuple_zip = zip(path4instance, list_target_scalar, list_pred_prob_list) list_pair_path_pred = list(tuple_zip) - with open(filename, 'a', encoding="utf8") as handle_file: + with open(filename, "a", encoding="utf8") as handle_file: for list4one_obs_path_prob_target in list_pair_path_pred: list_str_one_obs_path_target_predprob = [ - str(ele) for ele in list4one_obs_path_prob_target] - str_line = (" "+spliter+" ").join(list_str_one_obs_path_target_predprob) + str(ele) for ele in list4one_obs_path_prob_target + ] + str_line = (" " + spliter + " ").join( + list_str_one_obs_path_target_predprob + ) str_line = str_line.replace("[", "") str_line = str_line.replace("]", "") print(str_line, file=handle_file) logger.info(f"prediction saved in file {filename}") file_acc = self.read_prediction_file(filename, spliter) - acc_metric_te = metric_te['acc'] + acc_metric_te = metric_te["acc"] flag1 = math.isclose(file_acc, acc_metric_te, rel_tol=1e-9, abs_tol=0.01) acc_raw1 = PerfClassif.cal_acc(self, loader_te, device) acc_raw2 = PerfClassif.cal_acc(self, loader_te, device) - flag_raw_consistency = math.isclose(acc_raw1, acc_raw2, rel_tol=1e-9, abs_tol=0.01) + flag_raw_consistency = math.isclose( + acc_raw1, acc_raw2, rel_tol=1e-9, abs_tol=0.01 + ) flag2 = math.isclose(file_acc, acc_raw1, rel_tol=1e-9, abs_tol=0.01) if not (flag1 & flag2 & flag_raw_consistency): - str_info = f"inconsistent acc: \n" \ - f"prediction file acc generated using the current model is {file_acc} \n" \ - f"input torchmetric acc to the current function: {acc_metric_te} \n" \ - f"raw acc 1 {acc_raw1} \n" \ - f"raw acc 2 {acc_raw2} \n" + str_info = ( + f"inconsistent acc: \n" + f"prediction file acc generated using the current model is {file_acc} \n" + f"input torchmetric acc to the current function: {acc_metric_te} \n" + f"raw acc 1 {acc_raw1} \n" + f"raw acc 2 {acc_raw2} \n" + ) raise RuntimeError(str_info) return file_acc @@ -189,7 +198,7 @@ def read_prediction_file(self, filename, spliter): """ check if the written fiel could calculate acc """ - with open(filename, 'r', encoding="utf8") as handle_file: + with open(filename, "r", encoding="utf8") as handle_file: list_lines = [line.strip().split(spliter) for line in handle_file] count_correct = 0 for line in list_lines: diff --git a/domainlab/models/args_jigen.py b/domainlab/models/args_jigen.py index c72083e5b..91c2c90ae 100644 --- a/domainlab/models/args_jigen.py +++ b/domainlab/models/args_jigen.py @@ -1,19 +1,29 @@ """ hyper-parameters for JiGen """ + + def add_args2parser_jigen(parser): """ hyper-parameters for JiGen """ - parser.add_argument('--nperm', type=int, default=31, - help='number of permutations') - parser.add_argument('--pperm', type=float, default=0.1, - help='probability of permutating the tiles \ - of an image') - parser.add_argument('--jigen_ppath', type=str, default=None, - help='npy file path to load numpy array with each row being \ + parser.add_argument("--nperm", type=int, default=31, help="number of permutations") + parser.add_argument( + "--pperm", + type=float, + default=0.1, + help="probability of permutating the tiles \ + of an image", + ) + parser.add_argument( + "--jigen_ppath", + type=str, + default=None, + help="npy file path to load numpy array with each row being \ permutation index, if not None, nperm and grid_len has to agree \ - with the number of row and columns of the input array') - parser.add_argument('--grid_len', type=int, default=3, - help='length of image in tile unit') + with the number of row and columns of the input array", + ) + parser.add_argument( + "--grid_len", type=int, default=3, help="length of image in tile unit" + ) return parser diff --git a/domainlab/models/args_vae.py b/domainlab/models/args_vae.py index fecd0bd8c..4d56f96b5 100644 --- a/domainlab/models/args_vae.py +++ b/domainlab/models/args_vae.py @@ -1,44 +1,76 @@ def add_args2parser_vae(parser): - parser.add_argument('--zd_dim', type=int, default=64, - help='diva: size of latent space for domain') - parser.add_argument('--zx_dim', type=int, default=0, - help='diva: size of latent space for unobserved') - parser.add_argument('--zy_dim', type=int, default=64, - help='diva, hduva: size of latent space for class') + parser.add_argument( + "--zd_dim", type=int, default=64, help="diva: size of latent space for domain" + ) + parser.add_argument( + "--zx_dim", + type=int, + default=0, + help="diva: size of latent space for unobserved", + ) + parser.add_argument( + "--zy_dim", + type=int, + default=64, + help="diva, hduva: size of latent space for class", + ) # HDUVA - parser.add_argument('--topic_dim', type=int, default=3, - help='hduva: number of topics') + parser.add_argument( + "--topic_dim", type=int, default=3, help="hduva: number of topics" + ) - parser.add_argument('--nname_encoder_x2topic_h', - type=str, default=None, - help='hduva: network from image to topic distribution') + parser.add_argument( + "--nname_encoder_x2topic_h", + type=str, + default=None, + help="hduva: network from image to topic distribution", + ) - parser.add_argument('--npath_encoder_x2topic_h', - type=str, default=None, - help='hduva: network from image to topic distribution') + parser.add_argument( + "--npath_encoder_x2topic_h", + type=str, + default=None, + help="hduva: network from image to topic distribution", + ) - parser.add_argument('--nname_encoder_sandwich_x2h4zd', - type=str, default=None, - help='hduva: network from image and topic to zd') - parser.add_argument('--npath_encoder_sandwich_x2h4zd', - type=str, default=None, - help='hduva: network from image and topic to zd') + parser.add_argument( + "--nname_encoder_sandwich_x2h4zd", + type=str, + default=None, + help="hduva: network from image and topic to zd", + ) + parser.add_argument( + "--npath_encoder_sandwich_x2h4zd", + type=str, + default=None, + help="hduva: network from image and topic to zd", + ) # ERM, ELBO - parser.add_argument('--gamma_y', type=float, default=None, - help='diva, hduva: multiplier for y classifier') - parser.add_argument('--gamma_d', type=float, default=None, - help='diva: multiplier for d classifier from zd') - - + parser.add_argument( + "--gamma_y", + type=float, + default=None, + help="diva, hduva: multiplier for y classifier", + ) + parser.add_argument( + "--gamma_d", + type=float, + default=None, + help="diva: multiplier for d classifier from zd", + ) # Beta VAE part - parser.add_argument('--beta_t', type=float, default=1., - help='hduva: multiplier for KL topic') - parser.add_argument('--beta_d', type=float, default=1., - help='diva: multiplier for KL d') - parser.add_argument('--beta_x', type=float, default=1., - help='diva: multiplier for KL x') - parser.add_argument('--beta_y', type=float, default=1., - help='diva, hduva: multiplier for KL y') + parser.add_argument( + "--beta_t", type=float, default=1.0, help="hduva: multiplier for KL topic" + ) + parser.add_argument( + "--beta_d", type=float, default=1.0, help="diva: multiplier for KL d" + ) + parser.add_argument( + "--beta_x", type=float, default=1.0, help="diva: multiplier for KL x" + ) + parser.add_argument( + "--beta_y", type=float, default=1.0, help="diva, hduva: multiplier for KL y" + ) return parser diff --git a/domainlab/models/interface_vae_xyd.py b/domainlab/models/interface_vae_xyd.py index adf101b54..2a587b677 100644 --- a/domainlab/models/interface_vae_xyd.py +++ b/domainlab/models/interface_vae_xyd.py @@ -7,13 +7,13 @@ from domainlab.utils.utils_class import store_args -class InterfaceVAEXYD(): +class InterfaceVAEXYD: """ Interface (without constructor and inheritance) for XYD VAE """ + def init(self): - self.chain_node_builder.init_business( - self.zd_dim, self.zx_dim, self.zy_dim) + self.chain_node_builder.init_business(self.zd_dim, self.zx_dim, self.zy_dim) self.i_c = self.chain_node_builder.i_c self.i_h = self.chain_node_builder.i_h self.i_w = self.chain_node_builder.i_w @@ -27,9 +27,10 @@ def _init_components(self): """ self.add_module("encoder", self.chain_node_builder.build_encoder()) self.add_module("decoder", self.chain_node_builder.build_decoder()) - self.add_module("net_p_zy", - self.chain_node_builder.construct_cond_prior( - self.dim_y, self.zy_dim)) + self.add_module( + "net_p_zy", + self.chain_node_builder.construct_cond_prior(self.dim_y, self.zy_dim), + ) def init_p_zx4batch(self, batch_size, device): """ diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index 58f54cab2..b487cd586 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -48,8 +48,17 @@ class ModelDAN(parent_class): """ anonymous """ - def __init__(self, list_str_y, list_d_tr, - alpha, net_encoder, net_classifier, net_discriminator, builder=None): + + def __init__( + self, + list_str_y, + list_d_tr, + alpha, + net_encoder, + net_classifier, + net_discriminator, + builder=None, + ): """ See documentation above in mk_dann() function """ @@ -67,14 +76,18 @@ def reset_aux_net(self): """ if self.builder is None: return - self.net_discriminator = self.builder.reset_aux_net(self.extract_semantic_feat) + self.net_discriminator = self.builder.reset_aux_net( + self.extract_semantic_feat + ) def hyper_update(self, epoch, fun_scheduler): """hyper_update. :param epoch: :param fun_scheduler: the hyperparameter scheduler object """ - dict_rst = fun_scheduler(epoch) # the __call__ method of hyperparameter scheduler + dict_rst = fun_scheduler( + epoch + ) # the __call__ method of hyperparameter scheduler self.alpha = dict_rst["alpha"] def hyper_init(self, functor_scheduler): @@ -87,9 +100,12 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others): _ = others _ = tensor_y feat = self.extract_semantic_feat(tensor_x) - net_grad_additive_reverse = AutoGradFunReverseMultiply.apply(feat, self.alpha) + net_grad_additive_reverse = AutoGradFunReverseMultiply.apply( + feat, self.alpha + ) logit_d = self.net_discriminator(net_grad_additive_reverse) _, d_target = tensor_d.max(dim=1) lc_d = F.cross_entropy(logit_d, d_target, reduction=g_str_cross_entropy_agg) return [lc_d], [self.alpha] + return ModelDAN diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index 31be30966..9ccdad896 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -55,29 +55,42 @@ class ModelDIVA(parent_class): """ DIVA """ + @store_args - def __init__(self, chain_node_builder, - zd_dim, zy_dim, zx_dim, - list_str_y, list_d_tr, - gamma_d, gamma_y, - beta_d, beta_x, beta_y, multiplier_recon=1.0): + def __init__( + self, + chain_node_builder, + zd_dim, + zy_dim, + zx_dim, + list_str_y, + list_d_tr, + gamma_d, + gamma_y, + beta_d, + beta_x, + beta_y, + multiplier_recon=1.0, + ): """ gamma: classification loss coefficient """ - super().__init__(chain_node_builder, - zd_dim, zy_dim, zx_dim, - list_str_y) + super().__init__(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y) self.list_d_tr = list_d_tr self.dim_d_tr = len(self.list_d_tr) if self.zd_dim > 0: self.add_module( "net_p_zd", self.chain_node_builder.construct_cond_prior( - self.dim_d_tr, self.zd_dim)) + self.dim_d_tr, self.zd_dim + ), + ) self.add_module( "net_classif_d", self.chain_node_builder.construct_classifier( - self.zd_dim, self.dim_d_tr)) + self.zd_dim, self.dim_d_tr + ), + ) def hyper_update(self, epoch, fun_scheduler): """hyper_update. @@ -97,8 +110,8 @@ def hyper_init(self, functor_scheduler): :param functor_scheduler: the class name of the scheduler """ return functor_scheduler( - trainer=None, - beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x) + trainer=None, beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x + ) def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): q_zd, zd_q, q_zx, zx_q, q_zy, zy_q = self.encoder(tensor_x) @@ -115,21 +128,36 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): loss_recon_x, _, _ = self.decoder(z_concat, tensor_x) zd_p_minus_zd_q = g_inst_component_loss_agg( - p_zd.log_prob(zd_q) - q_zd.log_prob(zd_q), 1) + p_zd.log_prob(zd_q) - q_zd.log_prob(zd_q), 1 + ) # without aggregation, shape is [batchsize, zd_dim] zx_p_minus_zx_q = torch.zeros_like(zd_p_minus_zd_q) if self.zx_dim > 0: # torch.sum will return 0 for empty tensor, # torch.mean will return nan zx_p_minus_zx_q = g_inst_component_loss_agg( - p_zx.log_prob(zx_q) - q_zx.log_prob(zx_q), 1) + p_zx.log_prob(zx_q) - q_zx.log_prob(zx_q), 1 + ) zy_p_minus_zy_q = g_inst_component_loss_agg( - p_zy.log_prob(zy_q) - q_zy.log_prob(zy_q), 1) + p_zy.log_prob(zy_q) - q_zy.log_prob(zy_q), 1 + ) _, d_target = tensor_d.max(dim=1) lc_d = F.cross_entropy(logit_d, d_target, reduction=g_str_cross_entropy_agg) - return [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \ - [self.multiplier_recon, -self.beta_d, -self.beta_x, -self.beta_y, self.gamma_d] + return [ + loss_recon_x, + zd_p_minus_zd_q, + zx_p_minus_zx_q, + zy_p_minus_zy_q, + lc_d, + ], [ + self.multiplier_recon, + -self.beta_d, + -self.beta_x, + -self.beta_y, + self.gamma_d, + ] + return ModelDIVA diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 043a4c6c1..2a650925a 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -1,9 +1,9 @@ """ Emperical risk minimization """ +from domainlab.compos.nn_zoo.nn import LayerId from domainlab.models.a_model_classif import AModelClassif from domainlab.utils.override_interface import override_interface -from domainlab.compos.nn_zoo.nn import LayerId def mk_erm(parent_class=AModelClassif): @@ -34,7 +34,10 @@ class ModelERM(parent_class): """ anonymous """ - def __init__(self, net=None, net_feat=None, net_classifier=None, list_str_y=None): + + def __init__( + self, net=None, net_feat=None, net_classifier=None, list_str_y=None + ): if net_feat is None and net_classifier is None and net is not None: net_feat = net net_classifier = LayerId() @@ -42,11 +45,14 @@ def __init__(self, net=None, net_feat=None, net_classifier=None, list_str_y=None elif net_classifier is not None: dim_y = list(net_classifier.modules())[-1].out_features else: - raise RuntimeError("specify either a whole network for classification or separate \ - feature and classifier") + raise RuntimeError( + "specify either a whole network for classification or separate \ + feature and classifier" + ) if list_str_y is None: list_str_y = [f"class{i}" for i in range(dim_y)] super().__init__(list_str_y) self._net_classifier = net_classifier self._net_invar_feat = net_feat + return ModelERM diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index f22df41e1..1314a4e6e 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -4,9 +4,9 @@ import torch from torch.distributions import Dirichlet +from domainlab import g_inst_component_loss_agg from domainlab.models.model_vae_xyd_classif import VAEXYDClassif from domainlab.utils.utils_class import store_args -from domainlab import g_inst_component_loss_agg def mk_hduva(parent_class=VAEXYDClassif): @@ -56,13 +56,16 @@ class ModelHDUVA(parent_class): """ Hierarchical Domain Unsupervised Variational Auto-Encoding """ + def hyper_update(self, epoch, fun_scheduler): """hyper_update. :param epoch: :param fun_scheduler: """ - dict_rst = fun_scheduler(epoch) # the __call__ function of hyper-para-scheduler object + dict_rst = fun_scheduler( + epoch + ) # the __call__ function of hyper-para-scheduler object self.beta_d = dict_rst["beta_d"] self.beta_y = dict_rst["beta_y"] self.beta_x = dict_rst["beta_x"] @@ -77,30 +80,40 @@ def hyper_init(self, functor_scheduler): # constructor signature is def __init__(self, **kwargs): return functor_scheduler( trainer=None, - beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x, - beta_t=self.beta_t) + beta_d=self.beta_d, + beta_y=self.beta_y, + beta_x=self.beta_x, + beta_t=self.beta_t, + ) @store_args - def __init__(self, chain_node_builder, - zy_dim, zd_dim, - list_str_y, - gamma_d, gamma_y, - beta_d, beta_x, beta_y, - beta_t, - device, - zx_dim=0, - topic_dim=3, - multiplier_recon=1.0): - """ - """ - super().__init__(chain_node_builder, - zd_dim, zy_dim, zx_dim, - list_str_y) + def __init__( + self, + chain_node_builder, + zy_dim, + zd_dim, + list_str_y, + gamma_d, + gamma_y, + beta_d, + beta_x, + beta_y, + beta_t, + device, + zx_dim=0, + topic_dim=3, + multiplier_recon=1.0, + ): + """ """ + super().__init__(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y) # topic to zd follows Gaussian distribution - self.add_module("net_p_zd", - self.chain_node_builder.construct_cond_prior( - self.topic_dim, self.zd_dim)) + self.add_module( + "net_p_zd", + self.chain_node_builder.construct_cond_prior( + self.topic_dim, self.zd_dim + ), + ) # override interface def _init_components(self): @@ -109,16 +122,21 @@ def _init_components(self): p(zy) q_{classif}(zy) """ - self.add_module("encoder", self.chain_node_builder.build_encoder( - self.device, self.topic_dim)) - self.add_module("decoder", self.chain_node_builder.build_decoder( - self.topic_dim)) - self.add_module("net_p_zy", - self.chain_node_builder.construct_cond_prior( - self.dim_y, self.zy_dim)) - self.add_module("net_classif_y", - self.chain_node_builder.construct_classifier( - self.zy_dim, self.dim_y)) + self.add_module( + "encoder", + self.chain_node_builder.build_encoder(self.device, self.topic_dim), + ) + self.add_module( + "decoder", self.chain_node_builder.build_decoder(self.topic_dim) + ) + self.add_module( + "net_p_zy", + self.chain_node_builder.construct_cond_prior(self.dim_y, self.zy_dim), + ) + self.add_module( + "net_classif_y", + self.chain_node_builder.construct_classifier(self.zy_dim, self.dim_y), + ) self._net_classifier = self.net_classif_y def init_p_topic_batch(self, batch_size, device): @@ -129,10 +147,7 @@ def init_p_topic_batch(self, batch_size, device): return prior def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): - q_topic, topic_q, \ - qzd, zd_q, \ - qzx, zx_q, \ - qzy, zy_q = self.encoder(tensor_x) + q_topic, topic_q, qzd, zd_q, qzx, zx_q, qzy, zy_q = self.encoder(tensor_x) batch_size = zd_q.shape[0] device = zd_q.device @@ -146,26 +161,30 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): if (tensor_y.shape[-1] == 1) | (len(tensor_y.shape) == 1): tensor_y_onehot = torch.nn.functional.one_hot( - tensor_y, - num_classes=len(self.list_str_y)) + tensor_y, num_classes=len(self.list_str_y) + ) tensor_y_onehot = tensor_y_onehot.to(torch.float32) else: tensor_y_onehot = tensor_y p_zy = self.net_p_zy(tensor_y_onehot) zy_p_minus_zy_q = g_inst_component_loss_agg( - p_zy.log_prob(zy_q) - qzy.log_prob(zy_q), 1) + p_zy.log_prob(zy_q) - qzy.log_prob(zy_q), 1 + ) # zx KL divergence zx_p_minus_q = torch.zeros_like(zy_p_minus_zy_q) if self.zx_dim > 0: p_zx = self.init_p_zx4batch(batch_size, device) zx_p_minus_q = g_inst_component_loss_agg( - p_zx.log_prob(zx_q) - qzx.log_prob(zx_q), 1) + p_zx.log_prob(zx_q) - qzx.log_prob(zx_q), 1 + ) # zd KL diverence p_zd = self.net_p_zd(topic_q) - zd_p_minus_q = g_inst_component_loss_agg(p_zd.log_prob(zd_q) - qzd.log_prob(zd_q), 1) + zd_p_minus_q = g_inst_component_loss_agg( + p_zd.log_prob(zd_q) - qzd.log_prob(zd_q), 1 + ) # topic KL divergence # @FIXME: why topic is still there? @@ -174,8 +193,19 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): # reconstruction z_concat = self.decoder.concat_ytdx(zy_q, topic_q, zd_q, zx_q) loss_recon_x, _, _ = self.decoder(z_concat, tensor_x) - return [loss_recon_x, zx_p_minus_q, zy_p_minus_zy_q, zd_p_minus_q, topic_p_minus_q], \ - [self.multiplier_recon, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t] + return [ + loss_recon_x, + zx_p_minus_q, + zy_p_minus_zy_q, + zd_p_minus_q, + topic_p_minus_q, + ], [ + self.multiplier_recon, + -self.beta_x, + -self.beta_y, + -self.beta_d, + -self.beta_t, + ] def extract_semantic_feat(self, tensor_x): """ diff --git a/domainlab/models/model_jigen.py b/domainlab/models/model_jigen.py index ca14b57f2..98cb41150 100644 --- a/domainlab/models/model_jigen.py +++ b/domainlab/models/model_jigen.py @@ -2,12 +2,13 @@ Jigen Model Similar to DANN model """ import warnings + from torch.nn import functional as F from domainlab import g_str_cross_entropy_agg +from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches from domainlab.models.a_model_classif import AModelClassif from domainlab.models.model_dann import mk_dann -from domainlab.dsets.utils_wrapdset_patches import WrapDsetPatches def mk_jigen(parent_class=AModelClassif): @@ -56,20 +57,27 @@ class ModelJiGen(class_dann): """ Jigen Model Similar to DANN model """ - def __init__(self, list_str_y, - net_encoder, - net_classifier_class, - net_classifier_permutation, - coeff_reg, n_perm=31, - prob_permutation=0.1, - overwrite_args=False, - meta_info=None): - super().__init__(list_str_y, - list_d_tr=None, - alpha=coeff_reg, - net_encoder=net_encoder, - net_classifier=net_classifier_class, - net_discriminator=net_classifier_permutation) + + def __init__( + self, + list_str_y, + net_encoder, + net_classifier_class, + net_classifier_permutation, + coeff_reg, + n_perm=31, + prob_permutation=0.1, + overwrite_args=False, + meta_info=None, + ): + super().__init__( + list_str_y, + list_d_tr=None, + alpha=coeff_reg, + net_encoder=net_encoder, + net_classifier=net_classifier_class, + net_discriminator=net_classifier_permutation, + ) self.net_encoder = net_encoder self.net_classifier_class = net_classifier_class self.net_classifier_permutation = net_classifier_permutation @@ -83,30 +91,40 @@ def dset_decoration_args_algo(self, args, ddset): JiGen need to shuffle the tiles of the original image """ if self.meta_info is not None: - args.nperm = self.meta_info["nperm"] -1 \ - if "nperm" in self.meta_info else args.nperm - args.pperm = self.meta_info["pperm"] \ - if "pperm" in self.meta_info else args.pperm + args.nperm = ( + self.meta_info["nperm"] - 1 + if "nperm" in self.meta_info + else args.nperm + ) + args.pperm = ( + self.meta_info["pperm"] if "pperm" in self.meta_info else args.pperm + ) nperm = self.n_perm if args.nperm != nperm and not self.flag_overwrite_args: - warnings.warn(f"number of permutations specified differently \ + warnings.warn( + f"number of permutations specified differently \ in model {nperm} and args {args.nperm}, \ - going to take args specification") + going to take args specification" + ) nperm = args.nperm - + pperm = self.prob_perm if args.pperm != pperm and not self.flag_overwrite_args: - warnings.warn(f"probability of reshuffling specified differently \ + warnings.warn( + f"probability of reshuffling specified differently \ in model {pperm} and args: {args.pperm}, \ - going to take model specification") + going to take model specification" + ) pperm = args.pperm - ddset_new = WrapDsetPatches(ddset, - num_perms2classify=nperm, - prob_no_perm=1-pperm, - grid_len=args.grid_len, - ppath=args.jigen_ppath) + ddset_new = WrapDsetPatches( + ddset, + num_perms2classify=nperm, + prob_no_perm=1 - pperm, + grid_len=args.grid_len, + ppath=args.jigen_ppath, + ) return ddset_new def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others): @@ -132,6 +150,10 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others): batch_target_scalar = vec_perm_ind batch_target_scalar = batch_target_scalar.to(tensor_x.device) loss_perm = F.cross_entropy( - logits_which_permutation, batch_target_scalar, reduction=g_str_cross_entropy_agg) + logits_which_permutation, + batch_target_scalar, + reduction=g_str_cross_entropy_agg, + ) return [loss_perm], [self.alpha] + return ModelJiGen diff --git a/domainlab/models/model_vae_xyd_classif.py b/domainlab/models/model_vae_xyd_classif.py index 0b9917425..29f31499a 100644 --- a/domainlab/models/model_vae_xyd_classif.py +++ b/domainlab/models/model_vae_xyd_classif.py @@ -10,10 +10,9 @@ class VAEXYDClassif(AModelClassif, InterfaceVAEXYD): """ Base Class for DIVA and HDUVA """ + @store_args - def __init__(self, chain_node_builder, - zd_dim, zy_dim, zx_dim, - list_str_y): + def __init__(self, chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y): """ :param chain_node_builder: constructed object """ @@ -33,8 +32,9 @@ def multiplier4task_loss(self): def _init_components(self): super()._init_components() - self.add_module("net_classif_y", - self.chain_node_builder.construct_classifier( - self.zy_dim, self.dim_y)) + self.add_module( + "net_classif_y", + self.chain_node_builder.construct_classifier(self.zy_dim, self.dim_y), + ) # property setter only for other object, internally, one shoud use _net_classifier self._net_classifier = self.net_classif_y diff --git a/domainlab/tasks/a_task.py b/domainlab/tasks/a_task.py index 3808aa30b..5b03923c4 100644 --- a/domainlab/tasks/a_task.py +++ b/domainlab/tasks/a_task.py @@ -1,8 +1,8 @@ """ Base class for Task """ -from abc import abstractmethod import warnings +from abc import abstractmethod from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler from domainlab.tasks.task_utils import parse_domain_id @@ -13,6 +13,7 @@ class NodeTaskDG(AbstractChainNodeHandler): """ Domain Generalization Classification Task """ + def __init__(self, succ=None): super().__init__(succ) self._loader_tr = None @@ -30,12 +31,16 @@ def __init__(self, succ=None): self.dim_d_tr = None # public, only used for diva self._im_size = None self._dict_domains2imgroot = {} - self._dict_domain_folder_name2class = {} # {"domain1": {"class1":car, "class2":dog}} + self._dict_domain_folder_name2class = ( + {} + ) # {"domain1": {"class1":car, "class2":dog}} self._dict_domain_img_trans = {} self.dict_att = {} self.img_trans_te = None self.dict_domain2imgroot = {} - self._dict_domain2filepath_list_im_tr = {} # {"photo": "xxx/yyy/file_of_path2imgs"} + self._dict_domain2filepath_list_im_tr = ( + {} + ) # {"photo": "xxx/yyy/file_of_path2imgs"} self._dict_domain2filepath_list_im_val = {} self._dict_domain2filepath_list_im_te = {} self.dict_class_label_ind2name = None @@ -145,14 +150,14 @@ def get_list_domains_tr_te(self, tr_id, te_id): assert set(list_domain_te).issubset(set(list_domains)) if tr_id is None: - list_domain_tr = [did for did in list_domains if - did not in list_domain_te] + list_domain_tr = [did for did in list_domains if did not in list_domain_te] else: list_domain_tr = parse_domain_id(tr_id, list_domains) if not set(list_domain_tr).issubset(set(list_domains)): raise RuntimeError( f"training domain {list_domain_tr} is not \ - subset of available domains {list_domains}") + subset of available domains {list_domains}" + ) if set(list_domain_tr) & set(list_domain_te): logger = Logger.get_logger() @@ -163,7 +168,7 @@ def get_list_domains_tr_te(self, tr_id, te_id): warnings.warn( "The sets of training and test domains overlap -- " "be aware of data leakage or training to the test!", - RuntimeWarning + RuntimeWarning, ) self.dim_d_tr = len(list_domain_tr) @@ -176,5 +181,5 @@ def __str__(self): """ strout = "list of domains: \n" strout += str(self.get_list_domains()) - strout += (f"\n input tensor size: {self.isize}") + strout += f"\n input tensor size: {self.isize}" return strout diff --git a/domainlab/tasks/a_task_classif.py b/domainlab/tasks/a_task_classif.py index 59424526a..50ce926ca 100644 --- a/domainlab/tasks/a_task_classif.py +++ b/domainlab/tasks/a_task_classif.py @@ -3,14 +3,15 @@ """ import os -from domainlab.tasks.utils_task import img_loader2dir from domainlab.tasks.a_task import NodeTaskDG +from domainlab.tasks.utils_task import img_loader2dir class NodeTaskDGClassif(NodeTaskDG): """ abstract class for classification task """ + def __init__(self, succ=None): # just for declaration of variables self._list_str_y = None @@ -48,7 +49,9 @@ def dim_y(self, dim_y): """ if self.list_str_y is not None: if len(self.list_str_y) is not dim_y: - raise RuntimeError(f"dim y={dim_y} not equal to self.list_str_y={self.list_str_y}") + raise RuntimeError( + f"dim y={dim_y} not equal to self.list_str_y={self.list_str_y}" + ) self._dim_y = dim_y def sample_sav(self, root, batches=5, subfolder_na="task_sample"): @@ -57,15 +60,19 @@ def sample_sav(self, root, batches=5, subfolder_na="task_sample"): """ folder_na = os.path.join(root, self.task_name, subfolder_na) - img_loader2dir(self.loader_te, - list_domain_na=self.get_list_domains(), - list_class_na=self.list_str_y, - folder=folder_na, - batches=batches, - test=True) + img_loader2dir( + self.loader_te, + list_domain_na=self.get_list_domains(), + list_class_na=self.list_str_y, + folder=folder_na, + batches=batches, + test=True, + ) - img_loader2dir(self.loader_tr, - list_domain_na=self.get_list_domains(), - list_class_na=self.list_str_y, - folder=folder_na, - batches=batches) + img_loader2dir( + self.loader_tr, + list_domain_na=self.get_list_domains(), + list_class_na=self.list_str_y, + folder=folder_na, + batches=batches, + ) diff --git a/domainlab/tasks/b_task.py b/domainlab/tasks/b_task.py index 94a1e4de6..e532fbf35 100644 --- a/domainlab/tasks/b_task.py +++ b/domainlab/tasks/b_task.py @@ -4,14 +4,14 @@ from torch.utils.data.dataset import ConcatDataset from domainlab.tasks.a_task import NodeTaskDG -from domainlab.tasks.utils_task import (DsetDomainVecDecorator, mk_loader, - mk_onehot) +from domainlab.tasks.utils_task import DsetDomainVecDecorator, mk_loader, mk_onehot class NodeTaskDict(NodeTaskDG): """ Use dictionaries to create train and test domain split """ + def get_dset_by_domain(self, args, na_domain, split=False): """ each domain correspond to one dataset, must be implemented by child class @@ -27,19 +27,27 @@ def init_business(self, args, trainer=None): """ create a dictionary of datasets """ - list_domain_tr, list_domain_te = self.get_list_domains_tr_te(args.tr_d, args.te_d) + list_domain_tr, list_domain_te = self.get_list_domains_tr_te( + args.tr_d, args.te_d + ) self.dict_dset_tr = {} self.dict_dset_val = {} dim_d = len(list_domain_tr) - for (ind_domain_dummy, na_domain) in enumerate(list_domain_tr): - dset_tr, dset_val = self.get_dset_by_domain(args, na_domain, split=args.split) + for ind_domain_dummy, na_domain in enumerate(list_domain_tr): + dset_tr, dset_val = self.get_dset_by_domain( + args, na_domain, split=args.split + ) vec_domain = mk_onehot(dim_d, ind_domain_dummy) # for diva, dann ddset_tr = DsetDomainVecDecorator(dset_tr, vec_domain, na_domain) ddset_val = DsetDomainVecDecorator(dset_val, vec_domain, na_domain) if trainer is not None and hasattr(trainer, "dset_decoration_args_algo"): ddset_tr = trainer.dset_decoration_args_algo(args, ddset_tr) ddset_val = trainer.dset_decoration_args_algo(args, ddset_val) - if trainer is not None and trainer.model is not None and hasattr(trainer.model, "dset_decoration_args_algo"): + if ( + trainer is not None + and trainer.model is not None + and hasattr(trainer.model, "dset_decoration_args_algo") + ): ddset_tr = trainer.model.dset_decoration_args_algo(args, ddset_tr) ddset_val = trainer.model.dset_decoration_args_algo(args, ddset_val) self.dict_dset_tr.update({na_domain: ddset_tr}) @@ -48,9 +56,9 @@ def init_business(self, args, trainer=None): self._loader_tr = mk_loader(ddset_mix, args.bs) ddset_mix_val = ConcatDataset(tuple(self.dict_dset_val.values())) - self._loader_val = mk_loader(ddset_mix_val, args.bs, - shuffle=False, - drop_last=False) + self._loader_val = mk_loader( + ddset_mix_val, args.bs, shuffle=False, drop_last=False + ) self.dict_dset_te = {} # No need to have domain Label for test @@ -60,6 +68,4 @@ def init_business(self, args, trainer=None): # train and validation, this is not needed in test domain self.dict_dset_te.update({na_domain: dset_te}) dset_te = ConcatDataset(tuple(self.dict_dset_te.values())) - self._loader_te = mk_loader(dset_te, args.bs, - shuffle=False, - drop_last=False) \ No newline at end of file + self._loader_te = mk_loader(dset_te, args.bs, shuffle=False, drop_last=False) diff --git a/domainlab/tasks/b_task_classif.py b/domainlab/tasks/b_task_classif.py index c7f3c95c2..15fbb911d 100644 --- a/domainlab/tasks/b_task_classif.py +++ b/domainlab/tasks/b_task_classif.py @@ -11,6 +11,7 @@ class NodeTaskDictClassif(NodeTaskDict, NodeTaskDGClassif): """ Use dictionaries to create train and test domain split """ + def init_business(self, args, trainer=None): """ create a dictionary of datasets diff --git a/domainlab/tasks/task_dset.py b/domainlab/tasks/task_dset.py index 906b4e7b1..9c415ec5e 100644 --- a/domainlab/tasks/task_dset.py +++ b/domainlab/tasks/task_dset.py @@ -4,21 +4,25 @@ from domainlab.tasks.b_task_classif import NodeTaskDictClassif # abstract class -def mk_task_dset(isize, - taskna="task_custom", # name of the task - dim_y=None, - list_str_y=None, - parent=NodeTaskDictClassif, - succ=None): +def mk_task_dset( + isize, + taskna="task_custom", # name of the task + dim_y=None, + list_str_y=None, + parent=NodeTaskDictClassif, + succ=None, +): """ make a task via a dictionary of dataset where the key is domain value is a tuple of dataset for training and dataset for validation (can be identical to training) """ + class NodeTaskDset(parent): """ Use dictionaries to create train and test domain split """ + def conf_without_args(self): """ set member variables @@ -26,9 +30,13 @@ def conf_without_args(self): self._name = taskna if list_str_y is None and dim_y is None: - raise RuntimeError("arguments list_str_y and dim_y can not be both None!") + raise RuntimeError( + "arguments list_str_y and dim_y can not be both None!" + ) - self.list_str_y = list_str_y # list_str_y has to be initialized before dim_y + self.list_str_y = ( + list_str_y # list_str_y has to be initialized before dim_y + ) self.dim_y = dim_y if self.list_str_y is None: diff --git a/domainlab/tasks/task_folder.py b/domainlab/tasks/task_folder.py index f6bde8332..94b0c134d 100644 --- a/domainlab/tasks/task_folder.py +++ b/domainlab/tasks/task_folder.py @@ -4,9 +4,11 @@ from torchvision import transforms from domainlab.dsets.dset_subfolder import DsetSubFolder -from domainlab.dsets.utils_data import (DsetInMemDecorator, - fun_img_path_loader_default, - mk_fun_label2onehot) +from domainlab.dsets.utils_data import ( + DsetInMemDecorator, + fun_img_path_loader_default, + mk_fun_label2onehot, +) from domainlab.tasks.b_task_classif import NodeTaskDictClassif from domainlab.tasks.utils_task import DsetClassVecDecoratorImgPath from domainlab.utils.logger import Logger @@ -17,6 +19,7 @@ class NodeTaskFolder(NodeTaskDictClassif): create dataset by loading files from an organized folder then each domain correspond to one dataset """ + @property def dict_domain2imgroot(self): """ @@ -47,19 +50,22 @@ def extensions(self, str_format): def get_dset_by_domain(self, args, na_domain, split=False): if float(args.split): raise RuntimeError( - "this task does not support spliting training domain yet") + "this task does not support spliting training domain yet" + ) if self._dict_domain_img_trans: trans = self._dict_domain_img_trans[na_domain] if na_domain not in self.list_domain_tr: trans = self.img_trans_te else: trans = transforms.ToTensor() - dset = DsetSubFolder(root=self.dict_domain2imgroot[na_domain], - list_class_dir=self.list_str_y, - loader=fun_img_path_loader_default, - extensions=self.extensions, - transform=trans, - target_transform=mk_fun_label2onehot(len(self.list_str_y))) + dset = DsetSubFolder( + root=self.dict_domain2imgroot[na_domain], + list_class_dir=self.list_str_y, + loader=fun_img_path_loader_default, + extensions=self.extensions, + transform=trans, + target_transform=mk_fun_label2onehot(len(self.list_str_y)), + ) return dset, dset # @FIXME: validation by default set to be training set @@ -68,14 +74,15 @@ class NodeTaskFolderClassNaMismatch(NodeTaskFolder): when the folder names of the same class from different domains have different names """ + def get_dset_by_domain(self, args, na_domain, split=False): if float(args.split): raise RuntimeError( - "this task does not support spliting training domain yet") + "this task does not support spliting training domain yet" + ) logger = Logger.get_logger() logger.info(f"reading domain: {na_domain}") - domain_class_dirs = \ - self._dict_domain_folder_name2class[na_domain].keys() + domain_class_dirs = self._dict_domain_folder_name2class[na_domain].keys() if self._dict_domain_img_trans: trans = self._dict_domain_img_trans[na_domain] if na_domain not in self.list_domain_tr: @@ -84,20 +91,21 @@ def get_dset_by_domain(self, args, na_domain, split=False): trans = transforms.ToTensor() ext = None if self.extensions is None else self.extensions[na_domain] - dset = DsetSubFolder(root=self.dict_domain2imgroot[na_domain], - list_class_dir=list(domain_class_dirs), - loader=fun_img_path_loader_default, - extensions=ext, - transform=trans, - target_transform=mk_fun_label2onehot( - len(self.list_str_y))) + dset = DsetSubFolder( + root=self.dict_domain2imgroot[na_domain], + list_class_dir=list(domain_class_dirs), + loader=fun_img_path_loader_default, + extensions=ext, + transform=trans, + target_transform=mk_fun_label2onehot(len(self.list_str_y)), + ) # dset.path2imgs - dict_folder_name2class_global = \ - self._dict_domain_folder_name2class[na_domain] + dict_folder_name2class_global = self._dict_domain_folder_name2class[na_domain] dset = DsetClassVecDecoratorImgPath( - dset, dict_folder_name2class_global, self.list_str_y) + dset, dict_folder_name2class_global, self.list_str_y + ) # Always use the DsetInMemDecorator at the last step # since it does not have other needed attributes in bewteen if args.dmem: dset = DsetInMemDecorator(dset, na_domain) - return dset, dset # @FIXME: validation by default set to be training set + return dset, dset # @FIXME: validation by default set to be training set diff --git a/domainlab/tasks/task_folder_mk.py b/domainlab/tasks/task_folder_mk.py index f0c3cbb1f..5fad128da 100644 --- a/domainlab/tasks/task_folder_mk.py +++ b/domainlab/tasks/task_folder_mk.py @@ -4,22 +4,24 @@ from domainlab.tasks.task_folder import NodeTaskFolderClassNaMismatch -def mk_task_folder(extensions, - list_str_y, - dict_domain_folder_name2class, - dict_domain_img_trans, - img_trans_te, - isize, - dict_domain2imgroot, - taskna, - succ=None): +def mk_task_folder( + extensions, + list_str_y, + dict_domain_folder_name2class, + dict_domain_img_trans, + img_trans_te, + isize, + dict_domain2imgroot, + taskna, + succ=None, +): """ Make task by specifying each domain with folder structures :param extensions: Different Options: 1. a python dictionary with key as the domain name and value (str or tuple[str]) as the file extensions of the image. 2. a str or tuple[str] with file extensions for all domains. 3. None: in each domain all files with an extension in ('jpg', 'jpeg', 'png') are loaded. - :param + :param list_str_y: a python list with user defined class name where the order of the list matters. :param dict_domain_folder_name2class: a python dictionary, with key @@ -37,6 +39,7 @@ def mk_task_folder(extensions, names and values as the absolute path to each domain's data. :taskna: user defined task name """ + class NodeTaskFolderDummy(NodeTaskFolderClassNaMismatch): @property def task_name(self): @@ -54,4 +57,5 @@ def conf_without_args(self): self.dict_domain2imgroot = dict_domain2imgroot self._dict_domain_img_trans = dict_domain_img_trans self.img_trans_te = img_trans_te + return NodeTaskFolderDummy(succ=succ) diff --git a/domainlab/tasks/task_mini_vlcs.py b/domainlab/tasks/task_mini_vlcs.py index 386496c5b..bea20ce8e 100644 --- a/domainlab/tasks/task_mini_vlcs.py +++ b/domainlab/tasks/task_mini_vlcs.py @@ -15,44 +15,47 @@ def addtask2chain(chain): """ given a chain of responsibility for task selection, add another task into the chain """ - new_chain = mk_task_folder(extensions={"caltech": "jpg", "sun": "jpg", "labelme": "jpg"}, - list_str_y=["chair", "car"], - dict_domain_folder_name2class={ - "caltech": {"auto": "car", - "stuhl": "chair"}, - "sun": {"vehicle": "car", - "sofa": "chair"}, - "labelme": {"drive": "car", - "sit": "chair"} - }, - dict_domain_img_trans={ - "caltech": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - "sun": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - "labelme": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - }, - img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - isize=ImSize(3, 224, 224), - dict_domain2imgroot={ - "caltech": os.path.join( - path_this_file, - os.path.normpath("../"), - os.path.normpath("zdata/vlcs_mini/caltech/")), - "sun": os.path.join( - path_this_file, - os.path.normpath("../"), - os.path.normpath("zdata/vlcs_mini/sun/")), - "labelme": os.path.join( - path_this_file, - os.path.normpath("../"), - os.path.normpath("zdata/vlcs_mini/labelme/"))}, - taskna="mini_vlcs", - succ=chain) + new_chain = mk_task_folder( + extensions={"caltech": "jpg", "sun": "jpg", "labelme": "jpg"}, + list_str_y=["chair", "car"], + dict_domain_folder_name2class={ + "caltech": {"auto": "car", "stuhl": "chair"}, + "sun": {"vehicle": "car", "sofa": "chair"}, + "labelme": {"drive": "car", "sit": "chair"}, + }, + dict_domain_img_trans={ + "caltech": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + "sun": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + "labelme": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + }, + img_trans_te=transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + isize=ImSize(3, 224, 224), + dict_domain2imgroot={ + "caltech": os.path.join( + path_this_file, + os.path.normpath("../"), + os.path.normpath("zdata/vlcs_mini/caltech/"), + ), + "sun": os.path.join( + path_this_file, + os.path.normpath("../"), + os.path.normpath("zdata/vlcs_mini/sun/"), + ), + "labelme": os.path.join( + path_this_file, + os.path.normpath("../"), + os.path.normpath("zdata/vlcs_mini/labelme/"), + ), + }, + taskna="mini_vlcs", + succ=chain, + ) return new_chain diff --git a/domainlab/tasks/task_mnist_color.py b/domainlab/tasks/task_mnist_color.py index 826eae021..f01d77eab 100644 --- a/domainlab/tasks/task_mnist_color.py +++ b/domainlab/tasks/task_mnist_color.py @@ -3,8 +3,7 @@ """ from torch.utils.data import random_split -from domainlab.dsets.dset_mnist_color_solo_default import \ - DsetMNISTColorSoloDefault +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault from domainlab.dsets.utils_color_palette import default_rgb_palette # @FIXME from domainlab.tasks.b_task_classif import NodeTaskDictClassif from domainlab.tasks.utils_task import ImSize @@ -15,6 +14,7 @@ class NodeTaskMNISTColor10(NodeTaskDictClassif): """ Use the deafult palette with 10 colors """ + @property def list_str_y(self): return mk_dummy_label_list_str("digit", 10) @@ -30,13 +30,15 @@ def get_list_domains(self): 2. better use method than property so new domains can be added """ list_domains = [] - for rgb_list in default_rgb_palette: # 10 colors + for rgb_list in default_rgb_palette: # 10 colors domain = "_".join([str(c) for c in rgb_list]) domain = "rgb_" + domain list_domains.append(domain) return list_domains - def get_dset_by_domain(self, args, na_domain, split=True): # @FIXME: different number of arguments than parent + def get_dset_by_domain( + self, args, na_domain, split=True + ): # @FIXME: different number of arguments than parent """get_dset_by_domain. :param args: :param na_domain: diff --git a/domainlab/tasks/task_pathlist.py b/domainlab/tasks/task_pathlist.py index 607e3dff4..ab00c7647 100644 --- a/domainlab/tasks/task_pathlist.py +++ b/domainlab/tasks/task_pathlist.py @@ -13,7 +13,7 @@ from domainlab.dsets.utils_data import mk_fun_label2onehot from domainlab.tasks.b_task_classif import NodeTaskDictClassif -torch.multiprocessing.set_sharing_strategy('file_system') +torch.multiprocessing.set_sharing_strategy("file_system") # "too many opened files" https://github.com/pytorch/pytorch/issues/11201 @@ -21,20 +21,23 @@ class NodeTaskPathListDummy(NodeTaskDictClassif): """ typedef class so that other function can use isinstance """ + def get_dset_by_domain(self, args, na_domain, split=False): raise NotImplementedError -def mk_node_task_path_list(isize, - img_trans_te, - list_str_y, - img_trans_tr, - dict_class_label_ind2name, - dict_domain2imgroot, - dict_d2filepath_list_img_tr, - dict_d2filepath_list_img_val, - dict_d2filepath_list_img_te, - succ=None): +def mk_node_task_path_list( + isize, + img_trans_te, + list_str_y, + img_trans_tr, + dict_class_label_ind2name, + dict_domain2imgroot, + dict_d2filepath_list_img_tr, + dict_d2filepath_list_img_val, + dict_d2filepath_list_img_te, + succ=None, +): """mk_node_task_path_list. :param isize: @@ -47,6 +50,7 @@ def mk_node_task_path_list(isize, :param dict_d2filepath_list_img_te: :param succ: """ + class NodeTaskPathList(NodeTaskPathListDummy): """ The class TaskPathList provides the user an interface to provide a file @@ -56,6 +60,7 @@ class NodeTaskPathList(NodeTaskPathListDummy): slot contains the class label as a numerical string. e.g.: /path/2/file/art_painting/dog/pic_376.jpg 1 """ + def _get_complete_domain(self, na_domain, dict_domain2pathfilepath): """_get_complete_domain. @@ -72,9 +77,12 @@ def _get_complete_domain(self, na_domain, dict_domain2pathfilepath): path2filelist = dict_domain2pathfilepath[na_domain] path2filelist = os.path.expanduser(path2filelist) root_img = os.path.expanduser(root_img) - dset = DsetImPathList(root_img, path2filelist, trans_img=trans, - trans_target=mk_fun_label2onehot( - len(self.list_str_y))) + dset = DsetImPathList( + root_img, + path2filelist, + trans_img=trans, + trans_target=mk_fun_label2onehot(len(self.list_str_y)), + ) return dset def get_dset_by_domain(self, args, na_domain, split=True): @@ -89,20 +97,22 @@ def get_dset_by_domain(self, args, na_domain, split=True): # if split=False, then only te is used, which contains # the whole dataset dset = self._get_complete_domain( - na_domain, - self._dict_domain2filepath_list_im_te) + na_domain, self._dict_domain2filepath_list_im_te + ) # test set contains train+validation return dset, dset # @FIXME: avoid returning two identical dset = self._get_complete_domain( na_domain, # read training set from user configuration - self._dict_domain2filepath_list_im_tr) + self._dict_domain2filepath_list_im_tr, + ) dset_val = self._get_complete_domain( na_domain, # read validation set from user configuration - self._dict_domain2filepath_list_im_val) + self._dict_domain2filepath_list_im_val, + ) return dset, dset_val diff --git a/domainlab/tasks/utils_task.py b/domainlab/tasks/utils_task.py index 58166b91d..850364f0a 100644 --- a/domainlab/tasks/utils_task.py +++ b/domainlab/tasks/utils_task.py @@ -12,7 +12,7 @@ from domainlab.utils.utils_class import store_args -class ImSize(): +class ImSize: """ImSize.""" @store_args @@ -20,6 +20,7 @@ def __init__(self, i_c, i_h, i_w): """ store channel, height, width """ + @property def c(self): """image channel""" @@ -46,10 +47,7 @@ def mk_onehot(dim, ind): return vec -def mk_loader(dset, bsize, - drop_last=True, - shuffle=True, - num_workers=int(0)): +def mk_loader(dset, bsize, drop_last=True, shuffle=True, num_workers=int(0)): """ :param bs: batch size """ @@ -60,8 +58,9 @@ def mk_loader(dset, bsize, batch_size=bsize, shuffle=shuffle, # @FIXME: shuffle must be true so the last incomplete batch get used in another epoch? - num_workers=num_workers, # @FIXME: num_workers=int(0) can be slow? - drop_last=drop_last) + num_workers=num_workers, # @FIXME: num_workers=int(0) can be slow? + drop_last=drop_last, + ) return loader @@ -69,6 +68,7 @@ class DsetDomainVecDecorator(Dataset): """ decorate a pytorch dataset with a fixed vector representation of domain """ + def __init__(self, dset, vec_domain, na_domain): """ :param dset: x, y @@ -106,6 +106,7 @@ class DsetDomainVecDecoratorImgPath(DsetDomainVecDecorator): returned currently not in use since it is mostly important to print predictions together with path for the test domain """ + def __getitem__(self, idx): """ :param idx: @@ -118,6 +119,7 @@ class DsetClassVecDecorator(Dataset): """ decorate a pytorch dataset with a new class name """ + def __init__(self, dset, dict_folder_name2class_global, list_str_y): """ :param dset: x, y, *d @@ -125,15 +127,19 @@ def __init__(self, dset, dict_folder_name2class_global, list_str_y): class folder of domain to glbal class """ self.dset = dset - self.class2idx = {k:v for (k,v) in self.dset.class_to_idx.items() \ - if k in self.dset.list_class_dir} + self.class2idx = { + k: v + for (k, v) in self.dset.class_to_idx.items() + if k in self.dset.list_class_dir + } assert self.class2idx self.dict_folder_name2class_global = dict_folder_name2class_global self.list_str_y = list_str_y # inverst key:value to value:key for backward map self.dict_old_idx2old_class = dict((v, k) for k, v in self.class2idx.items()) dict_class_na_local2vec_new = dict( - (k, self.fun_class_local_na2vec_new(k)) for k, v in self.class2idx.items()) + (k, self.fun_class_local_na2vec_new(k)) for k, v in self.class2idx.items() + ) self.dict_class_na_local2vec_new = dict_class_na_local2vec_new @property @@ -177,10 +183,11 @@ def __getitem__(self, idx): return tensor, vec_class_new, path[0] -class LoaderDomainLabel(): +class LoaderDomainLabel: """ wraps a dataset with domain label and into a loader """ + def __init__(self, batch_size, dim_d): """__init__. @@ -214,14 +221,12 @@ def tensor1hot2ind(tensor_label): npa_label_ind = label_ind.numpy() return npa_label_ind + # @FIXME: this function couples strongly with the task, # should be a class method of task -def img_loader2dir(loader, - folder, - test=False, - list_domain_na=None, - list_class_na=None, - batches=5): +def img_loader2dir( + loader, folder, test=False, list_domain_na=None, list_class_na=None, batches=5 +): """ save images from loader to directory so speculate if loader is correct :param loader: pytorch data loader @@ -253,7 +258,7 @@ def img_loader2dir(loader, class_label_scalar = class_label_ind.item() if list_class_na is None: - str_class_label = "class_"+str(class_label_scalar) + str_class_label = "class_" + str(class_label_scalar) else: # @FIXME: where is the correspndance between # class ind_label and class str_label? @@ -270,12 +275,15 @@ def img_loader2dir(loader, arr = img[b_ind] img_vision = torchvision.transforms.ToPILImage()(arr) f_n = "_".join( - ["class", - str_class_label, - "domain", - str_domain_label, - "n", - str(counter)]) + [ + "class", + str_class_label, + "domain", + str_domain_label, + "n", + str(counter), + ] + ) counter += 1 path = os.path.join(folder, f_n + ".png") img_vision.save(path) diff --git a/domainlab/tasks/utils_task_dset.py b/domainlab/tasks/utils_task_dset.py index 3ed642054..0b3577c99 100644 --- a/domainlab/tasks/utils_task_dset.py +++ b/domainlab/tasks/utils_task_dset.py @@ -2,6 +2,7 @@ task specific dataset operation """ import random + from torch.utils.data import Dataset @@ -9,6 +10,7 @@ class DsetIndDecorator4XYD(Dataset): """ For dataset of x, y, d, decorate it wih index """ + def __init__(self, dset): """ :param dset: x,y,d @@ -17,7 +19,8 @@ def __init__(self, dset): if len(tuple_m) < 3: raise RuntimeError( "dataset to be wrapped should output at least x, y, and d; got length ", - len(tuple_m)) + len(tuple_m), + ) self.dset = dset def __getitem__(self, index): @@ -37,6 +40,7 @@ class DsetZip(Dataset): to avoid always the same match, the second dataset does not use the same idx in __get__item() but instead, a random one """ + def __init__(self, dset1, dset2, name=None): """ :param dset1: x1, y1, *d1 @@ -56,7 +60,16 @@ def __getitem__(self, idx): idx2 = idx2 % self.len2 tensor_x_1, vec_y_1, vec_d_1, *others_1 = self.dset1.__getitem__(idx) tensor_x_2, vec_y_2, vec_d_2, *others_2 = self.dset2.__getitem__(idx2) - return tensor_x_1, vec_y_1, vec_d_1, others_1, tensor_x_2, vec_y_2, vec_d_2, others_2 + return ( + tensor_x_1, + vec_y_1, + vec_d_1, + others_1, + tensor_x_2, + vec_y_2, + vec_d_2, + others_2, + ) def __len__(self): len1 = self.dset1.__len__() diff --git a/domainlab/tasks/zoo_tasks.py b/domainlab/tasks/zoo_tasks.py index 5fb1d6c87..ca050520a 100644 --- a/domainlab/tasks/zoo_tasks.py +++ b/domainlab/tasks/zoo_tasks.py @@ -4,10 +4,10 @@ from domainlab.arg_parser import mk_parser_main from domainlab.compos.pcr.request import RequestTask +from domainlab.tasks.task_mini_vlcs import addtask2chain from domainlab.tasks.task_mnist_color import NodeTaskMNISTColor10 -from domainlab.utils.u_import import import_path from domainlab.utils.logger import Logger -from domainlab.tasks.task_mini_vlcs import addtask2chain +from domainlab.utils.u_import import import_path class TaskChainNodeGetter(object): @@ -15,6 +15,7 @@ class TaskChainNodeGetter(object): 1. Hardcoded chain 3. Return selected node """ + def __init__(self, args): self.args = args tpath = args.tpath @@ -38,8 +39,9 @@ def __call__(self): if self.args.task is None: logger = Logger.get_logger() logger.info("") - logger.info(f"overriding args.task {self.args.task} " - f"to {node.task_name}") + logger.info( + f"overriding args.task {self.args.task} " f"to {node.task_name}" + ) logger.info("") self.request = node.task_name # @FIXME node = chain.handle(self.request) diff --git a/domainlab/utils/flows_gen_img_model.py b/domainlab/utils/flows_gen_img_model.py index 207c9788e..0933c5bdc 100644 --- a/domainlab/utils/flows_gen_img_model.py +++ b/domainlab/utils/flows_gen_img_model.py @@ -6,7 +6,7 @@ from domainlab.utils.utils_img_sav import mk_fun_sav_img -class FlowGenImgs(): +class FlowGenImgs: def __init__(self, model, device): model = model.to(device) self.obj_recon = ReconVAEXYD(model) @@ -56,54 +56,69 @@ def gen_img_xyd(self, img, vec_y, vec_d, device, path, folder_na): def _flow_vanilla(self, img, vec_y, vec_d, device, num_sample=10): x_recon_img, str_type = self.obj_recon.recon(img) - self._save_pair(x_recon_img, device, img, str_type + '.png') + self._save_pair(x_recon_img, device, img, str_type + ".png") x_recon_img, str_type = self.obj_recon.recon(img, vec_y) - self._save_pair(x_recon_img, device, img, str_type + '.png') + self._save_pair(x_recon_img, device, img, str_type + ".png") if vec_d is not None: x_recon_img, str_type = self.obj_recon.recon(img, None, vec_d) - self._save_pair(x_recon_img, device, img, str_type + '.png') + self._save_pair(x_recon_img, device, img, str_type + ".png") for i in range(num_sample): x_recon_img, str_type = self.obj_recon.recon(img, vec_y, vec_d, True, True) x_recon_img = x_recon_img.to(device) comparison = torch.cat([img, x_recon_img]) - self.sav_fun(comparison, str_type + str(i) + '.png') + self.sav_fun(comparison, str_type + str(i) + ".png") def _flow_cf_y(self, img, vec_y, vec_d, device): """ scan possible values of vec_y """ - recon_list, str_type = self.obj_recon.recon_cf(img, "y", vec_y.shape[1], device, - zx2fill=None) - self._save_list(recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png") - recon_list, str_type = self.obj_recon.recon_cf(img, "y", vec_y.shape[1], device, zx2fill=0) - self._save_list(recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png") + recon_list, str_type = self.obj_recon.recon_cf( + img, "y", vec_y.shape[1], device, zx2fill=None + ) + self._save_list( + recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png" + ) + recon_list, str_type = self.obj_recon.recon_cf( + img, "y", vec_y.shape[1], device, zx2fill=0 + ) + self._save_list( + recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png" + ) if vec_d is not None: recon_list, str_type = self.obj_recon.recon_cf( - img, "y", vec_y.shape[1], device, - vec_d=vec_d, - zx2fill=0) - self._save_list(recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png") + img, "y", vec_y.shape[1], device, vec_d=vec_d, zx2fill=0 + ) + self._save_list( + recon_list, device, img, "_".join(["recon_cf_y", str_type]) + ".png" + ) def _flow_cf_d(self, img, vec_y, vec_d, device): """ scan possible values of vec_y """ - recon_list, str_type = self.obj_recon.recon_cf(img, "d", vec_d.shape[1], device, - zx2fill=None) - self._save_list(recon_list, device, img, "_".join(["recon_cf_d", str_type]) +".png") - recon_list, str_type = self.obj_recon.recon_cf(img, "d", vec_d.shape[1], device, zx2fill=0) - self._save_list(recon_list, device, img, "_".join(["recon_cf_d", str_type]) +".png") + recon_list, str_type = self.obj_recon.recon_cf( + img, "d", vec_d.shape[1], device, zx2fill=None + ) + self._save_list( + recon_list, device, img, "_".join(["recon_cf_d", str_type]) + ".png" + ) + recon_list, str_type = self.obj_recon.recon_cf( + img, "d", vec_d.shape[1], device, zx2fill=0 + ) + self._save_list( + recon_list, device, img, "_".join(["recon_cf_d", str_type]) + ".png" + ) def fun_gen(model, device, node, args, subfolder_na, output_folder_na="gen"): flow = FlowGenImgs(model, device) - path = os.path.join(args.out, output_folder_na, node.task_name, args.model, subfolder_na) - flow.gen_img_loader(node.loader_te, device, - path=path, - domain="_".join(args.te_d)) - flow.gen_img_loader(node.loader_tr, device, - path=path, - domain="_".join(node.list_domain_tr)) + path = os.path.join( + args.out, output_folder_na, node.task_name, args.model, subfolder_na + ) + flow.gen_img_loader(node.loader_te, device, path=path, domain="_".join(args.te_d)) + flow.gen_img_loader( + node.loader_tr, device, path=path, domain="_".join(node.list_domain_tr) + ) diff --git a/domainlab/utils/generate_benchmark_plots.py b/domainlab/utils/generate_benchmark_plots.py index 5f63973bb..e8196d8e3 100644 --- a/domainlab/utils/generate_benchmark_plots.py +++ b/domainlab/utils/generate_benchmark_plots.py @@ -1,16 +1,18 @@ -''' +""" generate the benchmark plots by calling the gen_bencmark_plots(...) function -''' +""" import os -from ast import literal_eval # literal_eval can safe evaluate python expression -import matplotlib.pyplot as plt +from ast import literal_eval # literal_eval can safe evaluate python expression + import matplotlib +import matplotlib.pyplot as plt +import numpy as np import pandas as pd import seaborn as sns -import numpy as np + from domainlab.utils.logger import Logger -matplotlib.use('Agg') +matplotlib.use("Agg") # header of the csv file: # param_index, task, algo, epos, te_d, seed, params, acc, precision, recall, specificity, f1, auroc @@ -18,13 +20,14 @@ COLNAME_METHOD = "method" COLNAME_IDX_PARAM = "param_index" COLNAME_PARAM = "params" -G_DF_TASK_COL = 1 # column in which the method name is saved -G_DF_PLOT_COL_METRIC_START = 9 # first 0-6 columns are not metric +G_DF_TASK_COL = 1 # column in which the method name is saved +G_DF_PLOT_COL_METRIC_START = 9 # first 0-6 columns are not metric - -def gen_benchmark_plots(agg_results: str, output_dir: str, use_param_index: bool = True): - ''' +def gen_benchmark_plots( + agg_results: str, output_dir: str, use_param_index: bool = True +): + """ generate the benchmark plots from a csv file containing the aggregated restults. The csv file must have the columns: [param_index, task, algo, epos, te_d, seed, params, ...] @@ -34,158 +37,246 @@ def gen_benchmark_plots(agg_results: str, output_dir: str, use_param_index: bool agg_results: path to the csv file output_dir: path to a folder which shall contain the results skip_gen: Skips the actual plotting, used to speed up testing. - ''' - raw_df = pd.read_csv(agg_results, index_col=False, - converters={COLNAME_PARAM: literal_eval}, - # literal_eval can safe evaluate python expression - skipinitialspace=True) - - raw_df[COLNAME_PARAM] = round_vals_in_dict(raw_df[[COLNAME_IDX_PARAM, COLNAME_PARAM]], - use_param_index) + """ + raw_df = pd.read_csv( + agg_results, + index_col=False, + converters={COLNAME_PARAM: literal_eval}, + # literal_eval can safe evaluate python expression + skipinitialspace=True, + ) + + raw_df[COLNAME_PARAM] = round_vals_in_dict( + raw_df[[COLNAME_IDX_PARAM, COLNAME_PARAM]], use_param_index + ) # generating plot gen_plots(raw_df, output_dir, use_param_index) def round_vals_in_dict(df_column_in, use_param_index): - ''' + """ replaces the dictionary by a string containing only the significant digits of the hyperparams or (if use_param_index = True) by the parameter index df_column_in: columns of the dataframe containing the param index and the dictionary of hyperparams in the form [param_index, params] use_param_index: usage of param_index instead of exact values - ''' + """ df_column = df_column_in.copy() df_column_out = df_column_in[COLNAME_IDX_PARAM].copy() df_column_out = df_column_out.astype(str) for i in range(df_column.shape[0]): if not use_param_index: - string = '' + string = "" for num, val in enumerate(list(df_column[COLNAME_PARAM][i].values())): key = list(df_column[COLNAME_PARAM][i].keys())[num] - val = np.format_float_scientific(val, precision=1, unique=False, trim='0') - string += str(key) + ': ' + str(val) + ', ' + val = np.format_float_scientific( + val, precision=1, unique=False, trim="0" + ) + string += str(key) + ": " + str(val) + ", " df_column_out[i] = string[:-2] # remove ', ' from the end of the string else: - string = 'idx: ' + str(df_column[COLNAME_IDX_PARAM][i]) + string = "idx: " + str(df_column[COLNAME_IDX_PARAM][i]) df_column_out[i] = string return df_column_out def gen_plots(dataframe: pd.DataFrame, output_dir: str, use_param_index: bool): - ''' + """ dataframe: dataframe with columns ['param_index','task',' algo',' epos',' te_d',' seed',' params',' acc','precision',...] - ''' + """ os.makedirs(output_dir, exist_ok=True) obj = dataframe.columns[G_DF_PLOT_COL_METRIC_START:] # boxplots for objective in obj: - boxplot(dataframe, objective, file=output_dir + '/variational_plots/' + objective) + boxplot( + dataframe, objective, file=output_dir + "/variational_plots/" + objective + ) # scatterplot matrices - scatterplot_matrix(dataframe, use_param_index, - file=output_dir + '/sp_matrix_reg.png', - kind='reg', distinguish_param_setups=False) - scatterplot_matrix(dataframe, use_param_index, - file=output_dir + '/sp_matrix.png', - kind='scatter', distinguish_param_setups=False) - scatterplot_matrix(dataframe, use_param_index, - file=output_dir + '/sp_matrix_dist_reg.png', - kind='reg', distinguish_param_setups=True) - scatterplot_matrix(dataframe, use_param_index, - file=output_dir + '/sp_matrix_dist.png', - kind='scatter', distinguish_param_setups=True) + scatterplot_matrix( + dataframe, + use_param_index, + file=output_dir + "/sp_matrix_reg.png", + kind="reg", + distinguish_param_setups=False, + ) + scatterplot_matrix( + dataframe, + use_param_index, + file=output_dir + "/sp_matrix.png", + kind="scatter", + distinguish_param_setups=False, + ) + scatterplot_matrix( + dataframe, + use_param_index, + file=output_dir + "/sp_matrix_dist_reg.png", + kind="reg", + distinguish_param_setups=True, + ) + scatterplot_matrix( + dataframe, + use_param_index, + file=output_dir + "/sp_matrix_dist.png", + kind="scatter", + distinguish_param_setups=True, + ) # radar plots - radar_plot(dataframe, file=output_dir + '/radar_dist.png', distinguish_hyperparam=True) - radar_plot(dataframe, file=output_dir + '/radar.png', distinguish_hyperparam=False) + radar_plot( + dataframe, file=output_dir + "/radar_dist.png", distinguish_hyperparam=True + ) + radar_plot(dataframe, file=output_dir + "/radar.png", distinguish_hyperparam=False) # scatter plots for parirs of objectives - os.makedirs(output_dir + '/scatterpl', exist_ok=True) + os.makedirs(output_dir + "/scatterpl", exist_ok=True) for i, obj_i in enumerate(obj): - for j in range(i+1, len(obj)): + for j in range(i + 1, len(obj)): try: - scatterplot(dataframe, [obj_i, obj[j]], - file=output_dir + '/scatterpl/' + obj_i + '_' + obj[j] + '.png') + scatterplot( + dataframe, + [obj_i, obj[j]], + file=output_dir + "/scatterpl/" + obj_i + "_" + obj[j] + ".png", + ) except IndexError: logger = Logger.get_logger() - logger.warning(f'disabling kde because cov matrix is singular for objectives ' - f'{obj_i} & {obj[j]}') - scatterplot(dataframe, [obj_i, obj[j]], - file=output_dir + '/scatterpl/' + obj_i + '_' + obj[j] + '.png', - kde=False) + logger.warning( + f"disabling kde because cov matrix is singular for objectives " + f"{obj_i} & {obj[j]}" + ) + scatterplot( + dataframe, + [obj_i, obj[j]], + file=output_dir + "/scatterpl/" + obj_i + "_" + obj[j] + ".png", + kde=False, + ) # create plots for the different algortihms for algorithm in dataframe[COLNAME_METHOD].unique(): - os.makedirs(output_dir + '/' + str(algorithm), exist_ok=True) + os.makedirs(output_dir + "/" + str(algorithm), exist_ok=True) dataframe_algo = dataframe[dataframe[COLNAME_METHOD] == algorithm] # boxplots for objective in obj: - boxplot(dataframe_algo, objective, - file=output_dir + '/' + str(algorithm) + '/variational_plots/' + objective) + boxplot( + dataframe_algo, + objective, + file=output_dir + + "/" + + str(algorithm) + + "/variational_plots/" + + objective, + ) # scatterplot matrices - scatterplot_matrix(dataframe_algo, use_param_index, - file=output_dir + '/' + str(algorithm) + '/sp_matrix_reg.png', - kind='reg', distinguish_param_setups=False) - scatterplot_matrix(dataframe_algo, use_param_index, - file=output_dir + '/' + str(algorithm) + '/sp_matrix.png', - kind='scatter', distinguish_param_setups=False) - scatterplot_matrix(dataframe_algo, use_param_index, - file=output_dir + '/' + str(algorithm) + '/sp_matrix_dist_reg.png', - kind='reg', distinguish_param_setups=True) - scatterplot_matrix(dataframe_algo, use_param_index, - file=output_dir + '/' + str(algorithm) + '/sp_matrix_dist.png', - kind='scatter', distinguish_param_setups=True) + scatterplot_matrix( + dataframe_algo, + use_param_index, + file=output_dir + "/" + str(algorithm) + "/sp_matrix_reg.png", + kind="reg", + distinguish_param_setups=False, + ) + scatterplot_matrix( + dataframe_algo, + use_param_index, + file=output_dir + "/" + str(algorithm) + "/sp_matrix.png", + kind="scatter", + distinguish_param_setups=False, + ) + scatterplot_matrix( + dataframe_algo, + use_param_index, + file=output_dir + "/" + str(algorithm) + "/sp_matrix_dist_reg.png", + kind="reg", + distinguish_param_setups=True, + ) + scatterplot_matrix( + dataframe_algo, + use_param_index, + file=output_dir + "/" + str(algorithm) + "/sp_matrix_dist.png", + kind="scatter", + distinguish_param_setups=True, + ) # radar plots - radar_plot(dataframe_algo, file=output_dir + '/' + str(algorithm) + '/radar_dist.png', - distinguish_hyperparam=True) - radar_plot(dataframe_algo, file=output_dir + '/' + str(algorithm) + '/radar.png', - distinguish_hyperparam=False) + radar_plot( + dataframe_algo, + file=output_dir + "/" + str(algorithm) + "/radar_dist.png", + distinguish_hyperparam=True, + ) + radar_plot( + dataframe_algo, + file=output_dir + "/" + str(algorithm) + "/radar.png", + distinguish_hyperparam=False, + ) # scatter plots for parirs of objectives - os.makedirs(output_dir + '/' + str(algorithm) + '/scatterpl', exist_ok=True) + os.makedirs(output_dir + "/" + str(algorithm) + "/scatterpl", exist_ok=True) for i, obj_i in enumerate(obj): for j in range(i + 1, len(obj)): try: - scatterplot(dataframe_algo, [obj_i, obj[j]], - file=output_dir + '/' + str(algorithm) + - '/scatterpl/' + obj_i + '_' + obj[j] + '.png', - distinguish_hyperparam=True) + scatterplot( + dataframe_algo, + [obj_i, obj[j]], + file=output_dir + + "/" + + str(algorithm) + + "/scatterpl/" + + obj_i + + "_" + + obj[j] + + ".png", + distinguish_hyperparam=True, + ) except IndexError: logger = Logger.get_logger() - logger.warning(f'WARNING: disabling kde because cov matrix is singular ' - f'for objectives {obj_i} & {obj[j]}') - scatterplot(dataframe_algo, [obj_i, obj[j]], - file=output_dir + '/' + str(algorithm) + - '/scatterpl/' + obj_i + '_' + obj[j] + '.png', - kde=False, - distinguish_hyperparam=True) - - -def scatterplot_matrix(dataframe_in, use_param_index, file=None, kind='reg', - distinguish_param_setups=True): - ''' + logger.warning( + f"WARNING: disabling kde because cov matrix is singular " + f"for objectives {obj_i} & {obj[j]}" + ) + scatterplot( + dataframe_algo, + [obj_i, obj[j]], + file=output_dir + + "/" + + str(algorithm) + + "/scatterpl/" + + obj_i + + "_" + + obj[j] + + ".png", + kde=False, + distinguish_hyperparam=True, + ) + + +def scatterplot_matrix( + dataframe_in, use_param_index, file=None, kind="reg", distinguish_param_setups=True +): + """ dataframe: dataframe containing the data with columns [algo, epos, te_d, seed, params, obj1, ..., obj2] file: filename to save the plots (if None, the plot will not be saved) reg: if True a regression line will be plotted over the data distinguish_param_setups: if True the plot will not only distinguish between models, but also between the parameter setups - ''' + """ dataframe = dataframe_in.copy() index = list(range(G_DF_PLOT_COL_METRIC_START, dataframe.shape[1])) if distinguish_param_setups: dataframe_ = dataframe.iloc[:, index] - dataframe_.insert(0, 'label', - dataframe[COLNAME_METHOD].astype(str) + ', ' + - dataframe[COLNAME_PARAM].astype(str)) - - g_p = sns.pairplot(data=dataframe_, hue='label', corner=True, kind=kind) + dataframe_.insert( + 0, + "label", + dataframe[COLNAME_METHOD].astype(str) + + ", " + + dataframe[COLNAME_PARAM].astype(str), + ) + + g_p = sns.pairplot(data=dataframe_, hue="label", corner=True, kind=kind) else: index_ = list(range(G_DF_PLOT_COL_METRIC_START, dataframe.shape[1])) index_.insert(0, G_DF_TASK_COL) @@ -202,9 +293,9 @@ def scatterplot_matrix(dataframe_in, use_param_index, file=None, kind='reg', g_p.fig.set_size_inches(12.5, 12) if use_param_index and distinguish_param_setups: - sns.move_legend(g_p, loc='upper right', bbox_to_anchor=(1., 1.), ncol=3) + sns.move_legend(g_p, loc="upper right", bbox_to_anchor=(1.0, 1.0), ncol=3) else: - sns.move_legend(g_p, loc='upper right', bbox_to_anchor=(1., 1.), ncol=1) + sns.move_legend(g_p, loc="upper right", bbox_to_anchor=(1.0, 1.0), ncol=1) plt.tight_layout() if file is not None: @@ -212,7 +303,7 @@ def scatterplot_matrix(dataframe_in, use_param_index, file=None, kind='reg', def scatterplot(dataframe_in, obj, file=None, kde=True, distinguish_hyperparam=False): - ''' + """ dataframe: dataframe containing the data with columns [algo, epos, te_d, seed, params, obj1, ..., obj2] obj1 & obj2: name of the objectives which shall be plotted against each other @@ -220,7 +311,7 @@ def scatterplot(dataframe_in, obj, file=None, kde=True, distinguish_hyperparam=F kde: if True the distribution of the points will be estimated and plotted as kde plot distinguish_param_setups: if True the plot will not only distinguish between models, but also between the parameter setups - ''' + """ obj1, obj2 = obj dataframe = dataframe_in.copy() @@ -228,70 +319,127 @@ def scatterplot(dataframe_in, obj, file=None, kde=True, distinguish_hyperparam=F if distinguish_hyperparam: if kde: - g_p = sns.jointplot(data=dataframe, x=obj1, y=obj2, hue=COLNAME_PARAM, - xlim=(-0.1, 1.1), ylim=(-0.1, 1.1), kind='kde', - zorder=0, levels=8, alpha=0.35, warn_singular=False) - gg_p = sns.scatterplot(data=dataframe, x=obj1, y=obj2, hue=COLNAME_PARAM, - ax=g_p.ax_joint) + g_p = sns.jointplot( + data=dataframe, + x=obj1, + y=obj2, + hue=COLNAME_PARAM, + xlim=(-0.1, 1.1), + ylim=(-0.1, 1.1), + kind="kde", + zorder=0, + levels=8, + alpha=0.35, + warn_singular=False, + ) + gg_p = sns.scatterplot( + data=dataframe, x=obj1, y=obj2, hue=COLNAME_PARAM, ax=g_p.ax_joint + ) else: - g_p = sns.jointplot(data=dataframe, x=obj1, y=obj2, hue=COLNAME_PARAM, - xlim=(-0.1, 1.1), ylim=(-0.1, 1.1)) + g_p = sns.jointplot( + data=dataframe, + x=obj1, + y=obj2, + hue=COLNAME_PARAM, + xlim=(-0.1, 1.1), + ylim=(-0.1, 1.1), + ) gg_p = g_p.ax_joint else: if kde: - g_p = sns.jointplot(data=dataframe, x=obj1, y=obj2, hue=COLNAME_METHOD, - xlim=(-0.1, 1.1), ylim=(-0.1, 1.1), kind='kde', - zorder=0, levels=8, alpha=0.35, warn_singular=False) - gg_p = sns.scatterplot(data=dataframe, x=obj1, y=obj2, hue=COLNAME_METHOD, - style=COLNAME_PARAM, - ax=g_p.ax_joint) + g_p = sns.jointplot( + data=dataframe, + x=obj1, + y=obj2, + hue=COLNAME_METHOD, + xlim=(-0.1, 1.1), + ylim=(-0.1, 1.1), + kind="kde", + zorder=0, + levels=8, + alpha=0.35, + warn_singular=False, + ) + gg_p = sns.scatterplot( + data=dataframe, + x=obj1, + y=obj2, + hue=COLNAME_METHOD, + style=COLNAME_PARAM, + ax=g_p.ax_joint, + ) else: - g_p = sns.jointplot(data=dataframe, x=obj1, y=obj2, hue=COLNAME_METHOD, - xlim=(-0.1, 1.1), ylim=(-0.1, 1.1)) - gg_p = sns.scatterplot(data=dataframe, x=obj1, y=obj2, style=COLNAME_PARAM, - ax=g_p.ax_joint) - - gg_p.set_aspect('equal') - gg_p.legend(fontsize=6, loc='best') + g_p = sns.jointplot( + data=dataframe, + x=obj1, + y=obj2, + hue=COLNAME_METHOD, + xlim=(-0.1, 1.1), + ylim=(-0.1, 1.1), + ) + gg_p = sns.scatterplot( + data=dataframe, x=obj1, y=obj2, style=COLNAME_PARAM, ax=g_p.ax_joint + ) + + gg_p.set_aspect("equal") + gg_p.legend(fontsize=6, loc="best") if file is not None: plt.savefig(file, dpi=300) def max_0_x(x_arg): - ''' + """ max(0, x_arg) - ''' + """ return max(0, x_arg) def radar_plot(dataframe_in, file=None, distinguish_hyperparam=True): - ''' + """ dataframe_in: dataframe containing the data with columns [algo, epos, te_d, seed, params, obj1, ..., obj2] file: filename to save the plots (if None, the plot will not be saved) distinguish_param_setups: if True the plot will not only distinguish between models, but also between the parameter setups - ''' + """ dataframe = dataframe_in.copy() if distinguish_hyperparam: - dataframe.insert(0, 'label', - dataframe[COLNAME_METHOD].astype(str) + ', ' + - dataframe[COLNAME_PARAM].astype(str)) + dataframe.insert( + 0, + "label", + dataframe[COLNAME_METHOD].astype(str) + + ", " + + dataframe[COLNAME_PARAM].astype(str), + ) else: - dataframe.insert(0, 'label', dataframe[COLNAME_METHOD]) + dataframe.insert(0, "label", dataframe[COLNAME_METHOD]) # we need "G_DF_PLOT_COL_METRIC_START + 1" as we did insert the columns 'label' at index 0 index = list(range(G_DF_PLOT_COL_METRIC_START + 1, dataframe.shape[1])) - num_lines = len(dataframe['label'].unique()) - _, axis = plt.subplots(figsize=(9, 9 + (0.28 * num_lines)), subplot_kw=dict(polar=True)) + num_lines = len(dataframe["label"].unique()) + _, axis = plt.subplots( + figsize=(9, 9 + (0.28 * num_lines)), subplot_kw=dict(polar=True) + ) num = 0 # Split the circle into even parts and save the angles # so we know where to put each axis. - angles = list(np.linspace(0, 2 * np.pi, len(dataframe.columns[index]), endpoint=False)) - for algo_name in dataframe['label'].unique(): - mean = dataframe.loc[dataframe['label'] == algo_name].iloc[:, index].mean().to_list() - std = dataframe.loc[dataframe['label'] == algo_name].iloc[:, index].std().to_list() + angles = list( + np.linspace(0, 2 * np.pi, len(dataframe.columns[index]), endpoint=False) + ) + for algo_name in dataframe["label"].unique(): + mean = ( + dataframe.loc[dataframe["label"] == algo_name] + .iloc[:, index] + .mean() + .to_list() + ) + std = ( + dataframe.loc[dataframe["label"] == algo_name] + .iloc[:, index] + .std() + .to_list() + ) angles_ = angles # The plot is a circle, so we need to "complete the loop" @@ -301,15 +449,22 @@ def radar_plot(dataframe_in, file=None, distinguish_hyperparam=True): angles_ = np.array(angles_ + angles_[:1]) # Draw the outline of the data. - axis.plot(angles_, mean, - color=list(plt.rcParams["axes.prop_cycle"])[num]['color'], - linewidth=2, label=algo_name) + axis.plot( + angles_, + mean, + color=list(plt.rcParams["axes.prop_cycle"])[num]["color"], + linewidth=2, + label=algo_name, + ) # Fill it in. - axis.fill_between(angles_, list(map(max_0_x, mean - std)), - y2=mean + std, - color=list(plt.rcParams["axes.prop_cycle"])[num]['color'], - alpha=0.1) + axis.fill_between( + angles_, + list(map(max_0_x, mean - std)), + y2=mean + std, + color=list(plt.rcParams["axes.prop_cycle"])[num]["color"], + alpha=0.1, + ) num += 1 num = num % len(list(plt.rcParams["axes.prop_cycle"])) @@ -322,126 +477,180 @@ def radar_plot(dataframe_in, file=None, distinguish_hyperparam=True): axis.set_ylim((0, 1)) - plt.legend(loc='lower right', bbox_to_anchor=(1., 1.035), - ncol=1, fontsize=10) + plt.legend(loc="lower right", bbox_to_anchor=(1.0, 1.035), ncol=1, fontsize=10) if file is not None: plt.savefig(file, dpi=300) def boxplot(dataframe_in, obj, file=None): - ''' + """ generate the boxplots dataframe_in: dataframe containing the data with columns [param_idx, task , algo, epos, te_d, seed, params, obj1, ..., obj2] obj: objective to be considered in the plot (needs to be contained in dataframe_in) file: foldername to save the plots (if None, the plot will not be saved) - ''' + """ boxplot_stochastic(dataframe_in, obj, file=file) boxplot_systematic(dataframe_in, obj, file=file) + def boxplot_stochastic(dataframe_in, obj, file=None): - ''' + """ generate boxplot for stochastic variation dataframe_in: dataframe containing the data with columns [param_idx, task , algo, epos, te_d, seed, params, obj1, ..., obj2] obj: objective to be considered in the plot (needs to be contained in dataframe_in) file: foldername to save the plots (if None, the plot will not be saved) - ''' + """ dataframe = dataframe_in.copy() os.makedirs(file, exist_ok=True) ### stochastic variation - _, axes = plt.subplots(1, len(dataframe[COLNAME_METHOD].unique()), sharey=True, - figsize=(3 * len(dataframe[COLNAME_METHOD].unique()), 6)) + _, axes = plt.subplots( + 1, + len(dataframe[COLNAME_METHOD].unique()), + sharey=True, + figsize=(3 * len(dataframe[COLNAME_METHOD].unique()), 6), + ) # iterate over all algorithms for num, algo in enumerate(list(dataframe[COLNAME_METHOD].unique())): # distinguish if the algorithm does only have one param setup or multiple if len(dataframe[COLNAME_METHOD].unique()) > 1: # generate boxplot and swarmplot - sns.boxplot(data=dataframe[dataframe[COLNAME_METHOD] == algo], - x=COLNAME_IDX_PARAM, y=obj, - ax=axes[num], showfliers=False, - boxprops={"facecolor": (.4, .6, .8, .5)}) - sns.swarmplot(data=dataframe[dataframe[COLNAME_METHOD] == algo], - x=COLNAME_IDX_PARAM, y=obj, - legend=False, ax=axes[num]) + sns.boxplot( + data=dataframe[dataframe[COLNAME_METHOD] == algo], + x=COLNAME_IDX_PARAM, + y=obj, + ax=axes[num], + showfliers=False, + boxprops={"facecolor": (0.4, 0.6, 0.8, 0.5)}, + ) + sns.swarmplot( + data=dataframe[dataframe[COLNAME_METHOD] == algo], + x=COLNAME_IDX_PARAM, + y=obj, + legend=False, + ax=axes[num], + ) # remove legend, set ylim, set x-label and remove y-label axes[num].legend([], [], frameon=False) axes[num].set_ylim([-0.1, 1.1]) axes[num].set_xlabel(algo) if num != 0: - axes[num].set_ylabel('') + axes[num].set_ylabel("") else: - sns.boxplot(data=dataframe[dataframe[COLNAME_METHOD] == algo], - x=COLNAME_IDX_PARAM, y=obj, - ax=axes, showfliers=False, - boxprops={"facecolor": (.4, .6, .8, .5)}) - sns.swarmplot(data=dataframe[dataframe[COLNAME_METHOD] == algo], - x=COLNAME_IDX_PARAM, y=obj, hue=COLNAME_IDX_PARAM, - legend=False, ax=axes, - palette=sns.cubehelix_palette(n_colors=len( - dataframe[dataframe[COLNAME_METHOD] == algo] - [COLNAME_IDX_PARAM].unique()))) + sns.boxplot( + data=dataframe[dataframe[COLNAME_METHOD] == algo], + x=COLNAME_IDX_PARAM, + y=obj, + ax=axes, + showfliers=False, + boxprops={"facecolor": (0.4, 0.6, 0.8, 0.5)}, + ) + sns.swarmplot( + data=dataframe[dataframe[COLNAME_METHOD] == algo], + x=COLNAME_IDX_PARAM, + y=obj, + hue=COLNAME_IDX_PARAM, + legend=False, + ax=axes, + palette=sns.cubehelix_palette( + n_colors=len( + dataframe[dataframe[COLNAME_METHOD] == algo][ + COLNAME_IDX_PARAM + ].unique() + ) + ), + ) axes.legend([], [], frameon=False) axes.set_ylim([-0.1, 1.1]) axes.set_xlabel(algo) plt.tight_layout() if file is not None: - plt.savefig(file + '/stochastic_variation.png', dpi=300) + plt.savefig(file + "/stochastic_variation.png", dpi=300) def boxplot_systematic(dataframe_in, obj, file=None): - ''' + """ generate boxplot for ssystemtic variation dataframe_in: dataframe containing the data with columns [param_idx, task , algo, epos, te_d, seed, params, obj1, ..., obj2] obj: objective to be considered in the plot (needs to be contained in dataframe_in) file: foldername to save the plots (if None, the plot will not be saved) - ''' + """ dataframe = dataframe_in.copy() os.makedirs(file, exist_ok=True) ### systematic variation - _, axes = plt.subplots(1, len(dataframe[COLNAME_METHOD].unique()), sharey=True, - figsize=(3 * len(dataframe[COLNAME_METHOD].unique()), 6)) + _, axes = plt.subplots( + 1, + len(dataframe[COLNAME_METHOD].unique()), + sharey=True, + figsize=(3 * len(dataframe[COLNAME_METHOD].unique()), 6), + ) for num, algo in enumerate(list(dataframe[COLNAME_METHOD].unique())): # distinguish if the algorithm does only have one param setup or multiple if len(dataframe[COLNAME_METHOD].unique()) > 1: # generate boxplot and swarmplot - sns.boxplot(data=dataframe[dataframe[COLNAME_METHOD] == algo], - x=COLNAME_METHOD, y=obj, - ax=axes[num], showfliers=False, - boxprops={"facecolor": (.4, .6, .8, .5)}) - sns.swarmplot(data=dataframe[dataframe[COLNAME_METHOD] == algo], - x=COLNAME_METHOD, y=obj, hue=COLNAME_IDX_PARAM, - legend=False, ax=axes[num], - palette=sns.cubehelix_palette(n_colors=len( - dataframe[dataframe[COLNAME_METHOD] == algo] - [COLNAME_IDX_PARAM].unique()))) + sns.boxplot( + data=dataframe[dataframe[COLNAME_METHOD] == algo], + x=COLNAME_METHOD, + y=obj, + ax=axes[num], + showfliers=False, + boxprops={"facecolor": (0.4, 0.6, 0.8, 0.5)}, + ) + sns.swarmplot( + data=dataframe[dataframe[COLNAME_METHOD] == algo], + x=COLNAME_METHOD, + y=obj, + hue=COLNAME_IDX_PARAM, + legend=False, + ax=axes[num], + palette=sns.cubehelix_palette( + n_colors=len( + dataframe[dataframe[COLNAME_METHOD] == algo][ + COLNAME_IDX_PARAM + ].unique() + ) + ), + ) # remove legend, set ylim, set x-label and remove y-label axes[num].legend([], [], frameon=False) axes[num].set_ylim([-0.1, 1.1]) - axes[num].set_xlabel(' ') + axes[num].set_xlabel(" ") if num != 0: - axes[num].set_ylabel('') + axes[num].set_ylabel("") else: - sns.boxplot(data=dataframe[dataframe[COLNAME_METHOD] == algo], - x=COLNAME_METHOD, y=obj, - ax=axes, showfliers=False, - boxprops={"facecolor": (.4, .6, .8, .5)}) - sns.swarmplot(data=dataframe[dataframe[COLNAME_METHOD] == algo], - x=COLNAME_METHOD, y=obj, hue=COLNAME_IDX_PARAM, - legend=False, ax=axes, - palette=sns.cubehelix_palette(n_colors=len( - dataframe[dataframe[COLNAME_METHOD] == algo] - [COLNAME_IDX_PARAM].unique()))) + sns.boxplot( + data=dataframe[dataframe[COLNAME_METHOD] == algo], + x=COLNAME_METHOD, + y=obj, + ax=axes, + showfliers=False, + boxprops={"facecolor": (0.4, 0.6, 0.8, 0.5)}, + ) + sns.swarmplot( + data=dataframe[dataframe[COLNAME_METHOD] == algo], + x=COLNAME_METHOD, + y=obj, + hue=COLNAME_IDX_PARAM, + legend=False, + ax=axes, + palette=sns.cubehelix_palette( + n_colors=len( + dataframe[dataframe[COLNAME_METHOD] == algo][ + COLNAME_IDX_PARAM + ].unique() + ) + ), + ) axes.legend([], [], frameon=False) axes.set_ylim([-0.1, 1.1]) - axes.set_xlabel(' ') + axes.set_xlabel(" ") plt.tight_layout() if file is not None: - plt.savefig(file + '/systematic_variation.png', dpi=300) - + plt.savefig(file + "/systematic_variation.png", dpi=300) diff --git a/domainlab/utils/get_git_tag.py b/domainlab/utils/get_git_tag.py index 3a7c0204f..977dda1f2 100644 --- a/domainlab/utils/get_git_tag.py +++ b/domainlab/utils/get_git_tag.py @@ -9,8 +9,7 @@ def get_git_tag(print_diff=False): flag_not_commited = False logger = Logger.get_logger() try: - subprocess.check_output( - ['git', 'diff-index', '--quiet', 'HEAD']) + subprocess.check_output(["git", "diff-index", "--quiet", "HEAD"]) except CalledProcessError: logger.warning("\n\n") logger.warning("!!!: not committed yet") @@ -18,7 +17,7 @@ def get_git_tag(print_diff=False): flag_not_commited = True logger.warning("\n\n") try: - diff_byte = subprocess.check_output(['git', 'diff']) + diff_byte = subprocess.check_output(["git", "diff"]) if print_diff: logger.info(str(diff_byte)) # print is currently ugly, do not use! except Exception: @@ -26,8 +25,7 @@ def get_git_tag(print_diff=False): logger.warning("not in a git repository") warnings.warn("not in a git repository") try: - tag_byte = subprocess.check_output( - ["git", "describe", "--always"]).strip() + tag_byte = subprocess.check_output(["git", "describe", "--always"]).strip() logger.info(str(tag_byte)) tag_str = str(tag_byte) git_str = tag_str.replace("'", "") diff --git a/domainlab/utils/hyperparameter_gridsearch.py b/domainlab/utils/hyperparameter_gridsearch.py index c83c1868f..5ca2ef264 100644 --- a/domainlab/utils/hyperparameter_gridsearch.py +++ b/domainlab/utils/hyperparameter_gridsearch.py @@ -1,17 +1,18 @@ -''' +""" gridsearch for the hyperparameter space def add_next_param_from_list is an recursive function to make cartesian product along all the scalar hyper-parameters, this resursive function is used in def grid_task -''' +""" import copy -import os import json +import os import warnings import numpy as np import pandas as pd + import domainlab.utils.hyperparameter_sampling as sampling from domainlab.utils.get_git_tag import get_git_tag from domainlab.utils.logger import Logger @@ -19,20 +20,24 @@ def add_next_param_from_list is an recursive function to make cartesian product G_MODEL_NA = "model" G_METHOD_NA = "method" + def round_to_discreate_grid_uniform(grid, param_config): - ''' + """ round the values of the grid to the grid spacing specified in the config for uniform and loguniform grids - ''' - if float(param_config['step']) == 0: + """ + if float(param_config["step"]) == 0: return grid - mini = float(param_config['min']) - maxi = float(param_config['max']) - if maxi - mini < float(param_config['step']): - raise RuntimeError('distance between max and min to small for defined step size') - - discreate_gird = np.arange(mini, maxi + float(param_config['step']), - step=float(param_config['step'])) + mini = float(param_config["min"]) + maxi = float(param_config["max"]) + if maxi - mini < float(param_config["step"]): + raise RuntimeError( + "distance between max and min to small for defined step size" + ) + + discreate_gird = np.arange( + mini, maxi + float(param_config["step"]), step=float(param_config["step"]) + ) for num, elem in enumerate(list(grid)): # search for the closest allowed grid point to the scalar elem grid[num] = discreate_gird[(np.abs(discreate_gird - elem)).argmin()] @@ -40,77 +45,83 @@ def round_to_discreate_grid_uniform(grid, param_config): grid_out = grid_unique return grid_out + def round_to_discreate_grid_normal(grid, param_config): - ''' + """ round the values of the grid to the grid spacing specified in the config for normal and lognormal grids - ''' - if float(param_config['step']) == 0: + """ + if float(param_config["step"]) == 0: return grid # for normal and lognormal no min and max is provided # in this case the grid is constructed around the mean - neg_steps = np.ceil((float(param_config['mean']) - np.min(grid)) / - float(param_config['step'])) - pos_steps = np.ceil((np.max(grid) - float(param_config['mean'])) / - float(param_config['step'])) - mini = float(param_config['mean']) - float(param_config['step']) * neg_steps - maxi = float(param_config['mean']) + float(param_config['step']) * pos_steps - - discreate_gird = np.arange(mini, maxi, step=float(param_config['step'])) + neg_steps = np.ceil( + (float(param_config["mean"]) - np.min(grid)) / float(param_config["step"]) + ) + pos_steps = np.ceil( + (np.max(grid) - float(param_config["mean"])) / float(param_config["step"]) + ) + mini = float(param_config["mean"]) - float(param_config["step"]) * neg_steps + maxi = float(param_config["mean"]) + float(param_config["step"]) * pos_steps + + discreate_gird = np.arange(mini, maxi, step=float(param_config["step"])) for num, elem in enumerate(list(grid)): grid[num] = discreate_gird[(np.abs(discreate_gird - elem)).argmin()] return np.unique(grid) + def uniform_grid(param_config): - ''' + """ get a uniform distributed grid given the specifications in the param_config param_config: config which needs to contain 'num', 'max', 'min', 'step' - ''' - num = int(param_config['num']) - maxi = float(param_config['max']) - mini = float(param_config['min']) + """ + num = int(param_config["num"]) + maxi = float(param_config["max"]) + mini = float(param_config["min"]) step = (maxi - mini) / num # linspace does include the end of the interval and include the beginning # we move away from mini and maxi to sample inside the open interval (mini, maxi) grid = np.linspace(mini + step / 2, maxi - step / 2, num) - if 'step' in param_config.keys(): - return round_to_discreate_grid_uniform( - grid, param_config) + if "step" in param_config.keys(): + return round_to_discreate_grid_uniform(grid, param_config) return grid + def loguniform_grid(param_config): - ''' + """ get a loguniform distributed grid given the specifications in the param_config param_config: config which needs to contain 'num', 'max', 'min' - ''' - num = int(param_config['num']) - maxi = np.log10(float(param_config['max'])) - mini = np.log10(float(param_config['min'])) + """ + num = int(param_config["num"]) + maxi = np.log10(float(param_config["max"])) + mini = np.log10(float(param_config["min"])) step = (maxi - mini) / num # linspace does exclude the end of the interval and include the beginning grid = 10 ** np.linspace(mini + step / 2, maxi - step / 2, num) - if 'step' in param_config.keys(): + if "step" in param_config.keys(): return round_to_discreate_grid_uniform(grid, param_config) return grid + def normal_grid(param_config, lognormal=False): - ''' + """ get a normal distributed grid given the specifications in the param_config param_config: config which needs to contain 'num', 'mean', 'std' - ''' - if int(param_config['num']) == 1: - return np.array([float(param_config['mean'])]) + """ + if int(param_config["num"]) == 1: + return np.array([float(param_config["mean"])]) # Box–Muller transform to get from a uniform distribution to a normal distribution - num = int(np.floor(int(param_config['num']) / 2)) - step = 2 / (int(param_config['num']) + 1) + num = int(np.floor(int(param_config["num"]) / 2)) + step = 2 / (int(param_config["num"]) + 1) # for a even number of samples - if int(param_config['num']) % 2 == 0: + if int(param_config["num"]) % 2 == 0: param_grid = np.arange(step, 1, step=step)[:num] stnormal_grid = np.sqrt(-2 * np.log(param_grid)) stnormal_grid = np.append(stnormal_grid, -stnormal_grid) stnormal_grid = stnormal_grid / np.std(stnormal_grid) - stnormal_grid = float(param_config['std']) * stnormal_grid + \ - float(param_config['mean']) + stnormal_grid = float(param_config["std"]) * stnormal_grid + float( + param_config["mean"] + ) # for a odd number of samples else: param_grid = np.arange(step, 1, step=step)[:num] @@ -118,26 +129,28 @@ def normal_grid(param_config, lognormal=False): stnormal_grid = np.append(stnormal_grid, -stnormal_grid) stnormal_grid = np.append(stnormal_grid, 0) stnormal_grid = stnormal_grid / np.std(stnormal_grid) - stnormal_grid = float(param_config['std']) * stnormal_grid + \ - float(param_config['mean']) + stnormal_grid = float(param_config["std"]) * stnormal_grid + float( + param_config["mean"] + ) - if 'step' in param_config.keys() and lognormal is False: + if "step" in param_config.keys() and lognormal is False: return round_to_discreate_grid_normal(stnormal_grid, param_config) return stnormal_grid + def lognormal_grid(param_config): - ''' + """ get a normal distributed grid given the specifications in the param_config param_config: config which needs to contain 'num', 'mean', 'std' - ''' + """ grid = 10 ** normal_grid(param_config, lognormal=True) - if 'step' in param_config.keys(): + if "step" in param_config.keys(): return round_to_discreate_grid_normal(grid, param_config) return grid -def add_next_param_from_list(param_grid: dict, grid: dict, - grid_df: pd.DataFrame): - ''' + +def add_next_param_from_list(param_grid: dict, grid: dict, grid_df: pd.DataFrame): + """ can be used in a recoursive fassion to add all combinations of the parameters in param_grid to grid_df param_grid: dictionary with all possible values for each parameter @@ -147,7 +160,7 @@ def add_next_param_from_list(param_grid: dict, grid: dict, grid_df: dataframe which will save the finished grids task_name: task name also: G_MODEL_NA name - ''' + """ if len(param_grid.keys()) != 0: # specify the parameter to be used param_name = list(param_grid.keys())[0] @@ -165,13 +178,15 @@ def add_next_param_from_list(param_grid: dict, grid: dict, # add sample to grid_df grid_df.loc[len(grid_df.index)] = [grid] -def add_references_and_check_constraints(grid_df_prior, grid_df, referenced_params, - config, task_name): - ''' + +def add_references_and_check_constraints( + grid_df_prior, grid_df, referenced_params, config, task_name +): + """ in the last step all parameters which are referenced need to be add to the grid. All gridpoints not satisfying the constraints are removed afterwards. - ''' - for dictio in grid_df_prior['params']: + """ + for dictio in grid_df_prior["params"]: for key, val in dictio.items(): exec(f"{key} = val") # add referenced params @@ -180,154 +195,174 @@ def add_references_and_check_constraints(grid_df_prior, grid_df, referenced_para dictio.update({rev_param: val}) exec(f"{rev_param} = val") # check constraints - if 'hyperparameters' in config.keys(): - constraints = config['hyperparameters'].get('constraints', None) + if "hyperparameters" in config.keys(): + constraints = config["hyperparameters"].get("constraints", None) else: - constraints = config.get('constraints', None) + constraints = config.get("constraints", None) if constraints is not None: accepted = True for constr in constraints: if not eval(constr): accepted = False if accepted: - grid_df.loc[len(grid_df.index)] = [task_name, config['model'], dictio] + grid_df.loc[len(grid_df.index)] = [task_name, config["model"], dictio] else: - grid_df.loc[len(grid_df.index)] = [task_name, config['model'], dictio] + grid_df.loc[len(grid_df.index)] = [task_name, config["model"], dictio] + def sample_grid(param_config): - ''' + """ given the parameter config, this function samples all parameters which are distributed according the the categorical, uniform, loguniform, normal or lognormal distribution. - ''' + """ # sample cathegorical parameter - if param_config['distribution'] == 'categorical': - param_grid = sampling.CategoricalHyperparameter('', param_config).allowed_values + if param_config["distribution"] == "categorical": + param_grid = sampling.CategoricalHyperparameter("", param_config).allowed_values # sample uniform parameter - elif param_config['distribution'] == 'uniform': + elif param_config["distribution"] == "uniform": param_grid = uniform_grid(param_config) # sample loguniform parameter - elif param_config['distribution'] == 'loguniform': + elif param_config["distribution"] == "loguniform": param_grid = loguniform_grid(param_config) # sample normal parameter - elif param_config['distribution'] == 'normal': + elif param_config["distribution"] == "normal": param_grid = normal_grid(param_config) # sample lognormal parameter - elif param_config['distribution'] == 'lognormal': + elif param_config["distribution"] == "lognormal": param_grid = lognormal_grid(param_config) else: - raise RuntimeError(f'distribution \"{param_config["distribution"]}\" not ' - f'implemented use a distribution from ' - f'[categorical, uniform, loguniform, normal, lognormal]') + raise RuntimeError( + f'distribution "{param_config["distribution"]}" not ' + f"implemented use a distribution from " + f"[categorical, uniform, loguniform, normal, lognormal]" + ) # ensure that the gird does have the correct datatype # (only check for int, othervise float is used) - if 'datatype' in param_config.keys(): - if param_config['datatype'] == 'int': + if "datatype" in param_config.keys(): + if param_config["datatype"] == "int": param_grid = np.array(param_grid) param_grid = param_grid.astype(int) # NOTE: converting int to float will cause error for VAE, avoid do # it here return param_grid + def build_param_grid_of_shared_params(shared_df): - ''' + """ go back from the data frame format of the shared hyperparamters to a list format - ''' + """ if shared_df is None: return None shared_grid = {} - for key in shared_df['params'].iloc[0].keys(): + for key in shared_df["params"].iloc[0].keys(): grid_points = [] - for i in shared_df['params'].keys(): - grid_points.append(shared_df['params'][i][key]) + for i in shared_df["params"].keys(): + grid_points.append(shared_df["params"][i][key]) shared_grid[key] = np.array(grid_points) return shared_grid + def rais_error_if_num_not_specified(param_name: str, param_config: dict): - ''' + """ for each parameter a number of grid points needs to be specified This function raises an error if this is not the case param_name: parameter name under consideration param_config: config of this parameter - ''' - if not param_name == 'constraints': - if not 'num' in param_config.keys() \ - and not 'reference' in param_config.keys() \ - and not param_config['distribution'] == 'categorical': - raise RuntimeError(f"the number of parameters in the grid direction " - f"of {param_name} needs to be specified") + """ + if not param_name == "constraints": + if ( + not "num" in param_config.keys() + and not "reference" in param_config.keys() + and not param_config["distribution"] == "categorical" + ): + raise RuntimeError( + f"the number of parameters in the grid direction " + f"of {param_name} needs to be specified" + ) + def add_shared_params_to_param_grids(shared_df, dict_param_grids, config): - ''' + """ use the parameters in the dataframe of shared parameters and add them to the dictionary of parameters for the current task only the shared parameters specified in the config are respected shared_df: Dataframe of shared hyperparameters dict_param_grids: dictionary of the parameter grids config: config for the current task - ''' + """ dict_shared_grid = build_param_grid_of_shared_params(shared_df) - if 'shared' in config.keys(): - list_names = config['shared'] - dict_shared_grid = {key: dict_shared_grid[key] for key in config['shared']} + if "shared" in config.keys(): + list_names = config["shared"] + dict_shared_grid = {key: dict_shared_grid[key] for key in config["shared"]} if dict_shared_grid is not None: for key in dict_shared_grid.keys(): dict_param_grids[key] = dict_shared_grid[key] return dict_param_grids -def grid_task(grid_df: pd.DataFrame, task_name: str, config: dict, shared_df: pd.DataFrame): + +def grid_task( + grid_df: pd.DataFrame, task_name: str, config: dict, shared_df: pd.DataFrame +): """create grid for one sampling task for a method and add it to the dataframe""" - if 'hyperparameters' in config.keys(): + if "hyperparameters" in config.keys(): dict_param_grids = {} referenced_params = {} - for param_name in config['hyperparameters'].keys(): - param_config = config['hyperparameters'][param_name] + for param_name in config["hyperparameters"].keys(): + param_config = config["hyperparameters"][param_name] rais_error_if_num_not_specified(param_name, param_config) # constraints are not parameters - if not param_name == 'constraints': + if not param_name == "constraints": # remember all parameters which are reverenced - if 'datatype' not in param_config.keys(): - warnings.warn(f"datatype not specified in {param_config} \ - for {param_name}, take float as default") - param_config['datatype'] = 'float' - - if 'reference' in param_config.keys(): - referenced_params.update({param_name: param_config['reference']}) + if "datatype" not in param_config.keys(): + warnings.warn( + f"datatype not specified in {param_config} \ + for {param_name}, take float as default" + ) + param_config["datatype"] = "float" + + if "reference" in param_config.keys(): + referenced_params.update({param_name: param_config["reference"]}) # sample other parameter - elif param_name != 'constraints': + elif param_name != "constraints": dict_param_grids.update({param_name: sample_grid(param_config)}) # create the grid from the individual parameter grids # constraints are not respected in this step - grid_df_prior = pd.DataFrame(columns=['params']) + grid_df_prior = pd.DataFrame(columns=["params"]) # add shared parameters to dict_param_grids dict_param_grids = add_shared_params_to_param_grids( - shared_df, dict_param_grids, config) + shared_df, dict_param_grids, config + ) add_next_param_from_list(dict_param_grids, {}, grid_df_prior) # add referenced params and check constraints - add_references_and_check_constraints(grid_df_prior, grid_df, referenced_params, - config, task_name) - if grid_df[grid_df[G_MODEL_NA] == config['model']].shape[0] == 0: - raise RuntimeError('No valid value found for this grid spacing, refine grid') + add_references_and_check_constraints( + grid_df_prior, grid_df, referenced_params, config, task_name + ) + if grid_df[grid_df[G_MODEL_NA] == config["model"]].shape[0] == 0: + raise RuntimeError( + "No valid value found for this grid spacing, refine grid" + ) return grid_df - elif 'shared' in config.keys(): + elif "shared" in config.keys(): shared_grid = shared_df.copy() - shared_grid[G_MODEL_NA] = config['model'] + shared_grid[G_MODEL_NA] = config["model"] shared_grid[G_METHOD_NA] = task_name - if 'constraints' in config.keys(): - config['hyperparameters'] = {'constraints': config['constraints']} - add_references_and_check_constraints(shared_grid, grid_df, {}, config, task_name) + if "constraints" in config.keys(): + config["hyperparameters"] = {"constraints": config["constraints"]} + add_references_and_check_constraints( + shared_grid, grid_df, {}, config, task_name + ) return grid_df else: # add single line if no varying hyperparameters are specified. - grid_df.loc[len(grid_df.index)] = [task_name, config['model'], {}] + grid_df.loc[len(grid_df.index)] = [task_name, config["model"], {}] return grid_df -def sample_gridsearch(config: dict, - dest: str = None) -> pd.DataFrame: +def sample_gridsearch(config: dict, dest: str = None) -> pd.DataFrame: """ create the hyperparameters grid according to the given config, which should be the dictionary of the full @@ -339,52 +374,60 @@ def sample_gridsearch(config: dict, only with trusted config files. """ if dest is None: - dest = config['output_dir'] + os.sep + 'hyperparameters.csv' + dest = config["output_dir"] + os.sep + "hyperparameters.csv" logger = Logger.get_logger() - samples = pd.DataFrame(columns=[G_METHOD_NA, G_MODEL_NA, 'params']) - shared_samples_full = pd.DataFrame(columns=[G_METHOD_NA, G_MODEL_NA, 'params']) + samples = pd.DataFrame(columns=[G_METHOD_NA, G_MODEL_NA, "params"]) + shared_samples_full = pd.DataFrame(columns=[G_METHOD_NA, G_MODEL_NA, "params"]) - if 'Shared params' in config.keys(): - shared_val = {'model': 'all', 'hyperparameters': config['Shared params']} + if "Shared params" in config.keys(): + shared_val = {"model": "all", "hyperparameters": config["Shared params"]} # fill up the dataframe shared samples - shared_samples_full = grid_task(shared_samples_full, 'all', shared_val, None) + shared_samples_full = grid_task(shared_samples_full, "all", shared_val, None) else: shared_samples_full = None for key, val in config.items(): if sampling.is_dict_with_key(val, "model"): if shared_samples_full is not None: shared_samples = shared_samples_full.copy(deep=True) - if 'shared' in val.keys(): - shared = val['shared'] + if "shared" in val.keys(): + shared = val["shared"] else: shared = [] for line_num in range(shared_samples.shape[0]): - hyper_p_dict = shared_samples.iloc[line_num]['params'].copy() + hyper_p_dict = shared_samples.iloc[line_num]["params"].copy() key_list = copy.deepcopy(list(hyper_p_dict.keys())) if not all(x in key_list for x in shared): - raise RuntimeError(f"shared keys: {shared} not included in global shared keys {key_list}") + raise RuntimeError( + f"shared keys: {shared} not included in global shared keys {key_list}" + ) for key_ in key_list: if key_ not in shared: del hyper_p_dict[key_] - shared_samples.iloc[line_num]['params'] = hyper_p_dict + shared_samples.iloc[line_num]["params"] = hyper_p_dict # remove all duplicates - shared_samples = shared_samples.drop_duplicates(subset='params') + shared_samples = shared_samples.drop_duplicates(subset="params") else: shared_samples = None samples = grid_task(samples, key, val, shared_samples) - logger.info(f'number of gridpoints for {key} : ' - f'{samples[samples[G_MODEL_NA] == val["model"]].shape[0]}') + logger.info( + f"number of gridpoints for {key} : " + f'{samples[samples[G_MODEL_NA] == val["model"]].shape[0]}' + ) os.makedirs(os.path.dirname(dest), exist_ok=True) - logger.info(f'number of total sampled gridpoints: {samples.shape[0]}') + logger.info(f"number of total sampled gridpoints: {samples.shape[0]}") samples.to_csv(dest) - # create a txt file with the commit information - with open(config["output_dir"] + os.sep + 'commit.txt', 'w', encoding="utf8") as file: + # create a txt file with the commit information + with open( + config["output_dir"] + os.sep + "commit.txt", "w", encoding="utf8" + ) as file: file.writelines("use git log |grep \n") file.writelines("consider remove leading b in the line below \n") file.write(get_git_tag()) - with open(config["output_dir"] + os.sep + 'config.txt', 'w', encoding="utf8") as file: + with open( + config["output_dir"] + os.sep + "config.txt", "w", encoding="utf8" + ) as file: json.dump(config, file) return samples diff --git a/domainlab/utils/hyperparameter_sampling.py b/domainlab/utils/hyperparameter_sampling.py index cc1bae5ce..14c1f42e4 100644 --- a/domainlab/utils/hyperparameter_sampling.py +++ b/domainlab/utils/hyperparameter_sampling.py @@ -7,17 +7,17 @@ # Functions to sample hyper-parameters and log into csv file """ import copy -import os import json +import os +from ast import literal_eval # literal_eval can safe evaluate python expression from pydoc import locate from typing import List -from ast import literal_eval # literal_eval can safe evaluate python expression import numpy as np import pandas as pd -from domainlab.utils.logger import Logger from domainlab.utils.get_git_tag import get_git_tag +from domainlab.utils.logger import Logger G_MODEL_NA = "model" G_METHOD_NA = "method" @@ -33,6 +33,7 @@ class Hyperparameter: p2: max or scale reference: None or name of referenced hyperparameter """ + def __init__(self, name: str): self.name = name self.val = 0 @@ -61,19 +62,22 @@ class SampledHyperparameter(Hyperparameter): """ A numeric hyperparameter that shall be sampled """ + def __init__(self, name: str, config: dict): super().__init__(name) - self.step = config.get('step', 0) + self.step = config.get("step", 0) try: - self.distribution = config['distribution'] - if self.distribution in {'uniform', 'loguniform'}: - self.p_1 = config['min'] - self.p_2 = config['max'] - elif self.distribution in {'normal', 'lognormal'}: - self.p_1 = config['mean'] - self.p_2 = config['std'] + self.distribution = config["distribution"] + if self.distribution in {"uniform", "loguniform"}: + self.p_1 = config["min"] + self.p_2 = config["max"] + elif self.distribution in {"normal", "lognormal"}: + self.p_1 = config["mean"] + self.p_2 = config["std"] else: - raise RuntimeError(f"Unsupported distribution type: {self.distribution}.") + raise RuntimeError( + f"Unsupported distribution type: {self.distribution}." + ) except KeyError as ex: raise RuntimeError(f"Missing required key for parameter {name}.") from ex @@ -83,7 +87,7 @@ def __init__(self, name: str, config: dict): def _ensure_step(self): """Make sure that the hyperparameter sticks to the discrete grid""" if self.step == 0: - return # continous parameter + return # continous parameter # round to next discrete value. # p_1 is the lower bound of the hyper-parameter range, p_2 the upper bound @@ -101,13 +105,13 @@ def _ensure_step(self): def sample(self): """Sample this parameter, respecting properties""" - if self.distribution == 'uniform': + if self.distribution == "uniform": self.val = np.random.uniform(self.p_1, self.p_2) - elif self.distribution == 'loguniform': + elif self.distribution == "loguniform": self.val = 10 ** np.random.uniform(np.log10(self.p_1), np.log10(self.p_2)) - elif self.distribution == 'normal': + elif self.distribution == "normal": self.val = np.random.normal(self.p_1, self.p_2) - elif self.distribution == 'lognormal': + elif self.distribution == "lognormal": self.val = 10 ** np.random.normal(self.p_1, self.p_2) else: raise RuntimeError(f"Unsupported distribution type: {self.distribution}.") @@ -122,9 +126,10 @@ class ReferenceHyperparameter(Hyperparameter): Hyperparameter that references only a different one. Thus, this parameter is not sampled but set after sampling. """ + def __init__(self, name: str, config: dict): super().__init__(name) - self.reference = config.get('reference', None) + self.reference = config.get("reference", None) def _ensure_step(self): """Make sure that the hyperparameter sticks to the discrete grid""" @@ -145,12 +150,15 @@ class CategoricalHyperparameter(Hyperparameter): A sampled hyperparameter, which is constraint to fixed, user given values and datatype """ + def __init__(self, name: str, config: dict): super().__init__(name) - self.allowed_values = config['values'] - if 'datatype' not in config: - raise RuntimeError("Please specifiy datatype for all categorical hyper-parameters!, e.g. datatype=str") - self.type = locate(config['datatype']) + self.allowed_values = config["values"] + if "datatype" not in config: + raise RuntimeError( + "Please specifiy datatype for all categorical hyper-parameters!, e.g. datatype=str" + ) + self.type = locate(config["datatype"]) self.allowed_values = [self.type(v) for v in self.allowed_values] def _ensure_step(self): @@ -170,10 +178,10 @@ def datatype(self): def get_hyperparameter(name: str, config: dict) -> Hyperparameter: """Factory function. Instantiates the correct Hyperparameter""" - if 'reference' in config.keys(): + if "reference" in config.keys(): return ReferenceHyperparameter(name, config) - dist = config.get('distribution', None) - if dist == 'categorical': + dist = config.get("distribution", None) + if dist == "categorical": return CategoricalHyperparameter(name, config) return SampledHyperparameter(name, config) @@ -189,7 +197,7 @@ def check_constraints(params: List[Hyperparameter], constraints) -> bool: for par in params: if isinstance(par, ReferenceHyperparameter): try: - setattr(par, 'val', eval(par.reference)) + setattr(par, "val", eval(par.reference)) # NOTE: literal_eval will cause ValueError: malformed node or string except Exception as ex: logger = Logger.get_logger() @@ -198,7 +206,7 @@ def check_constraints(params: List[Hyperparameter], constraints) -> bool: locals().update({par.name: par.val}) if constraints is None: - return True # shortcut + return True # shortcut # check all constraints for constr in constraints: @@ -213,8 +221,12 @@ def check_constraints(params: List[Hyperparameter], constraints) -> bool: return True -def sample_parameters(init_params: List[Hyperparameter], constraints, - shared_config=None, shared_samples=None) -> dict: +def sample_parameters( + init_params: List[Hyperparameter], + constraints, + shared_config=None, + shared_samples=None, +) -> dict: """ Tries to sample from the hyperparameter list. @@ -228,7 +240,7 @@ def sample_parameters(init_params: List[Hyperparameter], constraints, # add a random hyperparameter from the shared hyperparameter dataframe if shared_samples is not None: # sample one line from the pandas dataframe - shared_samp = shared_samples.sample(1).iloc[0]['params'] + shared_samp = shared_samples.sample(1).iloc[0]["params"] for key in shared_samp.keys(): par = Hyperparameter(key) par.val = shared_samp[key] @@ -245,13 +257,15 @@ def sample_parameters(init_params: List[Hyperparameter], constraints, # this may be due to the shared hyperparameters. # If so, new samples are generated for the shared hyperparameters logger = Logger.get_logger() - logger.warning("The constrainds coundn't be met with the shared Hyperparameters, " - "shared dataframe pool will be ignored for now.") + logger.warning( + "The constrainds coundn't be met with the shared Hyperparameters, " + "shared dataframe pool will be ignored for now." + ) for _ in range(10_000): params = copy.deepcopy(init_params) # add the shared hyperparameter as a sampled hyperparameter if shared_samples is not None: - shared_samp = shared_samples.sample(1).iloc[0]['params'] + shared_samp = shared_samples.sample(1).iloc[0]["params"] for key in shared_samp.keys(): par = SampledHyperparameter(key, shared_config[key]) par.sample() @@ -266,39 +280,44 @@ def sample_parameters(init_params: List[Hyperparameter], constraints, samples[par.name] = par.val return samples - raise RuntimeError("Could not find an acceptable sample in 10,000 runs." - "Are the bounds and constraints reasonable?") + raise RuntimeError( + "Could not find an acceptable sample in 10,000 runs." + "Are the bounds and constraints reasonable?" + ) -def create_samples_from_shared_samples(shared_samples: pd.DataFrame, - config: dict, - task_name: str): - ''' +def create_samples_from_shared_samples( + shared_samples: pd.DataFrame, config: dict, task_name: str +): + """ add informations like task, G_MODEL_NA and constrainds to the shared samples Parameters: shared_samples: pd Dataframe with columns [G_METHOD_NA, G_MODEL_NA, 'params'] config: dataframe with yaml configuration of the current task task_name: name of the current task - ''' + """ shared_samp = shared_samples.copy() - shared_samp[G_MODEL_NA] = config['model'] + shared_samp[G_MODEL_NA] = config["model"] shared_samp[G_METHOD_NA] = task_name # respect the constraints if specified in the task - if 'constraints' in config.keys(): + if "constraints" in config.keys(): for idx in range(shared_samp.shape[0] - 1, -1, -1): - name = list(shared_samp['params'].iloc[idx].keys())[0] - value = shared_samp['params'].iloc[idx][name] + name = list(shared_samp["params"].iloc[idx].keys())[0] + value = shared_samp["params"].iloc[idx][name] par = Hyperparameter(name) par.val = value - if not check_constraints([par], config['constraints']): + if not check_constraints([par], config["constraints"]): shared_samp = shared_samp.drop(idx) return shared_samp -def sample_task_only_shared(num_samples, task_name, sample_df, config, shared_conf_samp): - ''' + +def sample_task_only_shared( + num_samples, task_name, sample_df, config, shared_conf_samp +): + """ sample one task and add it to the dataframe for task descriptions which only contain shared hyperparameters - ''' + """ shared_config, shared_samples = shared_conf_samp # copy the shared samples dataframe and add the corrct G_MODEL_NA and taks names shared_samp = create_samples_from_shared_samples(shared_samples, config, task_name) @@ -310,16 +329,20 @@ def sample_task_only_shared(num_samples, task_name, sample_df, config, shared_co s_config = shared_config.copy() s_dict = {} for keys in s_config.keys(): - if keys != 'num_shared_param_samples': + if keys != "num_shared_param_samples": s_dict[keys] = s_config[keys] - if 'constraints' in config.keys(): - s_dict['constraints'] = config['constraints'] - s_config['model'] = config['model'] - s_config['hyperparameters'] = s_dict + if "constraints" in config.keys(): + s_dict["constraints"] = config["constraints"] + s_config["model"] = config["model"] + s_config["hyperparameters"] = s_dict # sample new shared hyperparameters - sample_df = sample_task(num_samples - shared_samp.shape[0], - task_name, (s_config, sample_df), (None, None)) + sample_df = sample_task( + num_samples - shared_samp.shape[0], + task_name, + (s_config, sample_df), + (None, None), + ) # add previously sampled shared hyperparameters sample_df = sample_df.append(shared_samp, ignore_index=True) # for the case that the number of shared samples is >= the expected number of @@ -330,32 +353,35 @@ def sample_task_only_shared(num_samples, task_name, sample_df, config, shared_co return sample_df -def sample_task(num_samples: int, - task_name: str, - conf_samp: tuple, - shared_conf_samp: tuple): + +def sample_task( + num_samples: int, task_name: str, conf_samp: tuple, shared_conf_samp: tuple +): """Sample one task and add it to the dataframe""" config, sample_df = conf_samp shared_config, shared_samples = shared_conf_samp - if 'hyperparameters' in config.keys(): + if "hyperparameters" in config.keys(): # in benchmark configuration file, sub-section hyperparameters # means changing hyper-parameters params = [] - for key, val in config['hyperparameters'].items(): - if key in ('constraints', 'num_shared_param_samples'): + for key, val in config["hyperparameters"].items(): + if key in ("constraints", "num_shared_param_samples"): continue params += [get_hyperparameter(key, val)] - constraints = config['hyperparameters'].get('constraints', None) + constraints = config["hyperparameters"].get("constraints", None) for _ in range(num_samples): - sample = sample_parameters(params, constraints, shared_config, shared_samples) - sample_df.loc[len(sample_df.index)] = [task_name, config['model'], sample] - elif 'shared' in config.keys(): - sample_df = sample_task_only_shared(num_samples, task_name, sample_df, - config, (shared_config, shared_samples)) + sample = sample_parameters( + params, constraints, shared_config, shared_samples + ) + sample_df.loc[len(sample_df.index)] = [task_name, config["model"], sample] + elif "shared" in config.keys(): + sample_df = sample_task_only_shared( + num_samples, task_name, sample_df, config, (shared_config, shared_samples) + ) else: # add single line if no varying hyperparameters are specified. - sample_df.loc[len(sample_df.index)] = [task_name, config['model'], {}] + sample_df.loc[len(sample_df.index)] = [task_name, config["model"], {}] return sample_df @@ -363,41 +389,43 @@ def is_dict_with_key(input_dict, key) -> bool: """Determines if the input argument is a dictionary and it has key""" return isinstance(input_dict, dict) and key in input_dict.keys() -def get_shared_samples(shared_samples_full: pd.DataFrame, - shared_config_full: dict, - task_config: dict): - ''' + +def get_shared_samples( + shared_samples_full: pd.DataFrame, shared_config_full: dict, task_config: dict +): + """ - creates a dataframe with columns [task, G_MODEL_NA, params], task and G_MODEL_NA are all for all rows, but params is filled with the shared parameters of shared_samples_full requested by task_config. - creates a shared config containing only information about the shared hyperparameters requested by the task_config - ''' + """ shared_samples = shared_samples_full.copy(deep=True) shared_config = shared_config_full.copy() - if 'shared' in task_config.keys(): - shared = task_config['shared'] + if "shared" in task_config.keys(): + shared = task_config["shared"] else: shared = [] for line_num in range(shared_samples.shape[0]): - hyper_p_dict = shared_samples.iloc[line_num]['params'].copy() + hyper_p_dict = shared_samples.iloc[line_num]["params"].copy() key_list = copy.deepcopy(list(hyper_p_dict.keys())) for key_ in key_list: if key_ not in shared: del hyper_p_dict[key_] - shared_samples.iloc[line_num]['params'] = hyper_p_dict + shared_samples.iloc[line_num]["params"] = hyper_p_dict for key_ in key_list: - if not key_ == 'num_shared_param_samples': + if not key_ == "num_shared_param_samples": if key_ not in shared: del shared_config[key_] # remove all duplicates - shared_samples = shared_samples.drop_duplicates(subset='params') + shared_samples = shared_samples.drop_duplicates(subset="params") return shared_samples, shared_config -def sample_hyperparameters(config: dict, - dest: str = None, - sampling_seed: int = None) -> pd.DataFrame: + +def sample_hyperparameters( + config: dict, dest: str = None, sampling_seed: int = None +) -> pd.DataFrame: """ Samples the hyperparameters according to the given config, which should be the dictionary of the full @@ -409,42 +437,52 @@ def sample_hyperparameters(config: dict, only with trusted config files. """ if dest is None: - dest = config['output_dir'] + os.sep + 'hyperparameters.csv' + dest = config["output_dir"] + os.sep + "hyperparameters.csv" if sampling_seed is not None: np.random.seed(sampling_seed) - num_samples = config['num_param_samples'] - samples = pd.DataFrame(columns=[G_METHOD_NA, G_MODEL_NA, 'params']) - if 'Shared params' in config.keys(): - shared_config_full = config['Shared params'] - shared_samples_full = pd.DataFrame(columns=[G_METHOD_NA, G_MODEL_NA, 'params']) - shared_val = {'model': 'all', 'hyperparameters': config['Shared params']} + num_samples = config["num_param_samples"] + samples = pd.DataFrame(columns=[G_METHOD_NA, G_MODEL_NA, "params"]) + if "Shared params" in config.keys(): + shared_config_full = config["Shared params"] + shared_samples_full = pd.DataFrame(columns=[G_METHOD_NA, G_MODEL_NA, "params"]) + shared_val = {"model": "all", "hyperparameters": config["Shared params"]} # fill up the dataframe shared samples - shared_samples_full = sample_task(shared_config_full['num_shared_param_samples'], - 'all', (shared_val, shared_samples_full), (None, None)) + shared_samples_full = sample_task( + shared_config_full["num_shared_param_samples"], + "all", + (shared_val, shared_samples_full), + (None, None), + ) else: shared_samples_full = None for key, val in config.items(): if is_dict_with_key(val, "model"): if shared_samples_full is not None: shared_samples, shared_config = get_shared_samples( - shared_samples_full, shared_config_full, val) + shared_samples_full, shared_config_full, val + ) else: shared_config = None shared_samples = None - samples = sample_task(num_samples, key, (val, samples), - (shared_config, shared_samples)) + samples = sample_task( + num_samples, key, (val, samples), (shared_config, shared_samples) + ) os.makedirs(os.path.dirname(dest), exist_ok=True) # create a txt file with the commit information - with open(config["output_dir"] + os.sep + 'commit.txt', 'w', encoding="utf8") as file: + with open( + config["output_dir"] + os.sep + "commit.txt", "w", encoding="utf8" + ) as file: file.writelines("use git log |grep \n") file.writelines("consider remove leading b in the line below \n") file.write(get_git_tag()) - with open(config["output_dir"] + os.sep + 'config.txt', 'w', encoding="utf8") as file: + with open( + config["output_dir"] + os.sep + "config.txt", "w", encoding="utf8" + ) as file: json.dump(config, file) samples.to_csv(dest) diff --git a/domainlab/utils/logger.py b/domainlab/utils/logger.py index d04283e5c..31260e796 100644 --- a/domainlab/utils/logger.py +++ b/domainlab/utils/logger.py @@ -1,36 +1,41 @@ -''' +""" A logger for our software -''' -import os +""" import logging import multiprocessing +import os class Logger: - ''' + """ static logger class - ''' + """ + logger = None @staticmethod - def get_logger(logger_name='logger_' + str(multiprocessing.current_process().pid), loglevel='INFO'): - ''' + def get_logger( + logger_name="logger_" + str(multiprocessing.current_process().pid), + loglevel="INFO", + ): + """ returns a logger if no logger was created yet, it will create a logger with the name specified in logger_name with the level specified in loglevel. If the logger was created for the first time the arguments do not change anything at the behaviour anymore - ''' + """ if Logger.logger is None: Logger.logger = logging.getLogger(logger_name) Logger.logger.setLevel(loglevel) # Create handlers and set their logging level - logfolder = 'zoutput/logs' + logfolder = "zoutput/logs" os.makedirs(logfolder, exist_ok=True) # Create handlers and set their logging level - filehandler = logging.FileHandler(logfolder + '/' + Logger.logger.name + '.log', - mode='w') + filehandler = logging.FileHandler( + logfolder + "/" + Logger.logger.name + ".log", mode="w" + ) filehandler.setLevel(loglevel) console_handler = logging.StreamHandler() diff --git a/domainlab/utils/override_interface.py b/domainlab/utils/override_interface.py index 9eaef4c83..8c0c8025e 100644 --- a/domainlab/utils/override_interface.py +++ b/domainlab/utils/override_interface.py @@ -13,6 +13,7 @@ class Child(BaseClass): def fun(self): pass """ + def overrider(method2override): """overrider. diff --git a/domainlab/utils/perf.py b/domainlab/utils/perf.py index 218003bdf..05c77cfd2 100644 --- a/domainlab/utils/perf.py +++ b/domainlab/utils/perf.py @@ -2,13 +2,15 @@ import torch -class PerfClassif(): +class PerfClassif: """Classification Performance""" + @classmethod def gen_fun_acc(cls, dim_target): """ :param dim_target: class/domain label embeding dimension """ + def fun_acc(list_vec_preds, list_vec_labels): """ :param list_vec_preds: list of batches @@ -17,7 +19,9 @@ def fun_acc(list_vec_preds, list_vec_labels): correct_count = 0 obs_count = 0 for pred, label in zip(list_vec_preds, list_vec_labels): - correct_count += torch.sum(torch.sum(pred == label, dim=1) == dim_target) + correct_count += torch.sum( + torch.sum(pred == label, dim=1) == dim_target + ) obs_count += pred.shape[0] # batch size if isinstance(correct_count, int): acc = (correct_count) / obs_count @@ -26,6 +30,7 @@ def fun_acc(list_vec_preds, list_vec_labels): # AttributeError: 'int' object has no attribute 'float' # reason: batchsize is too big return acc + return fun_acc @classmethod @@ -39,7 +44,8 @@ def cal_acc(cls, model, loader_te, device): model_local = model.to(device) fun_acc = cls.gen_fun_acc(model_local.dim_y) list_vec_preds, list_vec_labels = cls.get_list_pred_target( - model_local, loader_te, device) + model_local, loader_te, device + ) accuracy_y = fun_acc(list_vec_preds, list_vec_labels) acc_y = accuracy_y.cpu().numpy().item() return acc_y diff --git a/domainlab/utils/perf_metrics.py b/domainlab/utils/perf_metrics.py index 7f5b69bba..932d1f9ab 100644 --- a/domainlab/utils/perf_metrics.py +++ b/domainlab/utils/perf_metrics.py @@ -1,26 +1,39 @@ """Classification Performance""" import numpy as np import torch -from torchmetrics.classification import (AUC, AUROC, Accuracy, ConfusionMatrix, - F1Score, Precision, Recall, - Specificity) +from torchmetrics.classification import ( + AUC, + AUROC, + Accuracy, + ConfusionMatrix, + F1Score, + Precision, + Recall, + Specificity, +) -class PerfMetricClassif(): +class PerfMetricClassif: """Classification Performance metrics""" - def __init__(self, num_classes, agg_precision_recall_f1='macro'): + + def __init__(self, num_classes, agg_precision_recall_f1="macro"): super().__init__() - self.acc = Accuracy(num_classes=num_classes, average='micro') + self.acc = Accuracy(num_classes=num_classes, average="micro") # NOTE: only micro aggregation make sense for acc - self.precision = Precision(num_classes=num_classes, average=agg_precision_recall_f1) + self.precision = Precision( + num_classes=num_classes, average=agg_precision_recall_f1 + ) # Calculate the metric for each class separately, and average the # metrics across classes (with equal weights for each class). self.recall = Recall(num_classes=num_classes, average=agg_precision_recall_f1) - self.f1_score = F1Score(num_classes=num_classes, average=agg_precision_recall_f1) + self.f1_score = F1Score( + num_classes=num_classes, average=agg_precision_recall_f1 + ) # NOTE: auroc does nto support "micro" as aggregation self.auroc = AUROC(num_classes=num_classes, average=agg_precision_recall_f1) - self.specificity = Specificity(num_classes=num_classes, - average=agg_precision_recall_f1) + self.specificity = Specificity( + num_classes=num_classes, average=agg_precision_recall_f1 + ) self.confmat = ConfusionMatrix(num_classes=num_classes) def cal_metrics(self, model, loader_te, device): @@ -68,13 +81,15 @@ def cal_metrics(self, model, loader_te, device): f1_score_y = self.f1_score.compute() auroc_y = self.auroc.compute() confmat_y = self.confmat.compute() - dict_metric = {"acc": acc_y, - "precision": precision_y, - "recall": recall_y, - "specificity": specificity_y, - "f1": f1_score_y, - "auroc": auroc_y, - "confmat": confmat_y} + dict_metric = { + "acc": acc_y, + "precision": precision_y, + "recall": recall_y, + "specificity": specificity_y, + "f1": f1_score_y, + "auroc": auroc_y, + "confmat": confmat_y, + } keys = list(dict_metric) keys.remove("confmat") for key in keys: diff --git a/domainlab/utils/sanity_check.py b/domainlab/utils/sanity_check.py index 82732f40e..6994c496d 100644 --- a/domainlab/utils/sanity_check.py +++ b/domainlab/utils/sanity_check.py @@ -1,22 +1,24 @@ -''' +""" This class is used to perform the sanity check on a task description -''' +""" import datetime import os import shutil + import numpy as np -from torch.utils.data import Subset import torch.utils.data as data_utils +from torch.utils.data import Subset from domainlab.dsets.utils_data import plot_ds -class SanityCheck(): +class SanityCheck: """ Performs a sanity check on the given args and the task when running dataset_sanity_check(self) """ + def __init__(self, args, task): self.args = args self.task = task @@ -29,13 +31,16 @@ def dataset_sanity_check(self): """ # self.task.init_business(self.args) - list_domain_tr, list_domain_te = self.task.get_list_domains_tr_te(self.args.tr_d, - self.args.te_d) - + list_domain_tr, list_domain_te = self.task.get_list_domains_tr_te( + self.args.tr_d, self.args.te_d + ) time_stamp = datetime.datetime.now() - f_name = os.path.join(self.args.out, 'Dset_extraction', - self.task.task_name + ' ' + str(time_stamp)) + f_name = os.path.join( + self.args.out, + "Dset_extraction", + self.task.task_name + " " + str(time_stamp), + ) # remove previous sanity checks with the same name shutil.rmtree(f_name, ignore_errors=True) @@ -45,7 +50,7 @@ def dataset_sanity_check(self): d_dataset = self.task.dict_dset_tr[domain] else: d_dataset = self.task.get_dset_by_domain(self.args, domain)[0] - folder_name = f_name + '/train_domain/' + str(domain) + folder_name = f_name + "/train_domain/" + str(domain) self.save_san_check_for_domain(self.args.san_num, folder_name, d_dataset) # for each testing domain do... @@ -54,22 +59,23 @@ def dataset_sanity_check(self): d_dataset = self.task.dict_dset_te[domain] else: d_dataset = self.task.get_dset_by_domain(self.args, domain)[0] - folder_name = f_name + '/test_domain/' + str(domain) + folder_name = f_name + "/test_domain/" + str(domain) self.save_san_check_for_domain(self.args.san_num, folder_name, d_dataset) - def save_san_check_for_domain(self, sample_num, folder_name, d_dataset): - ''' + """ saves a extraction of the dataset (d_dataset) into folder (folder_name) sample_num: int, number of images which are extracted from the dataset folder_name: string, destination for the saved images d_dataset: dataset - ''' + """ # for each class do... for class_num in range(len(self.task.list_str_y)): num_of_samples = 0 - loader_domain = data_utils.DataLoader(d_dataset, batch_size=1, shuffle=False) + loader_domain = data_utils.DataLoader( + d_dataset, batch_size=1, shuffle=False + ) domain_targets = [] for num, (_, lab, *_) in enumerate(loader_domain): if int(np.argmax(lab[0])) == class_num: @@ -82,7 +88,6 @@ def save_san_check_for_domain(self, sample_num, folder_name, d_dataset): os.makedirs(folder_name, exist_ok=True) plot_ds( class_dataset, - folder_name + '/' + - str(self.task.list_str_y[class_num]) + '.jpg', - batchsize=sample_num + folder_name + "/" + str(self.task.list_str_y[class_num]) + ".jpg", + batchsize=sample_num, ) diff --git a/domainlab/utils/test_img.py b/domainlab/utils/test_img.py index 16902a46c..f401aa337 100644 --- a/domainlab/utils/test_img.py +++ b/domainlab/utils/test_img.py @@ -5,18 +5,22 @@ def mk_img(i_h, i_ch=3, batch_size=5): img = torch.rand(i_h, i_h) # uniform distribution [0,1] # x = torch.clamp(x, 0, 1) img.unsqueeze_(0) - img = img.repeat(i_ch, 1, 1) # RGB image + img = img.repeat(i_ch, 1, 1) # RGB image img.unsqueeze_(0) img = img.repeat(batch_size, 1, 1, 1) return img + def mk_rand_label_onehot(target_dim=10, batch_size=5): - label_scalar = torch.randint(high=target_dim, size=(batch_size, )) + label_scalar = torch.randint(high=target_dim, size=(batch_size,)) label_scalar2 = label_scalar.unsqueeze(1) label_zeros = torch.zeros(batch_size, target_dim) - label_onehot = torch.scatter(input=label_zeros, dim=1, index=label_scalar2, value=1.0) + label_onehot = torch.scatter( + input=label_zeros, dim=1, index=label_scalar2, value=1.0 + ) return label_onehot + def mk_rand_xyd(ims, y_dim, d_dim, batch_size): imgs = mk_img(i_h=ims, batch_size=batch_size) ys = mk_rand_label_onehot(target_dim=y_dim, batch_size=batch_size) diff --git a/domainlab/utils/u_import_net_module.py b/domainlab/utils/u_import_net_module.py index 825ed1315..14df0fcc5 100644 --- a/domainlab/utils/u_import_net_module.py +++ b/domainlab/utils/u_import_net_module.py @@ -2,13 +2,12 @@ import external neural network implementation """ -from domainlab.utils.u_import import import_path from domainlab.utils.logger import Logger +from domainlab.utils.u_import import import_path -def build_external_obj_net_module_feat_extract(mpath, dim_y, - remove_last_layer): - """ The user provide a function to initiate an object of the neural network, +def build_external_obj_net_module_feat_extract(mpath, dim_y, remove_last_layer): + """The user provide a function to initiate an object of the neural network, which is fine for training but problematic for persistence of the trained model since it is created externally. :param mpath: path of external python file where the neural network @@ -29,20 +28,25 @@ def build_external_obj_net_module_feat_extract(mpath, dim_y, name_signature = "build_feat_extract_net(dim_y, \ remove_last_layer)" # @FIXME: hard coded, move to top level __init__ definition in domainlab - name_fun = name_signature[:name_signature.index("(")] + name_fun = name_signature[: name_signature.index("(")] if hasattr(net_module, name_fun): try: net = getattr(net_module, name_fun)(dim_y, remove_last_layer) except Exception: logger = Logger.get_logger() - logger.error(f"function {name_signature} should return a neural network " - f"(pytorch module) that that extract features from an image") + logger.error( + f"function {name_signature} should return a neural network " + f"(pytorch module) that that extract features from an image" + ) raise if net is None: - raise RuntimeError("the pytorch module returned by %s is None" - % (name_signature)) + raise RuntimeError( + "the pytorch module returned by %s is None" % (name_signature) + ) else: - raise RuntimeError("Please implement a function %s \ + raise RuntimeError( + "Please implement a function %s \ in your external python file" - % (name_signature)) + % (name_signature) + ) return net diff --git a/domainlab/utils/utils_class.py b/domainlab/utils/utils_class.py index 96963689d..3e0e9895e 100644 --- a/domainlab/utils/utils_class.py +++ b/domainlab/utils/utils_class.py @@ -3,13 +3,11 @@ def store_args(method): - """Stores provided method args as instance attributes. - """ + """Stores provided method args as instance attributes.""" argspec = inspect.getfullargspec(method) defaults = {} if argspec.defaults is not None: - defaults = dict( - zip(argspec.args[-len(argspec.defaults):], argspec.defaults)) + defaults = dict(zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)) if argspec.kwonlydefaults is not None: defaults.update(argspec.kwonlydefaults) arg_names = argspec.args[1:] diff --git a/domainlab/utils/utils_classif.py b/domainlab/utils/utils_classif.py index 51f858dec..2948f3289 100644 --- a/domainlab/utils/utils_classif.py +++ b/domainlab/utils/utils_classif.py @@ -26,12 +26,15 @@ def logit2preds_vpic(logit): one_hot = one_hot.scatter_(dim=1, index=max_ind, value=1.0) return one_hot, mat_prob, max_ind, max_prob + def get_label_na(tensor_ind, list_str_na): """ given list of label names in strings, map tensor of index to label names """ arr_ind_np = tensor_ind.cpu().numpy() - arr_ind = np.squeeze(arr_ind_np, axis=1) # explicitly use axis=1 to deal with edge case of only + arr_ind = np.squeeze( + arr_ind_np, axis=1 + ) # explicitly use axis=1 to deal with edge case of only # instance left # list_ind = list(arr_ind): if there is only dimension 1 tensor_ind, then there is a problem list_ind = arr_ind.tolist() diff --git a/domainlab/utils/utils_cuda.py b/domainlab/utils/utils_cuda.py index e068c6052..bf72776a3 100644 --- a/domainlab/utils/utils_cuda.py +++ b/domainlab/utils/utils_cuda.py @@ -2,6 +2,7 @@ choose devices """ import torch + from domainlab.utils.logger import Logger @@ -14,7 +15,7 @@ def get_device(args): if args.device is None: device = torch.device("cuda" if flag_cuda else "cpu") else: - device = torch.device("cuda:"+args.device if flag_cuda else "cpu") + device = torch.device("cuda:" + args.device if flag_cuda else "cpu") logger = Logger.get_logger() logger.info("") logger.info(f"using device: {str(device)}") diff --git a/domainlab/utils/utils_img_sav.py b/domainlab/utils/utils_img_sav.py index 8042b6bf7..d5ae50c4f 100644 --- a/domainlab/utils/utils_img_sav.py +++ b/domainlab/utils/utils_img_sav.py @@ -3,6 +3,7 @@ import matplotlib.pyplot as plt from torchvision.utils import make_grid, save_image + from domainlab.utils.logger import Logger @@ -24,6 +25,7 @@ def my_sav_img(comparison_tensor_stack, name, title=None): else: img_grid = make_grid(tensor=tensor, nrow=nrow) sav_add_title(img_grid, path=f_p, title="hi") + return my_sav_img @@ -31,7 +33,7 @@ def sav_add_title(grid_img, path, title): """ add title and save image as matplotlib.pyplot """ - fig = plt.gcf() # get current figure + fig = plt.gcf() # get current figure plt.imshow(grid_img.permute(1, 2, 0)) plt.title(title) fig.savefig(path) diff --git a/domainlab/zdata/mixed_codec/caltech/auto/text.txt b/domainlab/zdata/mixed_codec/caltech/auto/text.txt index 5e1c309da..557db03de 100644 --- a/domainlab/zdata/mixed_codec/caltech/auto/text.txt +++ b/domainlab/zdata/mixed_codec/caltech/auto/text.txt @@ -1 +1 @@ -Hello World \ No newline at end of file +Hello World diff --git a/domainlab/zdata/script/download_pacs.py b/domainlab/zdata/script/download_pacs.py index b05c7a4de..51c346f24 100644 --- a/domainlab/zdata/script/download_pacs.py +++ b/domainlab/zdata/script/download_pacs.py @@ -1,14 +1,16 @@ -'this script can be used to download the pacs dataset' +"this script can be used to download the pacs dataset" import os import tarfile from zipfile import ZipFile + import gdown + def stage_path(data_dir, name): - ''' + """ creates the path to data_dir/name if it does not exist already - ''' + """ full_path = os.path.join(data_dir, name) if not os.path.exists(full_path): @@ -16,11 +18,12 @@ def stage_path(data_dir, name): return full_path + def download_and_extract(url, dst, remove=True): - ''' + """ downloads and extracts the data behind the url and saves it at dst - ''' + """ gdown.download(url, dst, quiet=False) if dst.endswith(".tar.gz"): @@ -43,17 +46,19 @@ def download_and_extract(url, dst, remove=True): def download_pacs(data_dir): - ''' + """ download and extract dataset pacs. Dataset is saved at location data_dir - ''' + """ full_path = stage_path(data_dir, "PACS") - download_and_extract("https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", - os.path.join(data_dir, "PACS.zip")) + download_and_extract( + "https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", + os.path.join(data_dir, "PACS.zip"), + ) + + os.rename(os.path.join(data_dir, "kfold"), full_path) - os.rename(os.path.join(data_dir, "kfold"), - full_path) -if __name__ == '__main__': - download_pacs('../pacs') +if __name__ == "__main__": + download_pacs("../pacs") diff --git a/domainlab/zdata/ztest_files/dummy_file.py b/domainlab/zdata/ztest_files/dummy_file.py index ee817a687..e0c7faa27 100644 --- a/domainlab/zdata/ztest_files/dummy_file.py +++ b/domainlab/zdata/ztest_files/dummy_file.py @@ -1,4 +1,4 @@ -''' +""" I am a dummy file used in tests/test_git_tag.py to produce a file which is not commited -''' +""" diff --git a/examples/api/jigen_dann_transformer.py b/examples/api/jigen_dann_transformer.py index d06aa0dbf..63bba2c9a 100644 --- a/examples/api/jigen_dann_transformer.py +++ b/examples/api/jigen_dann_transformer.py @@ -7,18 +7,19 @@ from torchvision.models.feature_extraction import create_feature_extractor from domainlab.mk_exp import mk_exp -from domainlab.tasks import get_task from domainlab.models.model_dann import mk_dann from domainlab.models.model_jigen import mk_jigen +from domainlab.tasks import get_task class VIT(nn.Module): """ Vision transformer as feature extractor """ - def __init__(self, freeze=True, - list_str_last_layer=['getitem_5'], - len_last_layer=768): + + def __init__( + self, freeze=True, list_str_last_layer=["getitem_5"], len_last_layer=768 + ): super().__init__() self.nets = vit_b_16(pretrained=True) if freeze: @@ -27,15 +28,15 @@ def __init__(self, freeze=True, # in case of enough computation resources for param in self.nets.parameters(): param.requires_grad = False - self.features_vit_flatten = \ - create_feature_extractor(self.nets, - return_nodes=list_str_last_layer) + self.features_vit_flatten = create_feature_extractor( + self.nets, return_nodes=list_str_last_layer + ) def forward(self, tensor_x): """ compute logits predicts """ - out = self.features_vit_flatten(tensor_x)['getitem_5'] + out = self.features_vit_flatten(tensor_x)["getitem_5"] return out @@ -51,26 +52,32 @@ def test_transformer(): net_classifier = nn.Linear(768, task.dim_y) # see documentation for each arguments below - model_dann = mk_dann()(net_encoder=net_feature, - net_classifier=net_classifier, - net_discriminator=nn.Linear(768, 2), - list_str_y=task.list_str_y, - list_d_tr=["labelme", "sun"], - alpha=1.0) + model_dann = mk_dann()( + net_encoder=net_feature, + net_classifier=net_classifier, + net_discriminator=nn.Linear(768, 2), + list_str_y=task.list_str_y, + list_d_tr=["labelme", "sun"], + alpha=1.0, + ) # see documentation for each argument below - model_jigen = mk_jigen()(net_encoder=net_feature, - net_classifier_class=net_classifier, - net_classifier_permutation=nn.Linear(768, 32), - list_str_y=task.list_str_y, - coeff_reg=1.0, n_perm=31) + model_jigen = mk_jigen()( + net_encoder=net_feature, + net_classifier_class=net_classifier, + net_classifier_permutation=nn.Linear(768, 32), + list_str_y=task.list_str_y, + coeff_reg=1.0, + n_perm=31, + ) model_jigen.extend(model_dann) # let Jigen decorate DANN model = model_jigen # make trainer for model, here we decorate trainer mldg with dial - exp = mk_exp(task, model, trainer="mldg_dial", - test_domain="caltech", batchsize=2, nocu=True) + exp = mk_exp( + task, model, trainer="mldg_dial", test_domain="caltech", batchsize=2, nocu=True + ) exp.execute(num_epochs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_transformer() diff --git a/examples/benchmark/benchmark_blood_resnet.yaml b/examples/benchmark/benchmark_blood_resnet.yaml index ac1cb7213..b10fd9364 100644 --- a/examples/benchmark/benchmark_blood_resnet.yaml +++ b/examples/benchmark/benchmark_blood_resnet.yaml @@ -55,7 +55,7 @@ Shared params: zd_dim: reference: zy_dim - + gamma_reg: min: 0.01 max: 10 diff --git a/examples/benchmark/benchmark_pacs_resnet_trainer.yaml b/examples/benchmark/benchmark_pacs_resnet_trainer.yaml index ea6cfe056..d11dffef6 100644 --- a/examples/benchmark/benchmark_pacs_resnet_trainer.yaml +++ b/examples/benchmark/benchmark_pacs_resnet_trainer.yaml @@ -69,10 +69,10 @@ matchdg: # name max: 20 step: 1 distribution: uniform - + # The only transformation for JiGen allowed is normalization and image resize, no random flip, as the original code shows: # https://github.com/fmcarlucci/JigenDG/blob/master/data/JigsawLoader.py -# adding random flip here will cause jigen to confuse with the image tile reshuffling. +# adding random flip here will cause jigen to confuse with the image tile reshuffling. dial: # name model: erm diff --git a/examples/conf/vlcs_diva_mldg_dial.yaml b/examples/conf/vlcs_diva_mldg_dial.yaml index f5cdba9bc..0701766ff 100644 --- a/examples/conf/vlcs_diva_mldg_dial.yaml +++ b/examples/conf/vlcs_diva_mldg_dial.yaml @@ -1,5 +1,5 @@ te_d: caltech # domain name of test domain -tpath: examples/tasks/task_vlcs.py # python file path to specify the task +tpath: examples/tasks/task_vlcs.py # python file path to specify the task bs: 2 # batch size model: dann_diva # combine model DANN with DIVA epos: 1 # number of epochs diff --git a/examples/models/demo_custom_model.py b/examples/models/demo_custom_model.py index 515d9457b..83fc4e8e7 100644 --- a/examples/models/demo_custom_model.py +++ b/examples/models/demo_custom_model.py @@ -4,14 +4,15 @@ import torch from torch.nn import functional as F -from domainlab.models.model_custom import AModelCustom from domainlab.algos.builder_custom import make_basic_trainer +from domainlab.models.model_custom import AModelCustom class ModelCustom(AModelCustom): """ Template class to inherit from if user need custom neural network """ + @property def dict_net_module_na2arg_na(self): """ diff --git a/examples/nets/resnet.py b/examples/nets/resnet.py index 2dd5074f9..2afb3e119 100644 --- a/examples/nets/resnet.py +++ b/examples/nets/resnet.py @@ -10,6 +10,7 @@ class ResNetBase(NetTorchVisionBase): """ Since ResNet can be fetched from torchvision """ + def fetch_net(self, flag_pretrain): """fetch_net. @@ -17,10 +18,10 @@ def fetch_net(self, flag_pretrain): """ if flag_pretrain: self.net_torchvision = torchvisionmodels.resnet.resnet50( - weights=ResNet50_Weights.IMAGENET1K_V2) + weights=ResNet50_Weights.IMAGENET1K_V2 + ) else: - self.net_torchvision = torchvisionmodels.resnet.resnet50( - weights='None') + self.net_torchvision = torchvisionmodels.resnet.resnet50(weights="None") # CHANGEME: user can modify this line to choose other neural # network architectures from 'torchvision.models' @@ -29,6 +30,7 @@ class ResNet4DeepAll(ResNetBase): """ change the size of the last layer """ + def __init__(self, flag_pretrain, dim_y): """__init__. diff --git a/examples/nets/resnet50domainbed.py b/examples/nets/resnet50domainbed.py index d54c98f27..3428b8b58 100644 --- a/examples/nets/resnet50domainbed.py +++ b/examples/nets/resnet50domainbed.py @@ -1,7 +1,7 @@ -''' +""" resnet50 modified as described in https://arxiv.org/pdf/2007.01434.pdf appendix D -''' +""" from torch import nn from torchvision import models as torchvisionmodels from torchvision.models import ResNet50_Weights @@ -11,7 +11,7 @@ class CostumResNet(nn.Module): - ''' + """ this costum resnet includes the modification described in https://arxiv.org/pdf/2007.01434.pdf appendix D @@ -20,21 +20,22 @@ class CostumResNet(nn.Module): generalization algorithms (as different minibatches follow different distributions), we freeze all batch normalization layers before fine-tuning. We insert a dropout layer before the final linear layer. - ''' + """ + def __init__(self, flag_pretrain): super().__init__() self.flag_pretrain = flag_pretrain if flag_pretrain: resnet50 = torchvisionmodels.resnet.resnet50( - weights=ResNet50_Weights.IMAGENET1K_V2) + weights=ResNet50_Weights.IMAGENET1K_V2 + ) else: - resnet50 = torchvisionmodels.resnet.resnet50( - weights='None') + resnet50 = torchvisionmodels.resnet.resnet50(weights="None") # freez all batchnormalisation layers for module in resnet50.modules(): - if module._get_name() == 'BatchNorm2d': + if module._get_name() == "BatchNorm2d": module.requires_grad_(False) self.resnet50_first_part = nn.Sequential(*(list(resnet50.children())[:-1])) @@ -55,6 +56,7 @@ class ResNetBase(NetTorchVisionBase): """ Since ResNet can be fetched from torchvision """ + def fetch_net(self, flag_pretrain): """fetch_net. @@ -69,6 +71,7 @@ class ResNet4DeepAll(ResNetBase): """ change the size of the last layer """ + def __init__(self, flag_pretrain, dim_y): """__init__. diff --git a/examples/nets/resnetDassl.py b/examples/nets/resnetDassl.py index 71e4af546..9006a305f 100644 --- a/examples/nets/resnetDassl.py +++ b/examples/nets/resnetDassl.py @@ -10,25 +10,25 @@ class ResNetBaseDassl(NetTorchVisionBase): """ Since ResNet can be fetched from torchvision """ + def fetch_net(self, flag_pretrain): """fetch_net. :param flag_pretrain: """ - self.net_torchvision = torchvisionmodels.resnet.resnet50( - weights=None) - weights = 'examples/nets/resnet50-19c8e357Dassl.pth' + self.net_torchvision = torchvisionmodels.resnet.resnet50(weights=None) + weights = "examples/nets/resnet50-19c8e357Dassl.pth" # CHANGEME: user can modify this line to choose other neural # network architectures from 'torchvision.models' if flag_pretrain: self.net_torchvision.load_state_dict(torch.load(weights)) - class ResNet4DeepAllDassl(ResNetBaseDassl): """ change the size of the last layer """ + def __init__(self, flag_pretrain, dim_y): """__init__. diff --git a/examples/nets/transformer.py b/examples/nets/transformer.py index 557d842c9..6e417f32d 100644 --- a/examples/nets/transformer.py +++ b/examples/nets/transformer.py @@ -7,29 +7,34 @@ from torchvision.models.feature_extraction import create_feature_extractor from domainlab.mk_exp import mk_exp -from domainlab.tasks import get_task from domainlab.models.model_erm import mk_erm +from domainlab.tasks import get_task class VIT(nn.Module): - def __init__(self, num_cls, freeze=True, - list_str_last_layer=['getitem_5'], - len_last_layer=768): + def __init__( + self, + num_cls, + freeze=True, + list_str_last_layer=["getitem_5"], + len_last_layer=768, + ): super().__init__() self.nets = vit_b_16(pretrained=True) if freeze: # freeze all the network except the final layer for param in self.nets.parameters(): param.requires_grad = False - self.features_vit_flatten = create_feature_extractor(self.nets, - return_nodes=list_str_last_layer) + self.features_vit_flatten = create_feature_extractor( + self.nets, return_nodes=list_str_last_layer + ) self.fc = nn.Linear(len_last_layer, num_cls) def forward(self, tensor_x): """ compute logits predicts """ - x = self.features_vit_flatten(tensor_x)['getitem_5'] + x = self.features_vit_flatten(tensor_x)["getitem_5"] out = self.fc(x) return out @@ -44,10 +49,11 @@ def test_transformer(): backbone = VIT(num_cls=task.dim_y, freeze=True) model = mk_erm()(backbone) # make trainer for model - exp = mk_exp(task, model, trainer="mldg,dial", - test_domain="caltech", batchsize=2, nocu=True) + exp = mk_exp( + task, model, trainer="mldg,dial", test_domain="caltech", batchsize=2, nocu=True + ) exp.execute(num_epochs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_transformer() diff --git a/examples/tasks/demo_task_path_list_small.py b/examples/tasks/demo_task_path_list_small.py index 9811777ae..92425774a 100644 --- a/examples/tasks/demo_task_path_list_small.py +++ b/examples/tasks/demo_task_path_list_small.py @@ -2,6 +2,7 @@ Toy example on how to use TaskPathList, by subsample a small portion of PACS data """ from torchvision import transforms + from domainlab.tasks.task_pathlist import mk_node_task_path_list from domainlab.tasks.utils_task import ImSize @@ -15,8 +16,7 @@ def get_task(na=None): # ## specify image size, must be consistent with the transformation isize=ImSize(3, 224, 224), # ## specify the names for all classes to classify - list_str_y=["dog", "elephant", "giraffe", "guitar", - "horse", "house", "person"], + list_str_y=["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"], # ## give an index to each target class dict_class_label_ind2name={ "1": "dog", @@ -25,44 +25,44 @@ def get_task(na=None): "4": "guitar", "5": "horse", "6": "house", - "7": "person"}, - + "7": "person", + }, # ## specify where to find the text file containing path for each image # ## the text file below only need to specify the relative path - # ## training dict_d2filepath_list_img_tr={ "art_painting": "data/pacs_split/art_painting_10.txt", "cartoon": "data/pacs_split/cartoon_10.txt", "photo": "data/pacs_split/photo_10.txt", - "sketch": "data/pacs_split/sketch_10.txt"}, - + "sketch": "data/pacs_split/sketch_10.txt", + }, # ## testing dict_d2filepath_list_img_te={ "art_painting": "data/pacs_split/art_painting_10.txt", "cartoon": "data/pacs_split/cartoon_10.txt", "photo": "data/pacs_split/photo_10.txt", - "sketch": "data/pacs_split/sketch_10.txt"}, - + "sketch": "data/pacs_split/sketch_10.txt", + }, # ## validation dict_d2filepath_list_img_val={ "art_painting": "data/pacs_split/art_painting_10.txt", "cartoon": "data/pacs_split/cartoon_10.txt", "photo": "data/pacs_split/photo_10.txt", - "sketch": "data/pacs_split/sketch_10.txt"}, - + "sketch": "data/pacs_split/sketch_10.txt", + }, # ## specify root folder storing the images of each domain: dict_domain2imgroot={ - 'art_painting': "data/pacs_mini_10", - 'cartoon': "data/pacs_mini_10", - 'photo': "data/pacs_mini_10", - 'sketch': "data/pacs_mini_10"}, - + "art_painting": "data/pacs_mini_10", + "cartoon": "data/pacs_mini_10", + "photo": "data/pacs_mini_10", + "sketch": "data/pacs_mini_10", + }, # ## specify the pytorch transformation you want to apply to the image img_trans_tr=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()])) + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + ) return node diff --git a/examples/tasks/task_blood2.py b/examples/tasks/task_blood2.py index f4e46d802..5c5571c66 100644 --- a/examples/tasks/task_blood2.py +++ b/examples/tasks/task_blood2.py @@ -1,95 +1,100 @@ from torchvision import transforms + from domainlab.tasks.task_folder_mk import mk_task_folder from domainlab.tasks.utils_task import ImSize - IMG_SIZE = 224 -trans = transforms.Compose([ - transforms.Resize((IMG_SIZE, IMG_SIZE)), - transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)), - transforms.RandomHorizontalFlip(), - transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), - transforms.RandomGrayscale(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], - [0.229, 0.224, 0.225]) -]) +trans = transforms.Compose( + [ + transforms.Resize((IMG_SIZE, IMG_SIZE)), + transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), + transforms.RandomGrayscale(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] +) -trans_te = transforms.Compose([ - transforms.Resize((IMG_SIZE, IMG_SIZE)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], - [0.229, 0.224, 0.225]) -]) +trans_te = transforms.Compose( + [ + transforms.Resize((IMG_SIZE, IMG_SIZE)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] +) -TASK = mk_task_folder(extensions={"acevedo": "jpg", "matek": "tiff", "mll": "tif"}, - list_str_y=[ - "basophil", - "erythroblast", - "metamyelocyte", - "myeloblast", - "neutrophil_band", - "promyelocyte", - "eosinophil", - "lymphocyte_typical", - "monocyte", - "myelocyte", - "neutrophil_segmented" - ], - dict_domain_folder_name2class={ - "acevedo": { - "basophil": "basophil", - "erythroblast": "erythroblast", - "metamyelocyte": "metamyelocyte", - "neutrophil_band": "neutrophil_band", - "promyelocyte": "promyelocyte", - "eosinophil": "eosinophil", - "lymphocyte_typical": "lymphocyte_typical", - "monocyte": "monocyte", - "myelocyte": "myelocyte", - "neutrophil_segmented": "neutrophil_segmented", - }, - "matek": { - "basophil": "basophil", - "erythroblast": "erythroblast", - "metamyelocyte": "metamyelocyte", - "myeloblast": "myeloblast", - "neutrophil_band": "neutrophil_band", - "promyelocyte": "promyelocyte", - "eosinophil": "eosinophil", - "lymphocyte_typical": "lymphocyte_typical", - "monocyte": "monocyte", - "myelocyte": "myelocyte", - "neutrophil_segmented": "neutrophil_segmented", - }, - "mll": { - "basophil": "basophil", - "erythroblast": "erythroblast", - "metamyelocyte": "metamyelocyte", - "myeloblast": "myeloblast", - "neutrophil_band": "neutrophil_band", - "promyelocyte": "promyelocyte", - "eosinophil": "eosinophil", - "lymphocyte_typical": "lymphocyte_typical", - "monocyte": "monocyte", - "myelocyte": "myelocyte", - "neutrophil_segmented": "neutrophil_segmented", - }, - }, - dict_domain_img_trans={ - "acevedo": trans, - "mll": trans, - "matek": trans, - }, - img_trans_te=trans_te, - isize=ImSize(3, IMG_SIZE, IMG_SIZE), - dict_domain2imgroot={ - "matek": "/lustre/groups/labs/marr/qscd01/datasets/armingruber/_Domains/Matek_cropped", - "mll": "/lustre/groups/labs/marr/qscd01/datasets/armingruber/_Domains/MLL_20221220", - "acevedo": "/lustre/groups/labs/marr/qscd01/datasets/armingruber/_Domains/Acevedo_cropped"}, - taskna="blood_mon_eos_bas") +TASK = mk_task_folder( + extensions={"acevedo": "jpg", "matek": "tiff", "mll": "tif"}, + list_str_y=[ + "basophil", + "erythroblast", + "metamyelocyte", + "myeloblast", + "neutrophil_band", + "promyelocyte", + "eosinophil", + "lymphocyte_typical", + "monocyte", + "myelocyte", + "neutrophil_segmented", + ], + dict_domain_folder_name2class={ + "acevedo": { + "basophil": "basophil", + "erythroblast": "erythroblast", + "metamyelocyte": "metamyelocyte", + "neutrophil_band": "neutrophil_band", + "promyelocyte": "promyelocyte", + "eosinophil": "eosinophil", + "lymphocyte_typical": "lymphocyte_typical", + "monocyte": "monocyte", + "myelocyte": "myelocyte", + "neutrophil_segmented": "neutrophil_segmented", + }, + "matek": { + "basophil": "basophil", + "erythroblast": "erythroblast", + "metamyelocyte": "metamyelocyte", + "myeloblast": "myeloblast", + "neutrophil_band": "neutrophil_band", + "promyelocyte": "promyelocyte", + "eosinophil": "eosinophil", + "lymphocyte_typical": "lymphocyte_typical", + "monocyte": "monocyte", + "myelocyte": "myelocyte", + "neutrophil_segmented": "neutrophil_segmented", + }, + "mll": { + "basophil": "basophil", + "erythroblast": "erythroblast", + "metamyelocyte": "metamyelocyte", + "myeloblast": "myeloblast", + "neutrophil_band": "neutrophil_band", + "promyelocyte": "promyelocyte", + "eosinophil": "eosinophil", + "lymphocyte_typical": "lymphocyte_typical", + "monocyte": "monocyte", + "myelocyte": "myelocyte", + "neutrophil_segmented": "neutrophil_segmented", + }, + }, + dict_domain_img_trans={ + "acevedo": trans, + "mll": trans, + "matek": trans, + }, + img_trans_te=trans_te, + isize=ImSize(3, IMG_SIZE, IMG_SIZE), + dict_domain2imgroot={ + "matek": "/lustre/groups/labs/marr/qscd01/datasets/armingruber/_Domains/Matek_cropped", + "mll": "/lustre/groups/labs/marr/qscd01/datasets/armingruber/_Domains/MLL_20221220", + "acevedo": "/lustre/groups/labs/marr/qscd01/datasets/armingruber/_Domains/Acevedo_cropped", + }, + taskna="blood_mon_eos_bas", +) def get_task(na=None): diff --git a/examples/tasks/task_dset_custom.py b/examples/tasks/task_dset_custom.py index d3c40f4bf..a5328cff0 100644 --- a/examples/tasks/task_dset_custom.py +++ b/examples/tasks/task_dset_custom.py @@ -2,22 +2,26 @@ example task construction: Specify each domain by a training set and validation (can be None) """ +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault from domainlab.tasks.task_dset import mk_task_dset from domainlab.tasks.utils_task import ImSize -from domainlab.dsets.dset_mnist_color_solo_default import \ - DsetMNISTColorSoloDefault - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") -task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) -task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) -task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) +task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), +) +task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), +) +task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), +) def get_task(na=None): diff --git a/examples/tasks/task_pacs_aug.py b/examples/tasks/task_pacs_aug.py index 6cc9a6767..6fc092dca 100644 --- a/examples/tasks/task_pacs_aug.py +++ b/examples/tasks/task_pacs_aug.py @@ -6,6 +6,7 @@ """ from torchvision import transforms + from domainlab.tasks.task_pathlist import mk_node_task_path_list from domainlab.tasks.utils_task import ImSize @@ -13,54 +14,64 @@ G_PACS_RAW_PATH = "data/pacs/PACS" # domainlab repository contain already the file names in data/pacs_split folder of domainlab + def get_task(na=None): node = mk_node_task_path_list( isize=ImSize(3, 224, 224), - list_str_y=["dog", "elephant", "giraffe", "guitar", - "horse", "house", "person"], - dict_class_label_ind2name={"1": "dog", - "2": "elephant", - "3": "giraffe", - "4": "guitar", - "5": "horse", - "6": "house", - "7": "person"}, + list_str_y=["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"], + dict_class_label_ind2name={ + "1": "dog", + "2": "elephant", + "3": "giraffe", + "4": "guitar", + "5": "horse", + "6": "house", + "7": "person", + }, dict_d2filepath_list_img_tr={ "art_painting": "data/pacs_split/art_painting_train_kfold.txt", "cartoon": "data/pacs_split/cartoon_train_kfold.txt", "photo": "data/pacs_split/photo_train_kfold.txt", - "sketch": "data/pacs_split/sketch_train_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_train_kfold.txt", + }, dict_d2filepath_list_img_te={ "art_painting": "data/pacs_split/art_painting_test_kfold.txt", "cartoon": "data/pacs_split/cartoon_test_kfold.txt", "photo": "data/pacs_split/photo_test_kfold.txt", - "sketch": "data/pacs_split/sketch_test_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_test_kfold.txt", + }, dict_d2filepath_list_img_val={ "art_painting": "data/pacs_split/art_painting_crossval_kfold.txt", "cartoon": "data/pacs_split/cartoon_crossval_kfold.txt", "photo": "data/pacs_split/photo_crossval_kfold.txt", - "sketch": "data/pacs_split/sketch_crossval_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_crossval_kfold.txt", + }, dict_domain2imgroot={ - 'art_painting': G_PACS_RAW_PATH, - 'cartoon': G_PACS_RAW_PATH, - 'photo': G_PACS_RAW_PATH, - 'sketch': G_PACS_RAW_PATH}, + "art_painting": G_PACS_RAW_PATH, + "cartoon": G_PACS_RAW_PATH, + "photo": G_PACS_RAW_PATH, + "sketch": G_PACS_RAW_PATH, + }, img_trans_tr=transforms.Compose( - [transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), - transforms.RandomHorizontalFlip(), - transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), - transforms.RandomGrayscale(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]), + [ + transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), + transforms.RandomGrayscale(), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), ) return node diff --git a/examples/tasks/task_pacs_aug_noflip.py b/examples/tasks/task_pacs_aug_noflip.py index 0282209b4..b5db010e9 100644 --- a/examples/tasks/task_pacs_aug_noflip.py +++ b/examples/tasks/task_pacs_aug_noflip.py @@ -6,6 +6,7 @@ """ from torchvision import transforms + from domainlab.tasks.task_pathlist import mk_node_task_path_list from domainlab.tasks.utils_task import ImSize @@ -13,53 +14,63 @@ G_PACS_RAW_PATH = "data/pacs/PACS" # domainlab repository contain already the file names in data/pacs_split folder of domainlab + def get_task(na=None): node = mk_node_task_path_list( isize=ImSize(3, 224, 224), - list_str_y=["dog", "elephant", "giraffe", "guitar", - "horse", "house", "person"], - dict_class_label_ind2name={"1": "dog", - "2": "elephant", - "3": "giraffe", - "4": "guitar", - "5": "horse", - "6": "house", - "7": "person"}, + list_str_y=["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"], + dict_class_label_ind2name={ + "1": "dog", + "2": "elephant", + "3": "giraffe", + "4": "guitar", + "5": "horse", + "6": "house", + "7": "person", + }, dict_d2filepath_list_img_tr={ "art_painting": "data/pacs_split/art_painting_train_kfold.txt", "cartoon": "data/pacs_split/cartoon_train_kfold.txt", "photo": "data/pacs_split/photo_train_kfold.txt", - "sketch": "data/pacs_split/sketch_train_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_train_kfold.txt", + }, dict_d2filepath_list_img_te={ "art_painting": "data/pacs_split/art_painting_test_kfold.txt", "cartoon": "data/pacs_split/cartoon_test_kfold.txt", "photo": "data/pacs_split/photo_test_kfold.txt", - "sketch": "data/pacs_split/sketch_test_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_test_kfold.txt", + }, dict_d2filepath_list_img_val={ "art_painting": "data/pacs_split/art_painting_crossval_kfold.txt", "cartoon": "data/pacs_split/cartoon_crossval_kfold.txt", "photo": "data/pacs_split/photo_crossval_kfold.txt", - "sketch": "data/pacs_split/sketch_crossval_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_crossval_kfold.txt", + }, dict_domain2imgroot={ - 'art_painting': G_PACS_RAW_PATH, - 'cartoon': G_PACS_RAW_PATH, - 'photo': G_PACS_RAW_PATH, - 'sketch': G_PACS_RAW_PATH}, + "art_painting": G_PACS_RAW_PATH, + "cartoon": G_PACS_RAW_PATH, + "photo": G_PACS_RAW_PATH, + "sketch": G_PACS_RAW_PATH, + }, img_trans_tr=transforms.Compose( - [transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), - transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), - transforms.RandomGrayscale(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]), + [ + transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), + transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), + transforms.RandomGrayscale(), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ), ) return node diff --git a/examples/tasks/task_pacs_path_list.py b/examples/tasks/task_pacs_path_list.py index bc0b14c2c..045232b4a 100644 --- a/examples/tasks/task_pacs_path_list.py +++ b/examples/tasks/task_pacs_path_list.py @@ -6,6 +6,7 @@ """ from torchvision import transforms + from domainlab.tasks.task_pathlist import mk_node_task_path_list from domainlab.tasks.utils_task import ImSize @@ -13,46 +14,49 @@ G_PACS_RAW_PATH = "data/pacs/PACS" # domainlab repository contain already the file names in data/pacs_split folder of domainlab + def get_task(na=None): node = mk_node_task_path_list( isize=ImSize(3, 224, 224), - list_str_y=["dog", "elephant", "giraffe", "guitar", - "horse", "house", "person"], - dict_class_label_ind2name={"1": "dog", - "2": "elephant", - "3": "giraffe", - "4": "guitar", - "5": "horse", - "6": "house", - "7": "person"}, + list_str_y=["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"], + dict_class_label_ind2name={ + "1": "dog", + "2": "elephant", + "3": "giraffe", + "4": "guitar", + "5": "horse", + "6": "house", + "7": "person", + }, dict_d2filepath_list_img_tr={ "art_painting": "data/pacs_split/art_painting_train_kfold.txt", "cartoon": "data/pacs_split/cartoon_train_kfold.txt", "photo": "data/pacs_split/photo_train_kfold.txt", - "sketch": "data/pacs_split/sketch_train_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_train_kfold.txt", + }, dict_d2filepath_list_img_te={ "art_painting": "data/pacs_split/art_painting_test_kfold.txt", "cartoon": "data/pacs_split/cartoon_test_kfold.txt", "photo": "data/pacs_split/photo_test_kfold.txt", - "sketch": "data/pacs_split/sketch_test_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_test_kfold.txt", + }, dict_d2filepath_list_img_val={ "art_painting": "data/pacs_split/art_painting_crossval_kfold.txt", "cartoon": "data/pacs_split/cartoon_crossval_kfold.txt", "photo": "data/pacs_split/photo_crossval_kfold.txt", - "sketch": "data/pacs_split/sketch_crossval_kfold.txt"}, - + "sketch": "data/pacs_split/sketch_crossval_kfold.txt", + }, dict_domain2imgroot={ - 'art_painting': G_PACS_RAW_PATH, - 'cartoon': G_PACS_RAW_PATH, - 'photo': G_PACS_RAW_PATH, - 'sketch': G_PACS_RAW_PATH}, + "art_painting": G_PACS_RAW_PATH, + "cartoon": G_PACS_RAW_PATH, + "photo": G_PACS_RAW_PATH, + "sketch": G_PACS_RAW_PATH, + }, img_trans_tr=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]) + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), ) return node diff --git a/examples/tasks/task_vlcs.py b/examples/tasks/task_vlcs.py index 5b14a50e9..3deb6b2d6 100644 --- a/examples/tasks/task_vlcs.py +++ b/examples/tasks/task_vlcs.py @@ -1,60 +1,63 @@ import os + from torchvision import transforms + from domainlab.tasks.task_folder_mk import mk_task_folder from domainlab.tasks.utils_task import ImSize # relative path is essential here since this file is used for testing, no absolute directory possible path_this_file = os.path.dirname(os.path.realpath(__file__)) -chain = mk_task_folder(extensions={"caltech": "jpg", "sun": - "jpg", "labelme": "jpg"}, - list_str_y=["chair", "car"], - dict_domain_folder_name2class={ - "caltech": {"auto": "car", "stuhl": "chair"}, - "sun": {"vehicle": "car", "sofa": "chair"}, - "labelme": {"drive": "car", "sit": "chair"} - }, - dict_domain_img_trans={ - "caltech": transforms.Compose( - [transforms.Resize((256, 256)), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]), - "sun": transforms.Compose( - [transforms.Resize((256, 256)), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]), - "labelme": transforms.Compose( - [transforms.Resize((256, 256)), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]), - }, - img_trans_te=transforms.Compose( - [transforms.Resize((256, 256)), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), - isize=ImSize(3, 224, 224), - dict_domain2imgroot={ - "caltech": os.path.join( - path_this_file, - "../../data/vlcs_mini/caltech/"), - "sun": os.path.join( - path_this_file, - "../../data/vlcs_mini/sun/"), - "labelme": os.path.join( - path_this_file, - "../../data/vlcs_mini/labelme/")}, - taskna="e_mini_vlcs") +chain = mk_task_folder( + extensions={"caltech": "jpg", "sun": "jpg", "labelme": "jpg"}, + list_str_y=["chair", "car"], + dict_domain_folder_name2class={ + "caltech": {"auto": "car", "stuhl": "chair"}, + "sun": {"vehicle": "car", "sofa": "chair"}, + "labelme": {"drive": "car", "sit": "chair"}, + }, + dict_domain_img_trans={ + "caltech": transforms.Compose( + [ + transforms.Resize((256, 256)), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ), + "sun": transforms.Compose( + [ + transforms.Resize((256, 256)), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ), + "labelme": transforms.Compose( + [ + transforms.Resize((256, 256)), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ), + }, + img_trans_te=transforms.Compose( + [ + transforms.Resize((256, 256)), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ), + isize=ImSize(3, 224, 224), + dict_domain2imgroot={ + "caltech": os.path.join(path_this_file, "../../data/vlcs_mini/caltech/"), + "sun": os.path.join(path_this_file, "../../data/vlcs_mini/sun/"), + "labelme": os.path.join(path_this_file, "../../data/vlcs_mini/labelme/"), + }, + taskna="e_mini_vlcs", +) def get_task(na=None): diff --git a/examples/yaml/demo_config_single_run_diva.yaml b/examples/yaml/demo_config_single_run_diva.yaml index 08db9e725..c97b13775 100644 --- a/examples/yaml/demo_config_single_run_diva.yaml +++ b/examples/yaml/demo_config_single_run_diva.yaml @@ -1,4 +1,4 @@ ---- # yaml demo document +--- # yaml demo document te_d: caltech tpath: examples/tasks/task_vlcs.py diff --git a/examples/yaml/demo_slurm_config_with_comments.yaml b/examples/yaml/demo_slurm_config_with_comments.yaml index 0eec7e49b..d251037dc 100644 --- a/examples/yaml/demo_slurm_config_with_comments.yaml +++ b/examples/yaml/demo_slurm_config_with_comments.yaml @@ -2,28 +2,28 @@ cluster: mkdir -p logs/{rule} && sbatch - --partition=gpu_p + --partition=gpu_p # Put the job into the gpu partition - --qos=gpu + --qos=gpu # Request a quality of service for the job. - --gres=gpu:1 + --gres=gpu:1 # Number of GPUs per node (gres=gpu:N) - --nice=10000 + --nice=10000 # Run the job with an adjusted scheduling priority within Slurm. - -c 2 + -c 2 # Allocating number of processes per task - --mem=60G + --mem=60G # RAM per node - --job-name=smk-{rule}-{wildcards} + --job-name=smk-{rule}-{wildcards} # Specify name for job allocation - --output=logs/{rule}/{rule}-{wildcards}-%j.out - # Output file for logs + --output=logs/{rule}/{rule}-{wildcards}-%j.out + # Output file for logs default-resources: - - partition=gpu_p + - partition=gpu_p # Put the job into the gpu partition - - qos=gpu + - qos=gpu # Request a quality of service for the job. - - mem_mb=1000 + - mem_mb=1000 # memory in MB a cluster node must provide restart-times: 3 max-jobs-per-second: 10 diff --git a/examples/yaml/slurm/config.yaml b/examples/yaml/slurm/config.yaml index 12d922c90..e180df74d 100644 --- a/examples/yaml/slurm/config.yaml +++ b/examples/yaml/slurm/config.yaml @@ -2,20 +2,20 @@ cluster: mkdir -p zoutput/slurm_logs/{rule} && sbatch - --partition=gpu_p - --qos=gpu_normal - --gres=gpu:1 - --nice=10000 + --partition=gpu_p + --qos=gpu_normal + --gres=gpu:1 + --nice=10000 -t 48:00:00 - -c 2 - --mem=160G - --job-name=smk-{rule}-{wildcards} - --output=zoutput/slurm_logs/{rule}/{rule}-{wildcards}-%j.out + -c 2 + --mem=160G + --job-name=smk-{rule}-{wildcards} + --output=zoutput/slurm_logs/{rule}/{rule}-{wildcards}-%j.out --error=zoutput/slurm_logs/{rule}/{rule}-{wildcards}-%j.err default-resources: - - partition=gpu_p - - qos=gpu_normal - - mem_mb=100000 + - partition=gpu_p + - qos=gpu_normal + - mem_mb=100000 restart-times: 3 max-jobs-per-second: 10 max-status-checks-per-second: 1 diff --git a/examples/yaml/test_helm_benchmark.yaml b/examples/yaml/test_helm_benchmark.yaml index c7ec914e7..606872f6f 100644 --- a/examples/yaml/test_helm_benchmark.yaml +++ b/examples/yaml/test_helm_benchmark.yaml @@ -44,7 +44,7 @@ Task5: # name min: 0.01 max: 10 distribution: loguniform - + Task6: # name model: erm diff --git a/examples/yaml/test_helm_pacs.yaml b/examples/yaml/test_helm_pacs.yaml index de603c1e3..4ec2d7ef7 100644 --- a/examples/yaml/test_helm_pacs.yaml +++ b/examples/yaml/test_helm_pacs.yaml @@ -41,7 +41,7 @@ Task5: # name min: 0.01 max: 10 distribution: loguniform - + Task6: # name model: erm diff --git a/gen_doc.sh b/gen_doc.sh index 6d5534f10..c37e0faa0 100644 --- a/gen_doc.sh +++ b/gen_doc.sh @@ -16,4 +16,3 @@ cp ./*.md build/html/ mkdir -p build/html/figs cp -r ./figs/* build/html/figs # under docs directory, there is a empty html which point to the html generated by sphinx - diff --git a/main_out.py b/main_out.py index 76f4d69ec..27d7e8fc4 100644 --- a/main_out.py +++ b/main_out.py @@ -9,8 +9,9 @@ if args.bm_dir: aggregate_results.agg_main(args.bm_dir) elif args.plot_data is not None: - gen_benchmark_plots(args.plot_data, args.outp_dir, - use_param_index=args.param_idx) + gen_benchmark_plots( + args.plot_data, args.outp_dir, use_param_index=args.param_idx + ) else: set_seed(args.seed) exp = Exp(args=args) diff --git a/paper.md b/paper.md index 800ce00e0..1e33b81a5 100644 --- a/paper.md +++ b/paper.md @@ -1,48 +1,48 @@ # DomainLab: A PyTorch library for causal domain generalization ## Summary -Deep learning (DL) models have solved real-world challenges in various areas, such as computer vision, natural language processing, and medical image classification or computational pathology. While generalizing to unseen test domains comes naturally to humans, it’s still a major obstacle for machines. By design, most DL models assume that training and testing distributions are the same, causing them to fail when this is violated. Instead, domain generalization aims at training domain invariant models that are robust to distribution shifts. +Deep learning (DL) models have solved real-world challenges in various areas, such as computer vision, natural language processing, and medical image classification or computational pathology. While generalizing to unseen test domains comes naturally to humans, it’s still a major obstacle for machines. By design, most DL models assume that training and testing distributions are the same, causing them to fail when this is violated. Instead, domain generalization aims at training domain invariant models that are robust to distribution shifts. -We introduce DomainLab, a PyTorch based Python package for domain generalization. DomainLab focuses on causal domain generalization and probabilistic methods, while offering easy extensibility to a wide range of other methods including adversarial methods, self-supervised learning and other training paradigms. Compared to existing solutions, DomainLab uncouples the factors that contribute to the performance of a domain generalization method. How the data are split, which neural network architectures and loss functions are used, how the weights are updated, and which evaluation protocol is applied are defined independently. In that way, the user can take any combination and evaluate its impact on generalization performance. +We introduce DomainLab, a PyTorch based Python package for domain generalization. DomainLab focuses on causal domain generalization and probabilistic methods, while offering easy extensibility to a wide range of other methods including adversarial methods, self-supervised learning and other training paradigms. Compared to existing solutions, DomainLab uncouples the factors that contribute to the performance of a domain generalization method. How the data are split, which neural network architectures and loss functions are used, how the weights are updated, and which evaluation protocol is applied are defined independently. In that way, the user can take any combination and evaluate its impact on generalization performance. -DomainLab’s documentation is hosted on https://marrlab.github.io/DomainLab and its source code can be found at https://github.com/marrlab/DomainLab. +DomainLab’s documentation is hosted on https://marrlab.github.io/DomainLab and its source code can be found at https://github.com/marrlab/DomainLab. -## Statement of need +## Statement of need -Over the past years, various methods have been proposed addressing different aspects of domain generalization. However, their implementations are often limited to proof-of-concept code, interspersed with custom code for data access, pre-processing, evaluation, etc. This limits the applicability of these methods, affects reproducibility, and restricts the ability to perform comparisons with other state-of-the-art methods. +Over the past years, various methods have been proposed addressing different aspects of domain generalization. However, their implementations are often limited to proof-of-concept code, interspersed with custom code for data access, pre-processing, evaluation, etc. This limits the applicability of these methods, affects reproducibility, and restricts the ability to perform comparisons with other state-of-the-art methods. -DomainBed for the first time provided a common codebase for benchmarking domain generalization methods (Gulrajani and Lopez-Paz 2020), however applying its algorithms to new use-cases requires extensive adaptation of its source code and the neural network backbones are hardcoded. The components of an algorithm have to be all initilalized in the construction function which is not suitable for complicated algorithms which require flexibility and plugin functionality of its components. More recently, we found a concurrent work Dassl, which provides a Python package to benchmark different algorithms such as semi-supervised learning, domain adaptation and domain generalization (Zhou et al. 2021). Its design is more modular than DomainBed. However the documentation does not contain enough details about algorithm implementation and the code base is not well tested. In addition, the authors have not clarified a plan for maintaining the module, while we aim at a long term maintenance of our package. +DomainBed for the first time provided a common codebase for benchmarking domain generalization methods (Gulrajani and Lopez-Paz 2020), however applying its algorithms to new use-cases requires extensive adaptation of its source code and the neural network backbones are hardcoded. The components of an algorithm have to be all initilalized in the construction function which is not suitable for complicated algorithms which require flexibility and plugin functionality of its components. More recently, we found a concurrent work Dassl, which provides a Python package to benchmark different algorithms such as semi-supervised learning, domain adaptation and domain generalization (Zhou et al. 2021). Its design is more modular than DomainBed. However the documentation does not contain enough details about algorithm implementation and the code base is not well tested. In addition, the authors have not clarified a plan for maintaining the module, while we aim at a long term maintenance of our package. -With DomainLab, we introduce a fully modular Python package for domain generalization with Pytorch backend that follows best practices in software design and includes extensive documentation to enable the community to understand and contribute to the code. It contains extensive unit tests as well as end-to-end tests to verify the implemented functionality. +With DomainLab, we introduce a fully modular Python package for domain generalization with Pytorch backend that follows best practices in software design and includes extensive documentation to enable the community to understand and contribute to the code. It contains extensive unit tests as well as end-to-end tests to verify the implemented functionality. An ideal package for domain generalization should decouple the factors that affectmodelperformance. This way, the components that contributed most to a promising result can be isolated, allowing for better comparability between methods. For example -Can the results be ascribed to a more appropriate neural network architecture? -Is the performance impacted by the protocol used to estimate the generalization performance, e.g. the dataset split? +Can the results be ascribed to a more appropriate neural network architecture? +Is the performance impacted by the protocol used to estimate the generalization performance, e.g. the dataset split? Does the model benefit from a special loss function, e.g. because it offers a better regularization to the training of the neural network? ## Description -### General Design +### General Design To address software design issues of existing code bases like DomainBed (Gulrajani and Lopez-Paz 2020) and Dassl (Zhou et al. 2021), and to maximally decouple factors that might affect the performance of domain generalization algorithms, we designed DomainLab with the following features: First, the package offers the user a standalone application to specify the data, data split protocol , pre-processing, neural network backbone, and model loss function, which will not modify the code base of DomainLab. That is, it connects a user’s data to algorithms. Domain generalization algorithms were implemented with a transparent underlying neural network architecture. The concrete neural network architecture can thus be replaced by plugging in an architecture implemented in a python file or by specifying a string of some existing neural network like AlexNet, via command line arguments. Selection of algorithms, neural network components, as well as other components like training procedure are done via the chain-of-responsibility method. Other design patterns like observer pattern, visitor pattern, etc. are also used to improve the decoupling of different factors contributing to the performance of an algorithm (see also Section Components below). (Gamma book see below) -Instead of modifying code across several python files, the package is closed to modification and open to extension. To simply test an algorithm’s performance on `a user’s data, there is no need to change any code inside this repository, the user only needs to extend this repository to fit their requirement by providing custom python files. +Instead of modifying code across several python files, the package is closed to modification and open to extension. To simply test an algorithm’s performance on `a user’s data, there is no need to change any code inside this repository, the user only needs to extend this repository to fit their requirement by providing custom python files. It offers a framework for generating combinations by simply letting the user select elements through command line arguments. (combine tasks, neural network architectures) -With the above design, DomainLab offers users the flexibility to construct custom tasks with their own data, writing custom neural network architectures, and even trying their own algorithms by specifying a python file with custom loss functions. There is no need to change the original code of DomainLab when the user needs to use the domain generalization method to their own application, extend the method with custom neural network and try to discriminate the most significant factor that affects performance. -### Components +With the above design, DomainLab offers users the flexibility to construct custom tasks with their own data, writing custom neural network architectures, and even trying their own algorithms by specifying a python file with custom loss functions. There is no need to change the original code of DomainLab when the user needs to use the domain generalization method to their own application, extend the method with custom neural network and try to discriminate the most significant factor that affects performance. +### Components To achieve the above design goals of decoupling, we used the following components: Models refer to a PyTorch module with a specified loss function containing regularization effect of several domains plus the task-specific loss, which is classification loss for classification task, but stay transparent with respect to the exact neural network architecture, which can be configured by the user via command line arguments. There are two types of models implemented models from publications in the field of domain generalization using causality and probabilistic model based methods -custom models, where the user only needs to specify a python file defining the custom loss function, while remain transparent of the exact neural network used for each submodule. - The common classification loss calculation is done via a parent model class, thus the individual models representing different domain regularization can be reused for other tasks like segmentation by simply inheriting another task loss. +custom models, where the user only needs to specify a python file defining the custom loss function, while remain transparent of the exact neural network used for each submodule. + The common classification loss calculation is done via a parent model class, thus the individual models representing different domain regularization can be reused for other tasks like segmentation by simply inheriting another task loss. Tasks refer to a component, where the user specifies different datasets from different domains and preprocessing specified upon them. There are several types of tasks in DomainLab: Built-in tasks like Color-Mnist, subsampled version of PACS, VLCS, as test utility of algorithms. -TaskFolder: If the data is already organized in a root folder, with different subfolders containing data from different domains and a further level of sub-sub-folders containing data from different classes. -TaskPathFile: This allows the user to specify each domain a text file indicating the path and label for each observation. Thus, the user can choose which portion of the sample to use as training, validation and test. +TaskFolder: If the data is already organized in a root folder, with different subfolders containing data from different domains and a further level of sub-sub-folders containing data from different classes. +TaskPathFile: This allows the user to specify each domain a text file indicating the path and label for each observation. Thus, the user can choose which portion of the sample to use as training, validation and test. -Trainer is the component that directs data flow to the model to calculate loss and back-propagation to update the parameters; several models can share a common trainer. A specific trainer can also be a visitor to models to update coefficients in models during training to implement techniques like warm-up which follows the visitor design pattern from software engineering. -Following the observer pattern, we use separate classes to conduct operations needed to be done after each epoch (e.g. deciding whether to execute early stopping) and after training finishes. -Following the Builder Pattern, we construct each component needed to conduct a domain generalization experiment, including +Trainer is the component that directs data flow to the model to calculate loss and back-propagation to update the parameters; several models can share a common trainer. A specific trainer can also be a visitor to models to update coefficients in models during training to implement techniques like warm-up which follows the visitor design pattern from software engineering. +Following the observer pattern, we use separate classes to conduct operations needed to be done after each epoch (e.g. deciding whether to execute early stopping) and after training finishes. +Following the Builder Pattern, we construct each component needed to conduct a domain generalization experiment, including constructing a trainer which guides the data flow. constructing a concrete neural network architecture and feeding into the model. constructing the evaluator as a callback of what to do after each epoch. diff --git a/run_benchmark_slurm.sh b/run_benchmark_slurm.sh index a9a0be811..7fc77f3c0 100755 --- a/run_benchmark_slurm.sh +++ b/run_benchmark_slurm.sh @@ -14,7 +14,7 @@ then echo "argument 2: DOMAINLAB_CUDA_START_SEED empty, will set to 0" export DOMAINLAB_CUDA_START_SEED=0 # in fact, the smk code will hash empty string to zero, see standalone script, - # but here we just want to ensure the seed is 0 without worrying a different + # but here we just want to ensure the seed is 0 without worrying a different # behavior of the hash function else export DOMAINLAB_CUDA_START_SEED=$2 @@ -32,4 +32,4 @@ echo "Configuration file: $CONFIGFILE" echo "starting seed is: $DOMAINLAB_CUDA_START_SEED" echo "verbose log: $logfile" # Helmholtz -snakemake --profile "examples/yaml/slurm" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" 2>&1 | tee "$logfile" +snakemake --profile "examples/yaml/slurm" --keep-going --keep-incomplete --notemp --cores 3 -s "domainlab/exp_protocol/benchmark.smk" --configfile "$CONFIGFILE" 2>&1 | tee "$logfile" diff --git a/run_benchmark_standalone.sh b/run_benchmark_standalone.sh index 4d0f5dabc..14f152dc7 100755 --- a/run_benchmark_standalone.sh +++ b/run_benchmark_standalone.sh @@ -78,7 +78,7 @@ snakemake --rerun-incomplete --resources nvidia_gpu=$NUMBER_GPUS --cores 4 -s "d # snakemake --keep-going --keep-incomplete --notemp --cores 5 -s "domainlab/exp_protocol/benchmark.smk" --configfile "examples/yaml/helm_runtime_evaluation.yaml" 2>&1 | tee $logfile # Command used to run in the Helmholtz cluster -# snakemake --profile "examples/yaml/slurm" --keep-going --keep-incomplete --notemp --cores 5 -s "domainlab/exp_protocol/benchmark.smk" --configfile "examples/yaml/test_helm_benchmark.yaml" 2>&1 | tee "$logfile" +# snakemake --profile "examples/yaml/slurm" --keep-going --keep-incomplete --notemp --cores 5 -s "domainlab/exp_protocol/benchmark.smk" --configfile "examples/yaml/test_helm_benchmark.yaml" 2>&1 | tee "$logfile" # Command used to run snakemake on a demo benchmark # snakemake -np -s "domainlab/exp_protocol/benchmark.smk" --configfile "examples/yaml/demo_benchmark.yaml" diff --git a/sbatch4submit_slurm_cpu_10days.sh b/sbatch4submit_slurm_cpu_10days.sh index aefacae97..1b59514ae 100755 --- a/sbatch4submit_slurm_cpu_10days.sh +++ b/sbatch4submit_slurm_cpu_10days.sh @@ -2,7 +2,7 @@ VENV="domainlab_py39" BASHRC="~/.bashrc" # source ~/.bash_profile -## +## JOB_NAME="submit10d" PATH_CODE=$1 PATH_OUT_BASE="${PATH_CODE}/job_logs" @@ -16,7 +16,7 @@ echo "#!/bin/bash #SBATCH -J ${JOB_NAME} #SBATCH -o ${PATH_OUT_BASE}/${JOB_NAME}.out #SBATCH -e ${PATH_OUT_BASE}/${JOB_NAME}.err -#SBATCH -p cpu_p +#SBATCH -p cpu_p #SBATCH -t 10-00:00:00 #SBATCH -c 20 #SBATCH --mem=32G diff --git a/sbatch4submit_slurm_cpu_3days.sh b/sbatch4submit_slurm_cpu_3days.sh index cbbd89639..2c827eef7 100755 --- a/sbatch4submit_slurm_cpu_3days.sh +++ b/sbatch4submit_slurm_cpu_3days.sh @@ -2,14 +2,14 @@ VENV="domainlab_py39" BASHRC="~/.bashrc" # source ~/.bash_profile -## +## JOB_NAME="submit" PATH_CODE=$1 PATH_OUT_BASE="${PATH_CODE}/submit_job_logs" mkdir -p $PATH_OUT_BASE PATH_YAML=$2 START_SEED=$3 -ACTIVE_TIME="3-00:00:00" +ACTIVE_TIME="3-00:00:00" job_file="${PATH_OUT_BASE}/${JOB_NAME}.cmd" @@ -19,7 +19,7 @@ echo "#!/bin/bash #SBATCH -J ${JOB_NAME} #SBATCH -o ${PATH_OUT_BASE}/${JOB_NAME}.out #SBATCH -e ${PATH_OUT_BASE}/${JOB_NAME}.err -#SBATCH -p cpu_p +#SBATCH -p cpu_p #SBATCH -t ${ACTIVE_TIME} #SBATCH -c 20 #SBATCH --mem=32G diff --git a/setup.py b/setup.py index aabf11b93..4a68ada45 100644 --- a/setup.py +++ b/setup.py @@ -2,20 +2,23 @@ run python setup.py install to install DomainLab into system """ import os + from setuptools import find_packages, setup + def copy_dir(dir_path="zdata"): # root = os.path.dirname(os.path.abspath(__file__)) root = os.path.normpath(".") base_dir = os.path.join(root, "domainlab", dir_path) - for (dirpath, dirnames, files) in os.walk(base_dir): + for dirpath, dirnames, files in os.walk(base_dir): for f in files: - path = os.path.join(dirpath.split('/', 1)[1], f) + path = os.path.join(dirpath.split("/", 1)[1], f) print(path) yield path + setup( - name='domainlab', + name="domainlab", packages=find_packages(), # include_package_data=True, # data_files=[ @@ -32,12 +35,12 @@ def copy_dir(dir_path="zdata"): # data_files = [ # ('../data', f) for f in copy_dir() # ], - package_data = { - 'zdata': [f for f in copy_dir()], - }, - version='0.4.3', - description='Library of modular domain generalization for deep learning', - url='https://github.com/marrlab/DomainLab', - author='Xudong Sun, et.al.', - license='MIT', + package_data={ + "zdata": [f for f in copy_dir()], + }, + version="0.4.3", + description="Library of modular domain generalization for deep learning", + url="https://github.com/marrlab/DomainLab", + author="Xudong Sun, et.al.", + license="MIT", ) diff --git a/setup_install.sh b/setup_install.sh index 9925fa8b5..52afd955b 100644 --- a/setup_install.sh +++ b/setup_install.sh @@ -1,5 +1,5 @@ # source distribution -python setup.py sdist +python setup.py sdist pip install -e . #!/bin/bash python setup.py develop diff --git a/sh_list_error.sh b/sh_list_error.sh index a337fb462..5f725e15b 100644 --- a/sh_list_error.sh +++ b/sh_list_error.sh @@ -1,5 +1,5 @@ -# find $1 -type f -print0 | xargs -0 grep -li error -# B means before, A means after, some erros have long stack exception message so we need at least +# find $1 -type f -print0 | xargs -0 grep -li error +# B means before, A means after, some erros have long stack exception message so we need at least # 100 lines before the error, the last line usually indicate the root cause of error grep -B 100 -wnr "error" --group-separator="=========begin_slurm_error===============" $1 > slurm_errors.txt cat slurm_errors.txt diff --git a/sh_publish.sh b/sh_publish.sh index 498635cab..cd07b36fa 100644 --- a/sh_publish.sh +++ b/sh_publish.sh @@ -1,6 +1,6 @@ #!/bin/bash -x -v # step 1: in case for a new download of the repository, get the token via https://pypi.org/manage/project/domainlab/settings/, then do -# poetry config pypi-token.pypi [my-token] +# poetry config pypi-token.pypi [my-token] # step 2: change the version in pyproject.toml poetry build # step 3 poetry publish # step 4 diff --git a/tests/__init__.py b/tests/__init__.py index ee34bf5e0..31c61ca83 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -"""Unit tests for DomainLab""" \ No newline at end of file +"""Unit tests for DomainLab""" diff --git a/tests/dset_mnist_color_solo_default_test.py b/tests/dset_mnist_color_solo_default_test.py index a7be9c65e..76e32bfdb 100644 --- a/tests/dset_mnist_color_solo_default_test.py +++ b/tests/dset_mnist_color_solo_default_test.py @@ -25,22 +25,24 @@ def test_color_mnist2(): ds_list.append(DsetMNISTColorSoloDefault(i, "zout")) plot_ds_list(ds_list, "zout/color_0_9.png") + def test_color_mnist3(): """ test_color_mnist """ - dset = DsetMNISTColorSoloDefault(0, "zout", color_scheme="num", raw_split='test') + dset = DsetMNISTColorSoloDefault(0, "zout", color_scheme="num", raw_split="test") plot_ds(dset, "zout/color_solo.png") ds_list = [] for i in range(10): ds_list.append(DsetMNISTColorSoloDefault(i, "zout")) plot_ds_list(ds_list, "zout/color_0_9.png") + def test_color_mnist4(): """ test_color_mnist """ - dset = DsetMNISTColorSoloDefault(0, "zout", color_scheme="back", raw_split='test') + dset = DsetMNISTColorSoloDefault(0, "zout", color_scheme="back", raw_split="test") plot_ds(dset, "zout/color_solo.png") ds_list = [] for i in range(10): diff --git a/tests/dset_poly_domains_mnist_color_default_test.py b/tests/dset_poly_domains_mnist_color_default_test.py index 1847f9df4..1e0661eca 100644 --- a/tests/dset_poly_domains_mnist_color_default_test.py +++ b/tests/dset_poly_domains_mnist_color_default_test.py @@ -2,8 +2,9 @@ merge several solo-color mnist to form a mixed dataset """ -from domainlab.dsets.utils_data import plot_ds from domainlab.dsets.dset_poly_domains_mnist_color_default import DsetMNISTColorMix +from domainlab.dsets.utils_data import plot_ds + def test_color_mnist(): dset = DsetMNISTColorMix(n_domains=3, path="./output/") diff --git a/tests/dset_subfolder_test.py b/tests/dset_subfolder_test.py index d7ded4adc..ccf9bc575 100644 --- a/tests/dset_subfolder_test.py +++ b/tests/dset_subfolder_test.py @@ -4,55 +4,66 @@ """ import pytest -from domainlab.dsets.utils_data import fun_img_path_loader_default from domainlab.dsets.dset_subfolder import DsetSubFolder +from domainlab.dsets.utils_data import fun_img_path_loader_default def test_fun(): - dset = DsetSubFolder(root="data/vlcs_mini/caltech", - list_class_dir=["auto", "vogel"], - loader=fun_img_path_loader_default, - extensions="jpg", - transform=None, - target_transform=None) + dset = DsetSubFolder( + root="data/vlcs_mini/caltech", + list_class_dir=["auto", "vogel"], + loader=fun_img_path_loader_default, + extensions="jpg", + transform=None, + target_transform=None, + ) dset.class_to_idx def test_mixed_codec(): """Check if only images with given extension are loaded.""" - dset = DsetSubFolder(root="data/mixed_codec/caltech", - list_class_dir=["auto", "vogel"], - loader=fun_img_path_loader_default, - extensions=None, - transform=None, - target_transform=None) + dset = DsetSubFolder( + root="data/mixed_codec/caltech", + list_class_dir=["auto", "vogel"], + loader=fun_img_path_loader_default, + extensions=None, + transform=None, + target_transform=None, + ) assert len(dset.samples) == 6 - dset = DsetSubFolder(root="data/mixed_codec/caltech", - list_class_dir=["auto", "vogel"], - loader=fun_img_path_loader_default, - extensions="jpg", - transform=None, - target_transform=None) - assert len(dset.samples) == 4,\ - f"data/mixed_codec contains 4 jpg files, but {len(dset.samples)} were loaded." + dset = DsetSubFolder( + root="data/mixed_codec/caltech", + list_class_dir=["auto", "vogel"], + loader=fun_img_path_loader_default, + extensions="jpg", + transform=None, + target_transform=None, + ) + assert ( + len(dset.samples) == 4 + ), f"data/mixed_codec contains 4 jpg files, but {len(dset.samples)} were loaded." with pytest.raises(ValueError): - DsetSubFolder(root="data/mixed_codec/caltech", - list_class_dir=["auto", "vogel"], - loader=fun_img_path_loader_default, - extensions="jpg", - transform=None, - target_transform=None, - is_valid_file=True) + DsetSubFolder( + root="data/mixed_codec/caltech", + list_class_dir=["auto", "vogel"], + loader=fun_img_path_loader_default, + extensions="jpg", + transform=None, + target_transform=None, + is_valid_file=True, + ) def test_wrong_class_names(): """Check for error if list_class_dir does not match the subfolders.""" with pytest.raises(RuntimeError): - DsetSubFolder(root="data/mixed_codec/caltech", - list_class_dir=["auto", "haus"], - loader=fun_img_path_loader_default, - extensions=None, - transform=None, - target_transform=None) + DsetSubFolder( + root="data/mixed_codec/caltech", + list_class_dir=["auto", "haus"], + loader=fun_img_path_loader_default, + extensions=None, + transform=None, + target_transform=None, + ) diff --git a/tests/dset_utils_data_test.py b/tests/dset_utils_data_test.py index 1e2beac73..3b618146c 100644 --- a/tests/dset_utils_data_test.py +++ b/tests/dset_utils_data_test.py @@ -3,7 +3,7 @@ def test_dset_in_mem_decorator(): - dset = DsetMNISTColorSoloDefault(path ="../data", ind_color=1) + dset = DsetMNISTColorSoloDefault(path="../data", ind_color=1) dset_in_memory = DsetInMemDecorator(dset=dset) dset_in_memory.__len__() - dset_in_memory.__getitem__(0) \ No newline at end of file + dset_in_memory.__getitem__(0) diff --git a/tests/test_argparse.py b/tests/test_argparse.py index 9b64335be..683f66911 100644 --- a/tests/test_argparse.py +++ b/tests/test_argparse.py @@ -7,25 +7,23 @@ import pytest -from domainlab.arg_parser import parse_cmd_args, mk_parser_main, apply_dict_to_args +from domainlab.arg_parser import apply_dict_to_args, mk_parser_main, parse_cmd_args def test_parse_cmd_args_warning(): - """Call argparser for command line - """ - sys.argv = ['main.py'] - with pytest.warns(Warning, match='no algorithm conf specified'): + """Call argparser for command line""" + sys.argv = ["main.py"] + with pytest.warns(Warning, match="no algorithm conf specified"): parse_cmd_args() def test_parse_yml_args(): - """Test argparser with yaml file - """ + """Test argparser with yaml file""" testdir = os.path.dirname(os.path.realpath(__file__)) rootdir = os.path.join(testdir, "..") rootdir = os.path.abspath(rootdir) file_path = os.path.join(rootdir, "examples/yaml/demo_config_single_run_diva.yaml") - sys.argv = ['main.py', '--config=' + file_path] + sys.argv = ["main.py", "--config=" + file_path] args = parse_cmd_args() # Checking if arguments are from demo.yaml @@ -37,13 +35,12 @@ def test_parse_yml_args(): def test_parse_invalid_yml_args(): - """Test argparser with yaml file - """ + """Test argparser with yaml file""" testdir = os.path.dirname(os.path.realpath(__file__)) rootdir = os.path.join(testdir, "..") rootdir = os.path.abspath(rootdir) file_path = os.path.join(rootdir, "examples/yaml/demo_invalid_parameter.yaml") - sys.argv = ['main.py', '--config=' + file_path] + sys.argv = ["main.py", "--config=" + file_path] with pytest.raises(ValueError): parse_cmd_args() @@ -53,7 +50,7 @@ def test_apply_dict_to_args(): """Testing apply_dict_to_args""" parser = mk_parser_main() args = parser.parse_args(args=[]) - data = {'a': 1, 'b': [1, 2], 'model': 'diva'} + data = {"a": 1, "b": [1, 2], "model": "diva"} apply_dict_to_args(args, data, extend=True) assert args.a == 1 - assert args.model == 'diva' + assert args.model == "diva" diff --git a/tests/test_benchmark_plots.py b/tests/test_benchmark_plots.py index a103923e8..f8018dd5f 100644 --- a/tests/test_benchmark_plots.py +++ b/tests/test_benchmark_plots.py @@ -1,15 +1,18 @@ -''' +""" Test the benchmark plots using some dummy results saved in .csv files -''' +""" from domainlab.utils.generate_benchmark_plots import gen_benchmark_plots def test_benchm_plots(): - ''' + """ test benchmark plots - ''' - gen_benchmark_plots('data/ztest_files/aggret_res_test1', - 'zoutput/benchmark_plots_test/outp1', - use_param_index=False) - gen_benchmark_plots('data/ztest_files/aggret_res_test2', - 'zoutput/benchmark_plots_test/outp2') + """ + gen_benchmark_plots( + "data/ztest_files/aggret_res_test1", + "zoutput/benchmark_plots_test/outp1", + use_param_index=False, + ) + gen_benchmark_plots( + "data/ztest_files/aggret_res_test2", "zoutput/benchmark_plots_test/outp2" + ) diff --git a/tests/test_custom_model.py b/tests/test_custom_model.py index 3a477cc27..875d3e092 100644 --- a/tests/test_custom_model.py +++ b/tests/test_custom_model.py @@ -1,10 +1,12 @@ +import gc import os + import pytest import torch -import gc + +from domainlab.arg_parser import mk_parser_main from domainlab.exp.exp_main import Exp from domainlab.models.model_custom import AModelCustom -from domainlab.arg_parser import mk_parser_main def test_custom(): @@ -13,9 +15,12 @@ def test_custom(): rootdir = os.path.abspath(rootdir) mpath = os.path.join(rootdir, "examples/models/demo_custom_model.py") parser = mk_parser_main() - argsstr = "--te_d=caltech --task=mini_vlcs --model=custom --bs=2 --debug \ + argsstr = ( + "--te_d=caltech --task=mini_vlcs --model=custom --bs=2 --debug \ --apath=%s --nname_argna2val my_custom_arg_name \ - --nname_argna2val alexnet" % (mpath) + --nname_argna2val alexnet" + % (mpath) + ) margs = parser.parse_args(argsstr.split()) exp = Exp(margs) exp.trainer.before_tr() @@ -33,9 +38,12 @@ def test_custom2(): mpath = os.path.join(rootdir, "examples/models/demo_custom_model.py") path_net = os.path.join(rootdir, "examples/nets/resnet.py") parser = mk_parser_main() - argsstr = "--te_d=caltech --task=mini_vlcs --model=custom --bs=2 --debug \ + argsstr = ( + "--te_d=caltech --task=mini_vlcs --model=custom --bs=2 --debug \ --apath=%s --npath_argna2val my_custom_arg_name \ - --npath_argna2val %s" % (mpath, path_net) + --npath_argna2val %s" + % (mpath, path_net) + ) margs = parser.parse_args(argsstr.split()) exp = Exp(margs) exp.trainer.before_tr() @@ -47,24 +55,24 @@ def test_custom2(): def test_no_network_exeption(): - ''' + """ test if we can acess the exeption wen using a costum network which is not a network - ''' + """ parser = mk_parser_main() argsstr = "--te_d=caltech --task=mini_vlcs --debug \ --bs=8 --model=erm --npath=tests/this_is_not_a_network.py" margs = parser.parse_args(argsstr.split()) - with pytest.raises(RuntimeError, match='the pytorch module returned by'): + with pytest.raises(RuntimeError, match="the pytorch module returned by"): Exp(margs) def test_amodelcustom(): - """Test that AModelCustom raises correct NotImplementedErrors - """ + """Test that AModelCustom raises correct NotImplementedErrors""" + class Custom(AModelCustom): - """Dummy class to create an instance of the abstract AModelCustom - """ + """Dummy class to create an instance of the abstract AModelCustom""" + @property def dict_net_module_na2arg_na(self): pass @@ -80,4 +88,3 @@ def dict_net_module_na2arg_na(self): del mod torch.cuda.empty_cache() gc.collect() - diff --git a/tests/test_dann_jigen_transformer.py b/tests/test_dann_jigen_transformer.py index 305a3a69d..80994b142 100644 --- a/tests/test_dann_jigen_transformer.py +++ b/tests/test_dann_jigen_transformer.py @@ -8,18 +8,19 @@ from torchvision.models.feature_extraction import create_feature_extractor from domainlab.mk_exp import mk_exp -from domainlab.tasks import get_task from domainlab.models.model_dann import mk_dann from domainlab.models.model_jigen import mk_jigen +from domainlab.tasks import get_task class VIT(nn.Module): """ Vision transformer as feature extractor """ - def __init__(self, freeze=True, - list_str_last_layer=['getitem_5'], - len_last_layer=768): + + def __init__( + self, freeze=True, list_str_last_layer=["getitem_5"], len_last_layer=768 + ): super().__init__() self.nets = vit_b_16(pretrained=True) if freeze: @@ -28,15 +29,15 @@ def __init__(self, freeze=True, # in case of enough computation resources for param in self.nets.parameters(): param.requires_grad = False - self.features_vit_flatten = \ - create_feature_extractor(self.nets, - return_nodes=list_str_last_layer) + self.features_vit_flatten = create_feature_extractor( + self.nets, return_nodes=list_str_last_layer + ) def forward(self, tensor_x): """ compute logits predicts """ - out = self.features_vit_flatten(tensor_x)['getitem_5'] + out = self.features_vit_flatten(tensor_x)["getitem_5"] return out @@ -52,22 +53,28 @@ def test_transformer(): net_classifier = nn.Linear(768, task.dim_y) # see documentation for each arguments below - model_dann = mk_dann()(net_encoder=net_feature, - net_classifier=net_classifier, - net_discriminator=nn.Linear(768,2), - list_str_y=task.list_str_y, - list_d_tr=["labelme", "sun"], - alpha=1.0) + model_dann = mk_dann()( + net_encoder=net_feature, + net_classifier=net_classifier, + net_discriminator=nn.Linear(768, 2), + list_str_y=task.list_str_y, + list_d_tr=["labelme", "sun"], + alpha=1.0, + ) # see documentation for each argument below - model_jigen = mk_jigen()(net_encoder=net_feature, - net_classifier_class=net_classifier, - net_classifier_permutation=nn.Linear(768, 32), - list_str_y=task.list_str_y, - coeff_reg=1.0, n_perm=31) + model_jigen = mk_jigen()( + net_encoder=net_feature, + net_classifier_class=net_classifier, + net_classifier_permutation=nn.Linear(768, 32), + list_str_y=task.list_str_y, + coeff_reg=1.0, + n_perm=31, + ) - model_dann.extend(model_jigen) # let Jigen decorate DANN + model_dann.extend(model_jigen) # let Jigen decorate DANN model = model_dann # make trainer for model, here we decorate trainer mldg with dial - exp = mk_exp(task, model, trainer="mldg_dial", - test_domain="caltech", batchsize=2, nocu=True) + exp = mk_exp( + task, model, trainer="mldg_dial", test_domain="caltech", batchsize=2, nocu=True + ) exp.execute(num_epochs=2) diff --git a/tests/test_decorate_model.py b/tests/test_decorate_model.py index b947263f3..463fbb4ad 100644 --- a/tests/test_decorate_model.py +++ b/tests/test_decorate_model.py @@ -6,12 +6,12 @@ from torchvision import models as torchvisionmodels from torchvision.models import ResNet50_Weights -from domainlab.mk_exp import mk_exp from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault -from domainlab.tasks.task_dset import mk_task_dset +from domainlab.mk_exp import mk_exp +from domainlab.models.model_dann import mk_dann from domainlab.models.model_jigen import mk_jigen +from domainlab.tasks.task_dset import mk_task_dset from domainlab.tasks.utils_task import ImSize -from domainlab.models.model_dann import mk_dann def test_mk_exp_jigen(): @@ -27,16 +27,22 @@ def mk_exp_jigen(trainer="mldg"): """ # specify domain generalization task: - task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify parameters num_output_net_classifier = task.dim_y @@ -46,26 +52,45 @@ def mk_exp_jigen(trainer="mldg"): coeff_reg = 1e-3 # specify net encoder - net_encoder = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + net_encoder = torchvisionmodels.resnet.resnet50( + weights=ResNet50_Weights.IMAGENET1K_V2 + ) num_output_net_encoder = net_encoder.fc.out_features # specify permutation classifier as linear network - net_permutation_classifier = nn.Linear(num_output_net_encoder, num_output_net_permutation) + net_permutation_classifier = nn.Linear( + num_output_net_encoder, num_output_net_permutation + ) # specify label classifier as linear network net_classifier = nn.Linear(num_output_net_encoder, num_output_net_classifier) # specify model to use - model = mk_jigen()(list_str_y, net_encoder, - net_classifier, net_permutation_classifier, coeff_reg, meta_info={"nperm":num_output_net_permutation}) + model = mk_jigen()( + list_str_y, + net_encoder, + net_classifier, + net_permutation_classifier, + coeff_reg, + meta_info={"nperm": num_output_net_permutation}, + ) num_output_net_discriminator = 2 net_discriminator = nn.Linear(num_output_net_encoder, num_output_net_discriminator) alpha = 0.3 - model2 = mk_dann()(list_str_y, ["domain2", "domain3"], alpha, net_encoder, net_classifier, net_discriminator) + model2 = mk_dann()( + list_str_y, + ["domain2", "domain3"], + alpha, + net_encoder, + net_classifier, + net_discriminator, + ) model.extend(model2) # make trainer for model - exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32, nocu=True) + exp = mk_exp( + task, model, trainer=trainer, test_domain="domain1", batchsize=32, nocu=True + ) exp.execute(num_epochs=2) diff --git a/tests/test_decorate_model_cmd.py b/tests/test_decorate_model_cmd.py index 05e3b9a0a..399b90181 100644 --- a/tests/test_decorate_model_cmd.py +++ b/tests/test_decorate_model_cmd.py @@ -13,6 +13,7 @@ def test_cmd_model_erm_decorator_diva(): --gamma_y=10e5 --gamma_d=1e5" utils_test_algo(args) + def test_cmd_model_dann_decorator_diva(): """ trainer decorator diff --git a/tests/test_deepalldann.py b/tests/test_deepalldann.py index 251584dc6..962808452 100644 --- a/tests/test_deepalldann.py +++ b/tests/test_deepalldann.py @@ -1,11 +1,13 @@ """ unit and end-end test for deep all, dann """ -import os import gc +import os + import torch -from domainlab.exp.exp_main import Exp + from domainlab.arg_parser import mk_parser_main +from domainlab.exp.exp_main import Exp from tests.utils_test import utils_test_algo @@ -14,11 +16,20 @@ def test_erm(): unit deep all """ parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "erm", "--bs", "2", - "--nname", "conv_bn_pool_2" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "erm", + "--bs", + "2", + "--nname", + "conv_bn_pool_2", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) @@ -37,11 +48,20 @@ def test_erm_res(): path = os.path.join(rootdir, "examples/nets/resnet.py") parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "erm", "--bs", "2", - "--npath", f"{path}" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "erm", + "--bs", + "2", + "--npath", + f"{path}", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) @@ -60,12 +80,20 @@ def test_erm_resdombed(): path = os.path.join(rootdir, "examples/nets/resnet50domainbed.py") parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "erm", - "--bs", "2", - "--npath", f"{path}" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "erm", + "--bs", + "2", + "--npath", + f"{path}", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) @@ -79,12 +107,22 @@ def test_dann(): domain adversarial neural network """ parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "dann", "--bs", "2", - "--nname", "conv_bn_pool_2", - "--gamma_reg", "1.0" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "dann", + "--bs", + "2", + "--nname", + "conv_bn_pool_2", + "--gamma_reg", + "1.0", + ] + ) exp = Exp(margs) exp.execute() del exp @@ -103,14 +141,25 @@ def test_dann_dial(): def test_sanity_check(): """Sanity check of the dataset""" parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "dann", "--bs", "2", - "--nname", "conv_bn_pool_2", - "--gamma_reg", "1.0", - "--san_check", - "--san_num", "4" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "dann", + "--bs", + "2", + "--nname", + "conv_bn_pool_2", + "--gamma_reg", + "1.0", + "--san_check", + "--san_num", + "4", + ] + ) exp = Exp(margs) exp.execute() del exp diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 66e3ea4ef..4b70044ac 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -1,13 +1,15 @@ -''' +""" Code coverage issues: https://app.codecov.io/gh/marrlab/DomainLab/blob/master/domainlab/compos/nn_zoo/net_conv_conv_bn_pool_2.py - lines 66-67, 69-71, 73 - lines 79-82 -''' +""" import gc + import torch -from domainlab.compos.nn_zoo.nn import DenseNet + from domainlab.compos.nn_zoo.net_conv_conv_bn_pool_2 import NetConvDense +from domainlab.compos.nn_zoo.nn import DenseNet def test_netconvdense1(): @@ -15,9 +17,9 @@ def test_netconvdense1(): test convdensenet """ inpu = torch.randn(1, 3, 28, 28) - model = NetConvDense(isize=(3, 28, 28),\ - conv_stride=1, dim_out_h=32,\ - args=None, dense_layer=None) + model = NetConvDense( + isize=(3, 28, 28), conv_stride=1, dim_out_h=32, args=None, dense_layer=None + ) model(inpu) del model torch.cuda.empty_cache() @@ -30,9 +32,13 @@ def test_netconvdense2(): """ inpu = torch.randn(1, 3, 28, 28) dense_layers = DenseNet(1024, out_hidden_size=32) - model = NetConvDense(isize=(3, 28, 28),\ - conv_stride=1, dim_out_h=32,\ - args=None, dense_layer=dense_layers) + model = NetConvDense( + isize=(3, 28, 28), + conv_stride=1, + dim_out_h=32, + args=None, + dense_layer=dense_layers, + ) model(inpu) del model torch.cuda.empty_cache() diff --git a/tests/test_dial.py b/tests/test_dial.py index 54d8d2acc..5a5a770d8 100644 --- a/tests/test_dial.py +++ b/tests/test_dial.py @@ -3,9 +3,11 @@ so it is easier to identify which algorithm has a problem """ import gc + import torch -from domainlab.exp.exp_main import Exp + from domainlab.arg_parser import mk_parser_main +from domainlab.exp.exp_main import Exp def test_trainer_dial(): @@ -13,15 +15,25 @@ def test_trainer_dial(): end to end test """ parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "0", - "--task", "mnistcolor10", - "--model", "erm", "--bs", "2", - "--trainer", "dial", - "--nname", "conv_bn_pool_2"]) + margs = parser.parse_args( + [ + "--te_d", + "0", + "--task", + "mnistcolor10", + "--model", + "erm", + "--bs", + "2", + "--trainer", + "dial", + "--nname", + "conv_bn_pool_2", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) del exp torch.cuda.empty_cache() gc.collect() - diff --git a/tests/test_diva.py b/tests/test_diva.py index 40f79bbcc..b53bc50e5 100644 --- a/tests/test_diva.py +++ b/tests/test_diva.py @@ -1,8 +1,10 @@ -import os import gc +import os + import torch -from domainlab.exp.exp_main import Exp + from domainlab.arg_parser import mk_parser_main +from domainlab.exp.exp_main import Exp from tests.utils_test import utils_test_algo @@ -10,9 +12,12 @@ def test_dial_diva(): """ the combination of dial and diva: use dial trainer to train diva model """ - utils_test_algo("--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=diva \ + utils_test_algo( + "--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=diva \ --nname=conv_bn_pool_2 --nname_dom=conv_bn_pool_2 \ - --gamma_y=7e5 --gamma_d=1e5 --trainer=dial") + --gamma_y=7e5 --gamma_d=1e5 --trainer=dial" + ) + def test_diva(): parser = mk_parser_main() @@ -29,17 +34,28 @@ def test_diva(): gc.collect() - def test_trainer_diva(): parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "diva", "--bs", "2", - "--nname", "conv_bn_pool_2", - "--gamma_y", "7e5", - "--gamma_d", "7e5", - "--nname_dom", "conv_bn_pool_2" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "diva", + "--bs", + "2", + "--nname", + "conv_bn_pool_2", + "--gamma_y", + "7e5", + "--gamma_d", + "7e5", + "--nname_dom", + "conv_bn_pool_2", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) @@ -48,21 +64,32 @@ def test_trainer_diva(): gc.collect() - def test_trainer_diva_folder(): testdir = os.path.dirname(os.path.realpath(__file__)) rootdir = os.path.join(testdir, "..") rootdir = os.path.abspath(rootdir) path = os.path.join(rootdir, "examples/tasks/task_vlcs.py") parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--tpath", "%s" % (path), - "--model", "diva", "--bs", "2", - "--nname", "conv_bn_pool_2", - "--gamma_y", "7e5", - "--gamma_d", "7e5", - "--nname_dom", "conv_bn_pool_2" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--tpath", + "%s" % (path), + "--model", + "diva", + "--bs", + "2", + "--nname", + "conv_bn_pool_2", + "--gamma_y", + "7e5", + "--gamma_d", + "7e5", + "--nname_dom", + "conv_bn_pool_2", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) @@ -71,25 +98,35 @@ def test_trainer_diva_folder(): gc.collect() - def test_trainer_diva_pathlist(): testdir = os.path.dirname(os.path.realpath(__file__)) rootdir = os.path.join(testdir, "..") rootdir = os.path.abspath(rootdir) path = os.path.join(rootdir, "examples/tasks/demo_task_path_list_small.py") parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "sketch", - "--tpath", "%s" % (path), - "--model", "diva", "--bs", "2", - "--nname", "conv_bn_pool_2", - "--gamma_y", "7e5", - "--gamma_d", "7e5", - "--nname_dom", "conv_bn_pool_2" - ]) + margs = parser.parse_args( + [ + "--te_d", + "sketch", + "--tpath", + "%s" % (path), + "--model", + "diva", + "--bs", + "2", + "--nname", + "conv_bn_pool_2", + "--gamma_y", + "7e5", + "--gamma_d", + "7e5", + "--nname_dom", + "conv_bn_pool_2", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) del exp torch.cuda.empty_cache() gc.collect() - diff --git a/tests/test_encoder.py b/tests/test_encoder.py index eca4e1354..790f84d0b 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -1,6 +1,9 @@ import torch -from domainlab.compos.vae.compos.encoder_xyd_parallel import XYDEncoderParallelConvBnReluPool + from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool +from domainlab.compos.vae.compos.encoder_xyd_parallel import ( + XYDEncoderParallelConvBnReluPool, +) def test_XYDEncoderConvBnReluPool(): @@ -13,10 +16,12 @@ def test_XYDEncoderConvBnReluPool(): def test_LSEncoderConvStride1BnReluPool(): """test""" from domainlab.utils.test_img import mk_img + img_size = 28 img = mk_img(img_size) - model = LSEncoderConvBnReluPool(z_dim=8, i_channel=3, i_h=img_size, i_w=img_size, - conv_stride=1) + model = LSEncoderConvBnReluPool( + z_dim=8, i_channel=3, i_h=img_size, i_w=img_size, conv_stride=1 + ) q_zd, zd_q = model(img) q_zd.mean q_zd.scale diff --git a/tests/test_encoder_domain_topic.py b/tests/test_encoder_domain_topic.py index 43749ecbf..fd1376ea7 100644 --- a/tests/test_encoder_domain_topic.py +++ b/tests/test_encoder_domain_topic.py @@ -1,18 +1,16 @@ import torch -from domainlab.compos.vae.compos.encoder_domain_topic import EncoderSandwichTopicImg2Zd + from domainlab.arg_parser import mk_parser_main +from domainlab.compos.vae.compos.encoder_domain_topic import EncoderSandwichTopicImg2Zd def test_TopicImg2Zd(): parser = mk_parser_main() - args = parser.parse_args([ - "--te_d", "9", "--dpath", - "zout", "--split", "0.8"]) + args = parser.parse_args(["--te_d", "9", "--dpath", "zout", "--split", "0.8"]) args.nname_encoder_sandwich_x2h4zd = "conv_bn_pool_2" model = EncoderSandwichTopicImg2Zd( - zd_dim=64, isize=(3,64,64), - num_topics=5, img_h_dim=1024, - args=args) + zd_dim=64, isize=(3, 64, 64), num_topics=5, img_h_dim=1024, args=args + ) x = torch.rand(20, 3, 64, 64) topic = torch.rand(20, 5) _, _ = model(x, topic) diff --git a/tests/test_exp_protocol.py b/tests/test_exp_protocol.py index 5d81d0e86..2f9ca03ea 100644 --- a/tests/test_exp_protocol.py +++ b/tests/test_exp_protocol.py @@ -10,12 +10,18 @@ import yaml from domainlab.arg_parser import mk_parser_main -from domainlab.exp_protocol.aggregate_results import agg_results, agg_main -from domainlab.exp_protocol.run_experiment import run_experiment, apply_dict_to_args +from domainlab.exp_protocol.aggregate_results import agg_main, agg_results +from domainlab.exp_protocol.run_experiment import apply_dict_to_args, run_experiment + def test_run_experiment(): - utils_run_experiment("examples/benchmark/demo_benchmark.yaml", list_test_domains=['caltech']) - utils_run_experiment("examples/benchmark/demo_benchmark_mnist4test.yaml", ['0'], no_run=False) + utils_run_experiment( + "examples/benchmark/demo_benchmark.yaml", list_test_domains=["caltech"] + ) + utils_run_experiment( + "examples/benchmark/demo_benchmark_mnist4test.yaml", ["0"], no_run=False + ) + def utils_run_experiment(yaml_name, list_test_domains, no_run=True): """Checks the run_experiment function on a minimal basis""" @@ -23,31 +29,30 @@ def utils_run_experiment(yaml_name, list_test_domains, no_run=True): config = yaml.safe_load(stream) if torch.cuda.is_available(): torch.cuda.init() - config['epos'] = 1 - config['startseed'] = 1 - config['endseed'] = 1 - config['test_domains'] = list_test_domains + config["epos"] = 1 + config["startseed"] = 1 + config["endseed"] = 1 + config["test_domains"] = list_test_domains param_file = "data/ztest_files/test_parameter_samples.csv" param_index = 0 out_file = "zoutput/benchmarks/demo_benchmark/rule_results/0.csv" # setting misc={'testing': True} will disable experiment being executed - run_experiment(config, param_file, param_index, out_file, misc={'testing': True}) + run_experiment(config, param_file, param_index, out_file, misc={"testing": True}) # setting test_domain equals zero will also not execute the experiment if no_run: - config['test_domains'] = [] + config["test_domains"] = [] run_experiment(config, param_file, param_index, out_file) - def test_apply_dict_to_args(): """Testing apply_dict_to_args""" parser = mk_parser_main() args = parser.parse_args(args=[]) - data = {'a': 1, 'b': [1, 2], 'model': 'diva'} + data = {"a": 1, "b": [1, 2], "model": "diva"} apply_dict_to_args(args, data, extend=True) assert args.a == 1 - assert args.model == 'diva' + assert args.model == "diva" def create_agg_input_files() -> List[str]: @@ -57,7 +62,7 @@ def create_agg_input_files() -> List[str]: f_0 = test_dir + "/0.csv" f_1 = test_dir + "/1.csv" - with open(f_0, 'w') as stream: + with open(f_0, "w") as stream: stream.write( "param_index, method, algo, epos, te_d, seed, params, acc," " precision, recall, specificity, f1, aurocy\n" @@ -66,7 +71,7 @@ def create_agg_input_files() -> List[str]: " 0.80833334, 0.80833334, 0.82705104, 0.98333335\n" ) - with open(f_1, 'w') as stream: + with open(f_1, "w") as stream: stream.write( "param_index, method, algo, epos, te_d, seed, params, acc," " precision, recall, specificity, f1, aurocy\n" @@ -103,14 +108,16 @@ def agg_output_file() -> str: @pytest.fixture def agg_expected_output() -> str: """Expected result file content for the agg tests.""" - return "param_index, method, algo, epos, te_d, seed, params, acc," \ - " precision, recall, specificity, f1, aurocy\n" \ - "0, diva, diva, 2, caltech, 1, \"{'gamma_y': 682408," \ - " 'gamma_d': 275835}\", 0.88461536, 0.852381,"\ - " 0.80833334, 0.80833334, 0.82705104, 0.98333335\n"\ - "1, hduva, hduva, 2, caltech, 1, \"{'gamma_y': 70037," \ - " 'zy_dim': 48}\", 0.7307692, 0.557971,"\ - " 0.5333333, 0.5333333, 0.5297158, 0.73333335" + return ( + "param_index, method, algo, epos, te_d, seed, params, acc," + " precision, recall, specificity, f1, aurocy\n" + "0, diva, diva, 2, caltech, 1, \"{'gamma_y': 682408," + " 'gamma_d': 275835}\", 0.88461536, 0.852381," + " 0.80833334, 0.80833334, 0.82705104, 0.98333335\n" + "1, hduva, hduva, 2, caltech, 1, \"{'gamma_y': 70037," + " 'zy_dim': 48}\", 0.7307692, 0.557971," + " 0.5333333, 0.5333333, 0.5297158, 0.73333335" + ) @pytest.fixture @@ -125,9 +132,9 @@ def bm_config(): def compare_file_content(filename: str, expected: str) -> bool: """Returns true if the given file contains the given string.""" - with open(filename, 'r') as stream: + with open(filename, "r") as stream: content = stream.readlines() - return ''.join(content) == expected + return "".join(content) == expected def test_agg_results(agg_input_files, agg_output_file, agg_expected_output): diff --git a/tests/test_git_tag.py b/tests/test_git_tag.py index fd9c2ca61..2bd96e146 100644 --- a/tests/test_git_tag.py +++ b/tests/test_git_tag.py @@ -1,21 +1,23 @@ -''' +""" Code coverage issues: https://app.codecov.io/gh/marrlab/DomainLab/blob/master/domainlab/utils/get_git_tag.py - lines 10-20 - lines 28, 30-32 -''' +""" from domainlab.utils.get_git_tag import get_git_tag + def test_git_tag(): """ test git_tag """ get_git_tag(print_diff=True) + def test_git_tag_error(): - ''' + """ test git_tag error - ''' + """ # add one line to the file with open("data/ztest_files/dummy_file.py", "a") as f: f.write("\n# I am a dummy command") diff --git a/tests/test_hduva.py b/tests/test_hduva.py index c93aab9f0..1e904ab23 100644 --- a/tests/test_hduva.py +++ b/tests/test_hduva.py @@ -3,20 +3,25 @@ """ import gc + import torch -from domainlab.exp.exp_main import Exp + from domainlab.arg_parser import mk_parser_main +from domainlab.exp.exp_main import Exp from tests.utils_test import utils_test_algo + def test_hduva_zx_nonzero(): """ the combination of dial and diva: use dial trainer to train diva model """ - utils_test_algo("--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --zx_dim=8 \ + utils_test_algo( + "--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --zx_dim=8 \ --model=hduva --nname=conv_bn_pool_2 \ --nname_encoder_x2topic_h=conv_bn_pool_2 \ --gamma_y=7e5 \ - --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2") + --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2" + ) def test_trainer_hduva(): @@ -24,14 +29,26 @@ def test_trainer_hduva(): end to end test """ parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "hduva", "--bs", "2", - "--nname", "alexnet", - "--gamma_y", "7e5", - "--nname_encoder_x2topic_h", "conv_bn_pool_2", - "--nname_encoder_sandwich_x2h4zd", "conv_bn_pool_2" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "hduva", + "--bs", + "2", + "--nname", + "alexnet", + "--gamma_y", + "7e5", + "--nname_encoder_x2topic_h", + "conv_bn_pool_2", + "--nname_encoder_sandwich_x2h4zd", + "conv_bn_pool_2", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) diff --git a/tests/test_hduva_dirichlet_encoder.py b/tests/test_hduva_dirichlet_encoder.py index b166cfb02..749b0afa6 100644 --- a/tests/test_hduva_dirichlet_encoder.py +++ b/tests/test_hduva_dirichlet_encoder.py @@ -2,12 +2,12 @@ end to end test """ import torch + from domainlab.compos.vae.compos.encoder_dirichlet import EncoderH2Dirichlet def test_unit_encoder_dirichlet(): - encoder_dirichlet = EncoderH2Dirichlet( - dim_topic=3, device=torch.device("cpu")) + encoder_dirichlet = EncoderH2Dirichlet(dim_topic=3, device=torch.device("cpu")) feat_hidden_uniform01 = torch.rand(32, 3) # batchsize 32 encoder_dirichlet(feat_hidden_uniform01) feat_hidden_normal = torch.normal(0, 1, size=(32, 3)) diff --git a/tests/test_hyperparameter_sampling.py b/tests/test_hyperparameter_sampling.py index 790625af5..5142cadd6 100644 --- a/tests/test_hyperparameter_sampling.py +++ b/tests/test_hyperparameter_sampling.py @@ -6,11 +6,16 @@ import pytest import yaml -from domainlab.utils.hyperparameter_sampling import \ - sample_hyperparameters, sample_parameters, get_hyperparameter, \ - Hyperparameter, SampledHyperparameter, G_MODEL_NA, G_METHOD_NA -from domainlab.utils.hyperparameter_gridsearch import \ - sample_gridsearch +from domainlab.utils.hyperparameter_gridsearch import sample_gridsearch +from domainlab.utils.hyperparameter_sampling import ( + G_METHOD_NA, + G_MODEL_NA, + Hyperparameter, + SampledHyperparameter, + get_hyperparameter, + sample_hyperparameters, + sample_parameters, +) from tests.utils_test import assert_frame_not_equal @@ -21,189 +26,246 @@ def test_hyperparameter_sampling(): samples = sample_hyperparameters(config) - a1samples = samples[samples[G_MODEL_NA] == 'Algo1'] - for par in a1samples['params']: - assert par['p1_shared'] < par['p1'] - assert par['p1'] < par['p2'] - assert par['p3'] < par['p2'] - assert par['p2'] % 1 == pytest.approx(0) - assert par['p4'] == par['p3'] - assert par['p5'] == 2 * par['p3'] / par['p1'] - - a2samples = samples[samples[G_MODEL_NA] == 'Algo2'] - for par in a2samples['params']: - assert par['p1'] % 2 == pytest.approx(1) - assert par['p2'] % 1 == pytest.approx(0) - assert par['p3'] == 2 * par['p2'] - p_4 = par['p4'] + a1samples = samples[samples[G_MODEL_NA] == "Algo1"] + for par in a1samples["params"]: + assert par["p1_shared"] < par["p1"] + assert par["p1"] < par["p2"] + assert par["p3"] < par["p2"] + assert par["p2"] % 1 == pytest.approx(0) + assert par["p4"] == par["p3"] + assert par["p5"] == 2 * par["p3"] / par["p1"] + + a2samples = samples[samples[G_MODEL_NA] == "Algo2"] + for par in a2samples["params"]: + assert par["p1"] % 2 == pytest.approx(1) + assert par["p2"] % 1 == pytest.approx(0) + assert par["p3"] == 2 * par["p2"] + p_4 = par["p4"] assert p_4 == 30 or p_4 == 31 or p_4 == 100 assert np.issubdtype(type(p_4), np.integer) - a3samples = samples[samples[G_MODEL_NA] == 'Algo3'] + a3samples = samples[samples[G_MODEL_NA] == "Algo3"] assert not a3samples.empty # test the case with less parameter samples than shared samples - config['num_param_samples'] = 3 + config["num_param_samples"] = 3 sample_hyperparameters(config) def test_fallback_solution_of_sample_parameters(): - ''' + """ trying to meet the constrainds with the pool of presampled shared hyperparameters may not be possible, in this case the shared hyperparameters are sampled accoring to their config. This case is tested in this function - ''' + """ # define a task specific hyperparameter - config = {'distribution': 'uniform', 'min': 0, 'max': 1, 'step': 0} - par = SampledHyperparameter('p1', config) + config = {"distribution": "uniform", "min": 0, "max": 1, "step": 0} + par = SampledHyperparameter("p1", config) init_params = [par] # set a constrained with a shared hyperparameter - constraints = ['p1 > p1_shared'] + constraints = ["p1 > p1_shared"] # set config for shared hyperparameter - shared_config = {'num_shared_param_samples': 2, - 'p1_shared': {'distribution': 'uniform', - 'min': 0, 'max': 10, 'step': 0}} + shared_config = { + "num_shared_param_samples": 2, + "p1_shared": {"distribution": "uniform", "min": 0, "max": 10, "step": 0}, + } # set the shared samples to values which do never meet the # constrained with the task specific hyperparameter shared_samples = pd.DataFrame( - [['all', 'all', {'p1_shared': 5}], - ['all', 'all', {'p1_shared': 6}]], - columns=[G_METHOD_NA, G_MODEL_NA, 'params'] + [["all", "all", {"p1_shared": 5}], ["all", "all", {"p1_shared": 6}]], + columns=[G_METHOD_NA, G_MODEL_NA, "params"], + ) + sample_parameters( + init_params, + constraints, + shared_config=shared_config, + shared_samples=shared_samples, ) - sample_parameters(init_params, constraints, - shared_config=shared_config, - shared_samples=shared_samples) def test_hyperparameter_gridsearch(): """Test sampling from yaml, including constraints""" - with open("examples/yaml/demo_hyperparameter_gridsearch.yml", "r", encoding="utf-8") \ - as stream: + with open( + "examples/yaml/demo_hyperparameter_gridsearch.yml", "r", encoding="utf-8" + ) as stream: config = yaml.safe_load(stream) samples = sample_gridsearch(config) - a1samples = samples[samples[G_MODEL_NA] == 'Algo1'] - for par in a1samples['params']: - assert par['p1'] < par['p2'] - assert par['p3'] < par['p2'] - assert par['p2'] % 1 == pytest.approx(0) - assert np.issubdtype(type(par['p2']), np.integer) - assert par['p4'] == par['p3'] - assert par['p5'] == 2 * par['p3'] / par['p1'] - assert par['p1_shared'] == par['p1'] - assert np.issubdtype(type(par['p9']), np.integer) - assert par['p10'] % 1 == 0.5 - - a2samples = samples[samples[G_MODEL_NA] == 'Algo2'] - for par in a2samples['params']: - assert par['p1'] % 2 == pytest.approx(1) - assert par['p2'] % 1 == pytest.approx(0) - assert par['p3'] == 2 * par['p2'] - p_4 = par['p4'] + a1samples = samples[samples[G_MODEL_NA] == "Algo1"] + for par in a1samples["params"]: + assert par["p1"] < par["p2"] + assert par["p3"] < par["p2"] + assert par["p2"] % 1 == pytest.approx(0) + assert np.issubdtype(type(par["p2"]), np.integer) + assert par["p4"] == par["p3"] + assert par["p5"] == 2 * par["p3"] / par["p1"] + assert par["p1_shared"] == par["p1"] + assert np.issubdtype(type(par["p9"]), np.integer) + assert par["p10"] % 1 == 0.5 + + a2samples = samples[samples[G_MODEL_NA] == "Algo2"] + for par in a2samples["params"]: + assert par["p1"] % 2 == pytest.approx(1) + assert par["p2"] % 1 == pytest.approx(0) + assert par["p3"] == 2 * par["p2"] + p_4 = par["p4"] assert p_4 == 30 or p_4 == 31 or p_4 == 100 - assert np.issubdtype(type(par['p4']), np.integer) - assert 'p2_shared' not in par.keys() + assert np.issubdtype(type(par["p4"]), np.integer) + assert "p2_shared" not in par.keys() - a3samples = samples[samples[G_MODEL_NA] == 'Algo3'] + a3samples = samples[samples[G_MODEL_NA] == "Algo3"] assert not a3samples.empty - assert 'p1_shared' not in a3samples.keys() - assert 'p2_shared' not in a3samples.keys() + assert "p1_shared" not in a3samples.keys() + assert "p2_shared" not in a3samples.keys() def test_gridhyperparameter_errors(): """Test for the errors which may occour in the sampling of the grid""" with pytest.raises(RuntimeError, match="distance between max and min to small"): - sample_gridsearch({'output_dir': "zoutput/benchmarks/test", - 'Task1': {'model': 'Algo1', - 'hyperparameters': - {'p1':{'min': 0, 'max': 1, 'step': 5, - 'distribution': 'uniform', 'num': 2}}}}) - - with pytest.raises(RuntimeError, match="distribution \"random\" not implemented"): - sample_gridsearch({'output_dir': "zoutput/benchmarks/test", - 'Task1': {'model': 'Algo1', - 'hyperparameters': - {'p1':{'min': 0, 'max': 1, 'step': 0, - 'distribution': 'random', 'num': 2}}}}) + sample_gridsearch( + { + "output_dir": "zoutput/benchmarks/test", + "Task1": { + "model": "Algo1", + "hyperparameters": { + "p1": { + "min": 0, + "max": 1, + "step": 5, + "distribution": "uniform", + "num": 2, + } + }, + }, + } + ) + + with pytest.raises(RuntimeError, match='distribution "random" not implemented'): + sample_gridsearch( + { + "output_dir": "zoutput/benchmarks/test", + "Task1": { + "model": "Algo1", + "hyperparameters": { + "p1": { + "min": 0, + "max": 1, + "step": 0, + "distribution": "random", + "num": 2, + } + }, + }, + } + ) with pytest.raises(RuntimeError, match="No valid value found"): - sample_gridsearch({'output_dir': "zoutput/benchmarks/test", - 'Task1': {'model': 'Algo1', - 'hyperparameters': - {'p1':{'min': 2, 'max': 3.5, 'step': 1, - 'distribution': 'uniform', 'num': 2}, - 'p2':{'min': 0, 'max': 1.5, 'step': 1, - 'distribution': 'uniform', 'num': 2}, - 'constraints': ['p1 < p2'] - }}}) - - with pytest.raises(RuntimeError, match="the number of parameters in the grid " - "direction of p1 needs to be specified"): - sample_gridsearch({'output_dir': "zoutput/benchmarks/test", - 'Task1': {'model': 'Algo1', - 'hyperparameters': - {'p1': {'min': 0, 'max': 1, 'step': 0, - 'distribution': 'uniform'}}}}) + sample_gridsearch( + { + "output_dir": "zoutput/benchmarks/test", + "Task1": { + "model": "Algo1", + "hyperparameters": { + "p1": { + "min": 2, + "max": 3.5, + "step": 1, + "distribution": "uniform", + "num": 2, + }, + "p2": { + "min": 0, + "max": 1.5, + "step": 1, + "distribution": "uniform", + "num": 2, + }, + "constraints": ["p1 < p2"], + }, + }, + } + ) + + with pytest.raises( + RuntimeError, + match="the number of parameters in the grid " + "direction of p1 needs to be specified", + ): + sample_gridsearch( + { + "output_dir": "zoutput/benchmarks/test", + "Task1": { + "model": "Algo1", + "hyperparameters": { + "p1": {"min": 0, "max": 1, "step": 0, "distribution": "uniform"} + }, + }, + } + ) def test_hyperparameter_errors(): """Test for errors on unknown distribution or missing keys""" with pytest.raises(RuntimeError, match="Datatype unknown"): - par = get_hyperparameter('name', {'reference': 'a'}) + par = get_hyperparameter("name", {"reference": "a"}) par.datatype() - with pytest.raises(RuntimeError, match='Unsupported distribution'): - get_hyperparameter('name', {'distribution': 'unknown'}) + with pytest.raises(RuntimeError, match="Unsupported distribution"): + get_hyperparameter("name", {"distribution": "unknown"}) - with pytest.raises(RuntimeError, match='Missing required key'): - get_hyperparameter('name', {'distribution': 'uniform'}) + with pytest.raises(RuntimeError, match="Missing required key"): + get_hyperparameter("name", {"distribution": "uniform"}) - par = get_hyperparameter('name', {'distribution': 'uniform', 'min': 0, 'max': 1}) - par.distribution = 'unknown' - with pytest.raises(RuntimeError, match='Unsupported distribution'): + par = get_hyperparameter("name", {"distribution": "uniform", "min": 0, "max": 1}) + par.distribution = "unknown" + with pytest.raises(RuntimeError, match="Unsupported distribution"): par.sample() par.get_val() def test_constraint_error(): """Check error on invalid syntax in constraints""" - par = get_hyperparameter('name', {'distribution': 'uniform', 'min': 0, 'max': 1}) + par = get_hyperparameter("name", {"distribution": "uniform", "min": 0, "max": 1}) constraints = ["hello world"] - with pytest.raises(SyntaxError, match='Invalid syntax in yaml config'): + with pytest.raises(SyntaxError, match="Invalid syntax in yaml config"): sample_parameters([par], constraints) def test_sample_parameters_abort(): """Test for error on infeasible constraints""" - p_1 = get_hyperparameter('p1', {'distribution': 'uniform', 'min': 0, 'max': 1}) - p_2 = get_hyperparameter('p2', {'distribution': 'uniform', 'min': 2, 'max': 3}) - constraints = ['p2 < p1'] # impossible due to the bounds - with pytest.raises(RuntimeError, match='constraints reasonable'): + p_1 = get_hyperparameter("p1", {"distribution": "uniform", "min": 0, "max": 1}) + p_2 = get_hyperparameter("p2", {"distribution": "uniform", "min": 2, "max": 3}) + constraints = ["p2 < p1"] # impossible due to the bounds + with pytest.raises(RuntimeError, match="constraints reasonable"): sample_parameters([p_1, p_2], constraints) def test_sampling_seed(): """Tests if the same hyperparameters are sampled if sampling_seed is set""" - with open("examples/yaml/demo_hyperparameter_sampling.yml", "r", encoding="utf8") as stream: + with open( + "examples/yaml/demo_hyperparameter_sampling.yml", "r", encoding="utf8" + ) as stream: config = yaml.safe_load(stream) - config['sampling_seed'] = 1 + config["sampling_seed"] = 1 - samples1 = sample_hyperparameters(config, sampling_seed=config['sampling_seed']) - samples2 = sample_hyperparameters(config, sampling_seed=config['sampling_seed']) + samples1 = sample_hyperparameters(config, sampling_seed=config["sampling_seed"]) + samples2 = sample_hyperparameters(config, sampling_seed=config["sampling_seed"]) pd.testing.assert_frame_equal(samples1, samples2) def test_sampling_seed_diff(): """Tests if the same hyperparameters are sampled if sampling_seed is set""" - with open("examples/yaml/demo_hyperparameter_sampling.yml", "r", encoding="utf8") as stream: + with open( + "examples/yaml/demo_hyperparameter_sampling.yml", "r", encoding="utf8" + ) as stream: config = yaml.safe_load(stream) - config['sampling_seed'] = 1 - samples1 = sample_hyperparameters(config, sampling_seed=config['sampling_seed']) + config["sampling_seed"] = 1 + samples1 = sample_hyperparameters(config, sampling_seed=config["sampling_seed"]) - config['sampling_seed'] = 2 - samples2 = sample_hyperparameters(config, sampling_seed=config['sampling_seed']) + config["sampling_seed"] = 2 + samples2 = sample_hyperparameters(config, sampling_seed=config["sampling_seed"]) assert_frame_not_equal(samples1, samples2) diff --git a/tests/test_jigen.py b/tests/test_jigen.py index ea8d83269..e04dc17fc 100644 --- a/tests/test_jigen.py +++ b/tests/test_jigen.py @@ -9,21 +9,27 @@ def test_mnist_color_jigen(): """ color minst on jigen """ - utils_test_algo("--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=jigen \ - --nname=conv_bn_pool_2") + utils_test_algo( + "--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=jigen \ + --nname=conv_bn_pool_2" + ) def test_jigen30(): """ end to end test """ - utils_test_algo("--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=jigen \ - --nname=conv_bn_pool_2 --nperm=30") + utils_test_algo( + "--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=jigen \ + --nname=conv_bn_pool_2 --nperm=30" + ) def test_trainer_jigen100(): """ end to end test """ - utils_test_algo("--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=jigen \ - --nname=conv_bn_pool_2 --nperm=100") + utils_test_algo( + "--te_d 0 1 2 --tr_d 3 7 --task=mnistcolor10 --model=jigen \ + --nname=conv_bn_pool_2 --nperm=100" + ) diff --git a/tests/test_matchdg.py b/tests/test_matchdg.py index b3c6f6a16..9f95653db 100644 --- a/tests/test_matchdg.py +++ b/tests/test_matchdg.py @@ -1,22 +1,36 @@ import gc + import torch -from domainlab.exp.exp_main import Exp + from domainlab.arg_parser import mk_parser_main +from domainlab.exp.exp_main import Exp def test_trainer_matchdg(): parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--trainer", "matchdg", "--bs", "2", - "--model", "erm", - "--nname", "conv_bn_pool_2", - "--epochs_ctr", "1", - "--epos", "3"]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--trainer", + "matchdg", + "--bs", + "2", + "--model", + "erm", + "--nname", + "conv_bn_pool_2", + "--epochs_ctr", + "1", + "--epos", + "3", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) del exp torch.cuda.empty_cache() gc.collect() - diff --git a/tests/test_mk_exp_dann.py b/tests/test_mk_exp_dann.py index f161f70f7..50f52013d 100644 --- a/tests/test_mk_exp_dann.py +++ b/tests/test_mk_exp_dann.py @@ -5,10 +5,10 @@ from torchvision import models as torchvisionmodels from torchvision.models import ResNet50_Weights -from domainlab.mk_exp import mk_exp from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault -from domainlab.tasks.task_dset import mk_task_dset +from domainlab.mk_exp import mk_exp from domainlab.models.model_dann import mk_dann +from domainlab.tasks.task_dset import mk_task_dset from domainlab.tasks.utils_task import ImSize @@ -28,15 +28,21 @@ def mk_exp_dann(trainer="mldg"): # specify domain generalization task task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) task.get_list_domains_tr_te(None, "domain1") # specify task-specific parameters num_output_net_classifier = task.dim_y @@ -45,7 +51,9 @@ def mk_exp_dann(trainer="mldg"): alpha = 1e-3 # specify feature extractor as ResNet - net_encoder = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + net_encoder = torchvisionmodels.resnet.resnet50( + weights=ResNet50_Weights.IMAGENET1K_V2 + ) num_output_net_encoder = net_encoder.fc.out_features # specify discriminator as linear network @@ -55,9 +63,16 @@ def mk_exp_dann(trainer="mldg"): net_classifier = nn.Linear(num_output_net_encoder, num_output_net_classifier) # specify model to use - model = mk_dann()(list_str_y, task.list_domain_tr, alpha, net_encoder, net_classifier, net_discriminator) + model = mk_dann()( + list_str_y, + task.list_domain_tr, + alpha, + net_encoder, + net_classifier, + net_discriminator, + ) # make trainer for model - + exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32) exp.execute(num_epochs=2) diff --git a/tests/test_mk_exp_deepall.py b/tests/test_mk_exp_deepall.py index 6b645e755..f4b37782f 100644 --- a/tests/test_mk_exp_deepall.py +++ b/tests/test_mk_exp_deepall.py @@ -5,10 +5,10 @@ from torchvision import models as torchvisionmodels from torchvision.models import ResNet50_Weights -from domainlab.mk_exp import mk_exp from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault -from domainlab.tasks.task_dset import mk_task_dset +from domainlab.mk_exp import mk_exp from domainlab.models.model_erm import mk_erm +from domainlab.tasks.task_dset import mk_task_dset from domainlab.tasks.utils_task import ImSize @@ -27,16 +27,22 @@ def mk_exp_erm(trainer="mldg"): """ # specify domain generalization task - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify backbone to use backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) diff --git a/tests/test_mk_exp_diva.py b/tests/test_mk_exp_diva.py index 59882c071..79fcb86aa 100644 --- a/tests/test_mk_exp_diva.py +++ b/tests/test_mk_exp_diva.py @@ -2,14 +2,14 @@ make an experiment using "diva" model """ -from domainlab.mk_exp import mk_exp +from domainlab.compos.pcr.request import RequestVAEBuilderNN +from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool +from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault -from domainlab.tasks.task_dset import mk_task_dset +from domainlab.mk_exp import mk_exp from domainlab.models.model_diva import mk_diva +from domainlab.tasks.task_dset import mk_task_dset from domainlab.tasks.utils_task import ImSize -from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter -from domainlab.compos.pcr.request import RequestVAEBuilderNN -from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool def test_mk_exp_diva(): @@ -26,16 +26,22 @@ def mk_exp_diva(trainer="mldg"): """ # specify domain generalization task - task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify parameters list_str_y = [f"class{i}" for i in range(task.dim_y)] list_d_tr = ["domain2", "domain3"] @@ -48,19 +54,34 @@ def mk_exp_diva(trainer="mldg"): beta_x = 1e3 beta_y = 1e3 net_class_d = LSEncoderConvBnReluPool( - zd_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1) + zd_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1 + ) net_x = LSEncoderConvBnReluPool( - zx_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1) + zx_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1 + ) net_class_y = LSEncoderConvBnReluPool( - zy_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1) + zy_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1 + ) - request = RequestVAEBuilderNN(net_class_d, net_x, net_class_y, - task.isize.c, task.isize.h, task.isize.w) + request = RequestVAEBuilderNN( + net_class_d, net_x, net_class_y, task.isize.c, task.isize.h, task.isize.w + ) chain_node_builder = VAEChainNodeGetter(request)() # specify model to use - model = mk_diva()(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y, list_d_tr, gamma_d, - gamma_y, beta_d, beta_x, beta_y) + model = mk_diva()( + chain_node_builder, + zd_dim, + zy_dim, + zx_dim, + list_str_y, + list_d_tr, + gamma_d, + gamma_y, + beta_d, + beta_x, + beta_y, + ) # make trainer for model exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32) diff --git a/tests/test_mk_exp_jigen.py b/tests/test_mk_exp_jigen.py index 7e3804366..6c19d0d76 100644 --- a/tests/test_mk_exp_jigen.py +++ b/tests/test_mk_exp_jigen.py @@ -6,10 +6,10 @@ from torchvision import models as torchvisionmodels from torchvision.models import ResNet50_Weights -from domainlab.mk_exp import mk_exp from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault -from domainlab.tasks.task_dset import mk_task_dset +from domainlab.mk_exp import mk_exp from domainlab.models.model_jigen import mk_jigen +from domainlab.tasks.task_dset import mk_task_dset from domainlab.tasks.utils_task import ImSize @@ -28,25 +28,33 @@ def mk_exp_jigen(trainer="mldg"): """ # specify domain generalization task: - task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify parameters num_output_net_classifier = task.dim_y - num_output_net_permutation = 32 # 31+1 + num_output_net_permutation = 32 # 31+1 list_str_y = [f"class{i}" for i in range(num_output_net_classifier)] coeff_reg = 1e-3 # specify net encoder - net_encoder = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + net_encoder = torchvisionmodels.resnet.resnet50( + weights=ResNet50_Weights.IMAGENET1K_V2 + ) num_output_net_encoder = net_encoder.fc.out_features # specify permutation classifier as linear network @@ -56,8 +64,14 @@ def mk_exp_jigen(trainer="mldg"): net_classifier = nn.Linear(num_output_net_encoder, num_output_net_classifier) # specify model to use - model = mk_jigen()(list_str_y, net_encoder, - net_classifier, net_permutation, coeff_reg, meta_info={"nperm":num_output_net_permutation}) + model = mk_jigen()( + list_str_y, + net_encoder, + net_classifier, + net_permutation, + coeff_reg, + meta_info={"nperm": num_output_net_permutation}, + ) # make trainer for model exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32) diff --git a/tests/test_model_diva.py b/tests/test_model_diva.py index 598f45d8d..6346fa5dd 100644 --- a/tests/test_model_diva.py +++ b/tests/test_model_diva.py @@ -1,10 +1,9 @@ - -from domainlab.models.model_diva import mk_diva -from domainlab.utils.utils_classif import mk_dummy_label_list_str -from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter -from domainlab.compos.pcr.request import RequestVAEBuilderCHW from domainlab.arg_parser import mk_parser_main +from domainlab.compos.pcr.request import RequestVAEBuilderCHW +from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter +from domainlab.models.model_diva import mk_diva from domainlab.utils.test_img import mk_rand_xyd +from domainlab.utils.utils_classif import mk_dummy_label_list_str def test_model_diva(): @@ -19,9 +18,19 @@ def test_model_diva(): request = RequestVAEBuilderCHW(3, 28, 28, args=margs) node = VAEChainNodeGetter(request)() - model = mk_diva()(node, zd_dim=8, zy_dim=8, zx_dim=8, list_d_tr=list_str_d, - list_str_y=list_str_y, gamma_d=1.0, gamma_y=1.0, - beta_d=1.0, beta_y=1.0, beta_x=1.0) + model = mk_diva()( + node, + zd_dim=8, + zy_dim=8, + zx_dim=8, + list_d_tr=list_str_d, + list_str_y=list_str_y, + gamma_d=1.0, + gamma_y=1.0, + beta_d=1.0, + beta_y=1.0, + beta_x=1.0, + ) imgs, y_s, d_s = mk_rand_xyd(28, y_dim, 2, 2) _, _, _, _, _ = model.infer_y_vpicn(imgs) model(imgs, y_s, d_s) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 245ea2b4d..4a53d6991 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -1,27 +1,32 @@ """ -executing mk_exp multiple times will cause deep copy to be called multiple times, pytest will show process got killed. +executing mk_exp multiple times will cause deep copy to be called multiple times, pytest will show process got killed. """ from torch import nn from torchvision import models as torchvisionmodels from torchvision.models import ResNet50_Weights - from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor from domainlab.algos.msels.c_msel_val import MSelValPerf from domainlab.algos.observers.b_obvisitor import ObVisitor -from domainlab.models.model_erm import mk_erm -from domainlab.utils.utils_cuda import get_device from domainlab.arg_parser import mk_parser_main +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault from domainlab.exp.exp_main import Exp - -from domainlab.dsets.dset_mnist_color_solo_default import \ - DsetMNISTColorSoloDefault +from domainlab.models.model_erm import mk_erm from domainlab.tasks.task_dset import mk_task_dset from domainlab.tasks.utils_task import ImSize +from domainlab.utils.utils_cuda import get_device -def mk_exp(task, model, trainer: str, test_domain: str, batchsize: int, - alone=True, force_best_val=False, msel_loss_tr=False): +def mk_exp( + task, + model, + trainer: str, + test_domain: str, + batchsize: int, + alone=True, + force_best_val=False, + msel_loss_tr=False, +): """ Creates a custom experiment. The user can specify the input parameters. @@ -69,21 +74,25 @@ def test_msel_oracle(): """ return trainer, model, observer """ - task = mk_task_dset( - isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50( - weights=ResNet50_Weights.IMAGENET1K_V2) + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) @@ -91,33 +100,35 @@ def test_msel_oracle(): model = mk_erm()(backbone) # make trainer for model - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", - batchsize=32) + exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=32) exp.execute(num_epochs=2) del exp - def test_msel_oracle1(): """ return trainer, model, observer """ - task = mk_task_dset( - isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50( - weights=ResNet50_Weights.IMAGENET1K_V2) + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) @@ -126,9 +137,9 @@ def test_msel_oracle1(): # make trainer for model - - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", - batchsize=32, alone=False) + exp = mk_exp( + task, model, trainer="mldg", test_domain="domain1", batchsize=32, alone=False + ) exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.update(clear_counter=True) @@ -139,21 +150,25 @@ def test_msel_oracle2(): """ return trainer, model, observer """ - task = mk_task_dset( - isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50( - weights=ResNet50_Weights.IMAGENET1K_V2) + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) @@ -161,66 +176,89 @@ def test_msel_oracle2(): model = mk_erm()(backbone) # make trainer for model - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", - batchsize=32) + exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=32) exp.execute(num_epochs=2) - + + def test_msel_oracle3(): """ return trainer, model, observer """ - task = mk_task_dset( - isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50( - weights=ResNet50_Weights.IMAGENET1K_V2) + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) # specify model to use model = mk_erm()(backbone) - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", - batchsize=32, alone=False, force_best_val=True) + exp = mk_exp( + task, + model, + trainer="mldg", + test_domain="domain1", + batchsize=32, + alone=False, + force_best_val=True, + ) exp.execute(num_epochs=2) del exp + def test_msel_oracle4(): """ return trainer, model, observer """ - task = mk_task_dset( - isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain(name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1)) - task.add_domain(name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3)) - task.add_domain(name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5)) + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50( - weights=ResNet50_Weights.IMAGENET1K_V2) + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) # specify model to use model = mk_erm()(backbone) - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", - batchsize=32, alone=False, msel_loss_tr=True) + exp = mk_exp( + task, + model, + trainer="mldg", + test_domain="domain1", + batchsize=32, + alone=False, + msel_loss_tr=True, + ) exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.best_loss = 0 exp.trainer.observer.model_sel.msel.update(clear_counter=True) diff --git a/tests/test_msel_tr_loss.py b/tests/test_msel_tr_loss.py index 5e94b0d15..f11428497 100644 --- a/tests/test_msel_tr_loss.py +++ b/tests/test_msel_tr_loss.py @@ -8,6 +8,8 @@ def test_erm(): """ unit deep all """ - utils_test_algo("--te_d 0 --tr_d 3 7 --task=mnistcolor10 \ + utils_test_algo( + "--te_d 0 --tr_d 3 7 --task=mnistcolor10 \ --model=erm --nname=conv_bn_pool_2 --bs=2 \ - --msel=loss_tr --epos=2") + --msel=loss_tr --epos=2" + ) diff --git a/tests/test_observer.py b/tests/test_observer.py index fe046456f..1a3f27bfb 100644 --- a/tests/test_observer.py +++ b/tests/test_observer.py @@ -2,9 +2,11 @@ unit and end-end test for deep all, dann """ import gc + import torch -from domainlab.exp.exp_main import Exp + from domainlab.arg_parser import mk_parser_main +from domainlab.exp.exp_main import Exp def test_erm(): @@ -12,11 +14,20 @@ def test_erm(): unit deep all """ parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "erm", "--bs", "2", - "--nname", "conv_bn_pool_2" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "erm", + "--bs", + "2", + "--nname", + "conv_bn_pool_2", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) @@ -31,12 +42,22 @@ def test_erm_trloss(): unit deep all """ parser = mk_parser_main() - margs = parser.parse_args(["--te_d", "caltech", - "--task", "mini_vlcs", - "--model", "erm", "--bs", "2", - "--nname", "conv_bn_pool_2", - "--msel", "loss_tr" - ]) + margs = parser.parse_args( + [ + "--te_d", + "caltech", + "--task", + "mini_vlcs", + "--model", + "erm", + "--bs", + "2", + "--nname", + "conv_bn_pool_2", + "--msel", + "loss_tr", + ] + ) exp = Exp(margs) exp.trainer.before_tr() exp.trainer.tr_epoch(0) diff --git a/tests/test_run_experiment.py b/tests/test_run_experiment.py index aadbcf63d..b22593a91 100644 --- a/tests/test_run_experiment.py +++ b/tests/test_run_experiment.py @@ -14,54 +14,56 @@ def test_run_experiment(): config = yaml.safe_load(stream) if torch.cuda.is_available(): torch.cuda.init() - config['epos'] = 1 - config['startseed'] = 1 - config['endseed'] = 1 - config['test_domains'] = ['caltech'] + config["epos"] = 1 + config["startseed"] = 1 + config["endseed"] = 1 + config["test_domains"] = ["caltech"] param_file = "data/ztest_files/test_parameter_samples.csv" param_index = 0 out_file = "zoutput/benchmarks/demo_benchmark/rule_results/0.csv" - run_experiment(config, param_file, param_index, out_file, misc={'testing': True}) - config['test_domains'] = [] + run_experiment(config, param_file, param_index, out_file, misc={"testing": True}) + config["test_domains"] = [] run_experiment(config, param_file, param_index, out_file) - config['domainlab_args']['batchsize'] = 16 + config["domainlab_args"]["batchsize"] = 16 with pytest.raises(ValueError): run_experiment(config, param_file, param_index, out_file) + def test_run_experiment_parameter_doubling_error(): - '''checks if a hyperparameter is specified multiple times, - in the sympling section, in the common_args section and in the task section''' + """checks if a hyperparameter is specified multiple times, + in the sympling section, in the common_args section and in the task section""" with open("examples/benchmark/demo_benchmark.yaml", "r", encoding="utf8") as stream: config = yaml.safe_load(stream) - config['epos'] = 1 - config['startseed'] = 1 - config['endseed'] = 1 - config['test_domains'] = ['caltech'] - config['diva']['gamma_y'] = 1e4 + config["epos"] = 1 + config["startseed"] = 1 + config["endseed"] = 1 + config["test_domains"] = ["caltech"] + config["diva"]["gamma_y"] = 1e4 param_file = "data/ztest_files/test_parameter_samples.csv" param_index = 0 out_file = "zoutput/benchmarks/demo_benchmark/rule_results/0.csv" - with pytest.raises(RuntimeError, - match="has already been fixed " - "to a value in the algorithm section."): + with pytest.raises( + RuntimeError, + match="has already been fixed " "to a value in the algorithm section.", + ): run_experiment(config, param_file, param_index, out_file) - with open("examples/benchmark/demo_benchmark.yaml", "r", encoding="utf8") as stream: config = yaml.safe_load(stream) - config['epos'] = 1 - config['startseed'] = 1 - config['endseed'] = 1 - config['test_domains'] = ['caltech'] - config['domainlab_args']['gamma_y'] = 1e4 + config["epos"] = 1 + config["startseed"] = 1 + config["endseed"] = 1 + config["test_domains"] = ["caltech"] + config["domainlab_args"]["gamma_y"] = 1e4 param_file = "data/ztest_files/test_parameter_samples.csv" param_index = 0 out_file = "zoutput/benchmarks/demo_benchmark/rule_results/0.csv" - with pytest.raises(RuntimeError, - match="has already been fixed " - "to a value in the domainlab_args section."): + with pytest.raises( + RuntimeError, + match="has already been fixed " "to a value in the domainlab_args section.", + ): run_experiment(config, param_file, param_index, out_file) diff --git a/tests/test_sav_img_title.py b/tests/test_sav_img_title.py index 4f5870659..b6743d476 100644 --- a/tests/test_sav_img_title.py +++ b/tests/test_sav_img_title.py @@ -1,23 +1,27 @@ -''' +""" Code coverage issues: https://app.codecov.io/gh/marrlab/DomainLab/blob/master/domainlab/utils/utils_img_sav.py - lines 22-23 - lines 31-35 -''' +""" import torch + from domainlab.utils.utils_img_sav import mk_fun_sav_img, sav_add_title + def test_save_img(): """ test sav_img function """ imgs = torch.randn(1, 3, 28, 28) tt_sav_img = mk_fun_sav_img() - tt_sav_img(imgs, name='rand_img.png', title='random_img') + tt_sav_img(imgs, name="rand_img.png", title="random_img") + + def test_add_title(): """ test sav_add_title """ img = torch.randn(3, 28, 28) - sav_add_title(img, path='.', title='random_img') + sav_add_title(img, path=".", title="random_img") diff --git a/tests/test_task_folder.py b/tests/test_task_folder.py index 5d62c9cb2..3b78b5f2c 100644 --- a/tests/test_task_folder.py +++ b/tests/test_task_folder.py @@ -1,49 +1,48 @@ import os + import pytest from torchvision import transforms from domainlab.arg_parser import mk_parser_main -from domainlab.tasks.task_folder_mk import mk_task_folder from domainlab.tasks.task_folder import NodeTaskFolder +from domainlab.tasks.task_folder_mk import mk_task_folder from domainlab.tasks.utils_task import ImSize + path_this_file = os.path.dirname(os.path.realpath(__file__)) def test_fun(): - node = mk_task_folder(extensions={"caltech": "jpg", "sun": "jpg", "labelme": "jpg"}, - list_str_y=["chair", "car"], - dict_domain_folder_name2class={ - "caltech": {"auto": "car", - "stuhl": "chair"}, - "sun": {"vehicle": "car", - "sofa": "chair"}, - "labelme": {"drive": "car", - "sit": "chair"} - }, - dict_domain_img_trans={ - "caltech": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - "sun": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - "labelme": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - }, - img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - isize=ImSize(3, 224, 224), - dict_domain2imgroot={ - "caltech": - "data/vlcs_mini/caltech/", - "sun": - "data/vlcs_mini/sun/", - "labelme": - "data/vlcs_mini/labelme/"}, - taskna="mini_vlcs", - succ=None) + node = mk_task_folder( + extensions={"caltech": "jpg", "sun": "jpg", "labelme": "jpg"}, + list_str_y=["chair", "car"], + dict_domain_folder_name2class={ + "caltech": {"auto": "car", "stuhl": "chair"}, + "sun": {"vehicle": "car", "sofa": "chair"}, + "labelme": {"drive": "car", "sit": "chair"}, + }, + dict_domain_img_trans={ + "caltech": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + "sun": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + "labelme": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + }, + img_trans_te=transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + isize=ImSize(3, 224, 224), + dict_domain2imgroot={ + "caltech": "data/vlcs_mini/caltech/", + "sun": "data/vlcs_mini/sun/", + "labelme": "data/vlcs_mini/labelme/", + }, + taskna="mini_vlcs", + succ=None, + ) parser = mk_parser_main() # batchsize bs=2 ensures it works on small dataset @@ -68,142 +67,153 @@ def test_fun(): # folder=folder_na, # batches=10) + def test_mk_task_folder(): - _ = mk_task_folder(extensions={"caltech": "jpg", "sun": "jpg", "labelme": "jpg"}, - list_str_y=["chair", "car"], - dict_domain_folder_name2class={ - "caltech": {"auto": "car", "stuhl": "chair"}, - "sun": {"viehcle": "car", "sofa": "chair"}, - "labelme": {"drive": "car", "sit": "chair"} - }, - img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), + _ = mk_task_folder( + extensions={"caltech": "jpg", "sun": "jpg", "labelme": "jpg"}, + list_str_y=["chair", "car"], + dict_domain_folder_name2class={ + "caltech": {"auto": "car", "stuhl": "chair"}, + "sun": {"viehcle": "car", "sofa": "chair"}, + "labelme": {"drive": "car", "sit": "chair"}, + }, + img_trans_te=transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + dict_domain_img_trans={ + "caltech": transforms.Compose( + [ + transforms.Resize((224, 224)), + ] + ), + "sun": transforms.Compose( + [ + transforms.Resize((224, 224)), + ] + ), + "labelme": transforms.Compose( + [ + transforms.Resize((224, 224)), + ] + ), + }, + isize=ImSize(3, 224, 224), + dict_domain2imgroot={ + "caltech": "data/vlcs_mini/caltech/", + "sun": "data/vlcs_mini/sun/", + "labelme": "data/vlcs_mini/labelme/", + }, + taskna="mini_vlcs", + ) - dict_domain_img_trans={ - "caltech": transforms.Compose([transforms.Resize((224, 224)), ]), - "sun": transforms.Compose([transforms.Resize((224, 224)), ]), - "labelme": transforms.Compose([transforms.Resize((224, 224)), ]), - }, - isize=ImSize(3, 224, 224), - dict_domain2imgroot={ - "caltech": "data/vlcs_mini/caltech/", - "sun": "data/vlcs_mini/sun/", - "labelme": "data/vlcs_mini/labelme/"}, - taskna="mini_vlcs") def test_none_extensions(): """Check if all different datatypes for the extensions arg work.""" - node = mk_task_folder(extensions={'caltech': None, 'labelme': None, 'sun': None}, - list_str_y=["chair", "car"], - dict_domain_folder_name2class={ - "caltech": {"auto": "car", - "stuhl": "chair"}, - "sun": {"vehicle": "car", - "sofa": "chair"}, - "labelme": {"drive": "car", - "sit": "chair"} - }, - dict_domain_img_trans={ - "caltech": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - "sun": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - "labelme": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - }, - img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - isize=ImSize(3, 224, 224), - dict_domain2imgroot={ - "caltech": - "data/vlcs_mini/caltech/", - "labelme": - "data/vlcs_mini/labelme/", - "sun": - "data/vlcs_mini/sun/"}, - taskna="mini_vlcs", - succ=None) + node = mk_task_folder( + extensions={"caltech": None, "labelme": None, "sun": None}, + list_str_y=["chair", "car"], + dict_domain_folder_name2class={ + "caltech": {"auto": "car", "stuhl": "chair"}, + "sun": {"vehicle": "car", "sofa": "chair"}, + "labelme": {"drive": "car", "sit": "chair"}, + }, + dict_domain_img_trans={ + "caltech": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + "sun": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + "labelme": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + }, + img_trans_te=transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + isize=ImSize(3, 224, 224), + dict_domain2imgroot={ + "caltech": "data/vlcs_mini/caltech/", + "labelme": "data/vlcs_mini/labelme/", + "sun": "data/vlcs_mini/sun/", + }, + taskna="mini_vlcs", + succ=None, + ) parser = mk_parser_main() # batchsize bs=2 ensures it works on small dataset args = parser.parse_args(["--te_d", "1", "--bs", "2", "--model", "diva"]) node.init_business(args) - assert node.dict_domain_class_count['caltech']['chair'] == 6 - assert node.dict_domain_class_count['caltech']['car'] == 20 + assert node.dict_domain_class_count["caltech"]["chair"] == 6 + assert node.dict_domain_class_count["caltech"]["car"] == 20 # explicit given extension - node = mk_task_folder(extensions={'caltech': 'jpg', 'sun': 'jpg'}, - list_str_y=["bird", "car"], - dict_domain_folder_name2class={ - "caltech": {"auto": "car", - "vogel": "bird"}, - 'sun': {'vehicle': 'car', - 'sofa': 'bird'} - }, - dict_domain_img_trans={ - "caltech": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - "sun": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - }, - img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - isize=ImSize(3, 224, 224), - dict_domain2imgroot={ - "caltech": - "data/mixed_codec/caltech/", - "sun": - "data/mixed_codec/sun/", - }, - taskna="mixed_codec", - succ=None) + node = mk_task_folder( + extensions={"caltech": "jpg", "sun": "jpg"}, + list_str_y=["bird", "car"], + dict_domain_folder_name2class={ + "caltech": {"auto": "car", "vogel": "bird"}, + "sun": {"vehicle": "car", "sofa": "bird"}, + }, + dict_domain_img_trans={ + "caltech": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + "sun": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + }, + img_trans_te=transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + isize=ImSize(3, 224, 224), + dict_domain2imgroot={ + "caltech": "data/mixed_codec/caltech/", + "sun": "data/mixed_codec/sun/", + }, + taskna="mixed_codec", + succ=None, + ) parser = mk_parser_main() # batchsize bs=2 ensures it works on small dataset args = parser.parse_args(["--te_d", "1", "--bs", "2", "--model", "diva"]) node.init_business(args) - assert node.dict_domain_class_count['caltech']['bird'] == 2,\ - "mixed_codec/caltech holds 2 jpg birds" - assert node.dict_domain_class_count['caltech']['car'] == 2,\ - "mixed_codec/caltech holds 2 jpg cars" + assert ( + node.dict_domain_class_count["caltech"]["bird"] == 2 + ), "mixed_codec/caltech holds 2 jpg birds" + assert ( + node.dict_domain_class_count["caltech"]["car"] == 2 + ), "mixed_codec/caltech holds 2 jpg cars" # No extensions given - node = mk_task_folder(extensions=None, - list_str_y=["bird", "car"], - dict_domain_folder_name2class={ - "caltech": {"auto": "car", - "vogel": "bird"}, - 'sun': {'vehicle': 'car', - 'sofa': 'bird'} - }, - dict_domain_img_trans={ - "caltech": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - "sun": transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - }, - img_trans_te=transforms.Compose( - [transforms.Resize((224, 224)), - transforms.ToTensor()]), - isize=ImSize(3, 224, 224), - dict_domain2imgroot={ - "caltech": - "data/mixed_codec/caltech/", - "sun": - "data/mixed_codec/sun/", - }, - taskna="mixed_codec", - succ=None) + node = mk_task_folder( + extensions=None, + list_str_y=["bird", "car"], + dict_domain_folder_name2class={ + "caltech": {"auto": "car", "vogel": "bird"}, + "sun": {"vehicle": "car", "sofa": "bird"}, + }, + dict_domain_img_trans={ + "caltech": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + "sun": transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + }, + img_trans_te=transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), + isize=ImSize(3, 224, 224), + dict_domain2imgroot={ + "caltech": "data/mixed_codec/caltech/", + "sun": "data/mixed_codec/sun/", + }, + taskna="mixed_codec", + succ=None, + ) parser = mk_parser_main() # batchsize bs=2 ensures it works on small dataset @@ -213,52 +223,55 @@ def test_none_extensions(): @pytest.fixture def pacs_node(): - """Task folder for PACS Mini 10 - """ + """Task folder for PACS Mini 10""" # FIXME: make me work with mk_task_folder node = NodeTaskFolder() node.set_list_domains(["cartoon", "photo"]) # node.extensions = {"cartoon": "jpg", "photo": "jpg"} - node.extensions = ('jpg',) + node.extensions = ("jpg",) node.list_str_y = ["dog", "elephant"] node.dict_domain2imgroot = { "cartoon": "data/pacs_mini_10/cartoon/", - "photo": "data/pacs_mini_10/photo/" + "photo": "data/pacs_mini_10/photo/", } return node @pytest.fixture def folder_args(): - """Test args; batchsize bs=2 ensures it works on small dataset - """ + """Test args; batchsize bs=2 ensures it works on small dataset""" parser = mk_parser_main() args = parser.parse_args(["--te_d", "1", "--bs", "2", "--model", "diva"]) return args + def test_nodetaskfolder(pacs_node, folder_args): - """Test NodeTaskFolder can be initiated without transforms - """ + """Test NodeTaskFolder can be initiated without transforms""" pacs_node.init_business(folder_args) def test_nodetaskfolder_transforms(pacs_node, folder_args): - """Test NodeTaskFolder can be initiated with transforms - """ + """Test NodeTaskFolder can be initiated with transforms""" pacs_node._dict_domain_img_trans = { - "cartoon": transforms.Compose([transforms.Resize((224, 224)), ]), - "photo": transforms.Compose([transforms.Resize((224, 224)), ]) + "cartoon": transforms.Compose( + [ + transforms.Resize((224, 224)), + ] + ), + "photo": transforms.Compose( + [ + transforms.Resize((224, 224)), + ] + ), } - pacs_node.img_trans_te = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor() - ]) + pacs_node.img_trans_te = transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ) pacs_node.init_business(folder_args) def test_nodetaskfolder_split_error(pacs_node, folder_args): - """Test NodeTaskFolder throws an error when split == True - """ + """Test NodeTaskFolder throws an error when split == True""" folder_args.split = True with pytest.raises(RuntimeError): pacs_node.init_business(folder_args) diff --git a/tests/test_train_diva.py b/tests/test_train_diva.py index 0776e17d0..e22cb32a2 100644 --- a/tests/test_train_diva.py +++ b/tests/test_train_diva.py @@ -1,16 +1,18 @@ import gc + import torch -from domainlab.algos.observers.b_obvisitor import ObVisitor -from domainlab.models.model_diva import mk_diva -from domainlab.utils.utils_classif import mk_dummy_label_list_str -from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter -from domainlab.compos.pcr.request import RequestVAEBuilderCHW -from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler -from domainlab.exp.exp_main import Exp -from domainlab.arg_parser import mk_parser_main + from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor from domainlab.algos.msels.c_msel_tr_loss import MSelTrLoss +from domainlab.algos.observers.b_obvisitor import ObVisitor from domainlab.algos.observers.c_obvisitor_cleanup import ObVisitorCleanUp +from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler +from domainlab.arg_parser import mk_parser_main +from domainlab.compos.pcr.request import RequestVAEBuilderCHW +from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter +from domainlab.exp.exp_main import Exp +from domainlab.models.model_diva import mk_diva +from domainlab.utils.utils_classif import mk_dummy_label_list_str from domainlab.utils.utils_cuda import get_device @@ -28,15 +30,27 @@ def test_trainer_diva(): request = RequestVAEBuilderCHW(3, 28, 28, args=margs) node = VAEChainNodeGetter(request)() - model = mk_diva()(node, zd_dim=8, zy_dim=8, zx_dim=8, list_d_tr=list_str_d, - list_str_y=list_str_y, gamma_d=1.0, gamma_y=1.0, - beta_d=1.0, beta_y=1.0, beta_x=1.0) + model = mk_diva()( + node, + zd_dim=8, + zy_dim=8, + zx_dim=8, + list_d_tr=list_str_d, + list_str_y=list_str_y, + gamma_d=1.0, + gamma_y=1.0, + beta_d=1.0, + beta_y=1.0, + beta_x=1.0, + ) model_sel = MSelOracleVisitor(MSelTrLoss(max_es=margs.es)) exp = Exp(margs) device = get_device(margs) observer = ObVisitorCleanUp(ObVisitor(model_sel)) trainer = TrainerHyperScheduler() - trainer.init_business(model, task=exp.task, observer=observer, device=device, aconf=margs) + trainer.init_business( + model, task=exp.task, observer=observer, device=device, aconf=margs + ) trainer.before_tr() trainer.tr_epoch(0) del exp diff --git a/tests/test_unit_utils_task.py b/tests/test_unit_utils_task.py index ffb145b45..723674ece 100644 --- a/tests/test_unit_utils_task.py +++ b/tests/test_unit_utils_task.py @@ -1,5 +1,5 @@ -from domainlab.tasks.utils_task import LoaderDomainLabel from domainlab.dsets.dset_poly_domains_mnist_color_default import DsetMNISTColorMix +from domainlab.tasks.utils_task import LoaderDomainLabel def test_unit_utils_task(): @@ -7,4 +7,3 @@ def test_unit_utils_task(): loader = LoaderDomainLabel(32, 3)(dset, 0, "0") batch = next(iter(loader)) assert batch[0].shape == (32, 3, 28, 28) - diff --git a/tests/test_utils.py b/tests/test_utils.py index 45fa530cb..529552858 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,15 +20,15 @@ def test_exp_protocol_agg_writer(): parser = mk_parser_main() args = parser.parse_args(args=[]) misc = { - 'model': 'diva', - 'nname': 'conv_bn_pool_2', - 'nname_dom': 'conv_bn_pool_2', - 'task': 'mnistcolor10', - 'te_d': 0, - 'result_file': "out_file", - 'params': "hyperparameters", - 'benchmark_task_name': "task", - 'param_index': 0 + "model": "diva", + "nname": "conv_bn_pool_2", + "nname_dom": "conv_bn_pool_2", + "task": "mnistcolor10", + "te_d": 0, + "result_file": "out_file", + "params": "hyperparameters", + "benchmark_task_name": "task", + "param_index": 0, } apply_dict_to_args(args, misc, extend=True) diff --git a/tests/this_is_not_a_network.py b/tests/this_is_not_a_network.py index 5cd8a7522..505da90f0 100644 --- a/tests/this_is_not_a_network.py +++ b/tests/this_is_not_a_network.py @@ -1,11 +1,12 @@ -''' +""" for testing one needs a network, which is actually not a network -''' +""" + def build_feat_extract_net(dim_y, remove_last_layer): - ''' + """ I am not a neuronal network - ''' + """ # I had to add these two prints to satisfy the requirements of codacy # to use all arguments given to a function print(str(remove_last_layer)) diff --git a/tests/utils_test.py b/tests/utils_test.py index 88777e28a..a8dd04177 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -3,10 +3,12 @@ helper function to execute as if command line arguments are passed """ import gc + import pandas as pd import torch -from domainlab.exp.exp_main import Exp + from domainlab.arg_parser import mk_parser_main +from domainlab.exp.exp_main import Exp def utils_test_algo(argsstr="--help"): diff --git a/uml/libDG.uml b/uml/libDG.uml index 6ddc086c3..02c29d729 100644 --- a/uml/libDG.uml +++ b/uml/libDG.uml @@ -1,17 +1,17 @@ @startuml -' +' ' *-- composition ' <|-- extension ' o-- aggregation ' -- association (1 to n or 1 to 1 or n to 1) ' ..> -'Dependency is a weaker form of bond which indicates that one class depends on -'another because it uses it at some point in time. One class depends on -'another if the independent class is a parameter variable or local variable of -'a method of the dependent class. This is different from an association, where -'an attribute of the dependent class is an instance of the independent class. -'Sometimes the relationship between two classes is very weak. They are not -'implemented with member variables at all. Rather they might be implemented as +'Dependency is a weaker form of bond which indicates that one class depends on +'another because it uses it at some point in time. One class depends on +'another if the independent class is a parameter variable or local variable of +'a method of the dependent class. This is different from an association, where +'an attribute of the dependent class is an instance of the independent class. +'Sometimes the relationship between two classes is very weak. They are not +'implemented with member variables at all. Rather they might be implemented as 'member function arguments. package tasks { abstract class Task { @@ -48,7 +48,7 @@ package algos { } package observer { class Observer { - + trainer.model.calculate_metric() + + trainer.model.calculate_metric() } } package model_selection { @@ -56,7 +56,7 @@ package algos { - early_stop } } -} +} package datasets <>{ class Dataset { diff --git a/zmain/main.py b/zmain/main.py index f2b9a5adc..b106e4348 100644 --- a/zmain/main.py +++ b/zmain/main.py @@ -1,6 +1,6 @@ -from domainlab.exp.exp_main import Exp -from domainlab.exp.exp_cuda_seed import set_seed # reproducibility from domainlab.arg_parser import parse_cmd_args +from domainlab.exp.exp_cuda_seed import set_seed # reproducibility +from domainlab.exp.exp_main import Exp if __name__ == "__main__": args = parse_cmd_args() diff --git a/zmain/main_gen.py b/zmain/main_gen.py index fc435a2a7..ad5da86e8 100644 --- a/zmain/main_gen.py +++ b/zmain/main_gen.py @@ -2,12 +2,14 @@ command line generate images """ import os + import torch + +from domainlab.arg_parser import mk_parser_main from domainlab.exp.exp_cuda_seed import set_seed from domainlab.tasks.zoo_tasks import TaskChainNodeGetter -from domainlab.arg_parser import mk_parser_main -from domainlab.utils.utils_cuda import get_device from domainlab.utils.flows_gen_img_model import fun_gen +from domainlab.utils.utils_cuda import get_device def main_gen(args, task=None, model=None, device=None): @@ -25,8 +27,9 @@ def main_gen(args, task=None, model=None, device=None): if __name__ == "__main__": parser = mk_parser_main() - parser.add_argument('--mpath', type=str, default=None, - help="path for persisted model") + parser.add_argument( + "--mpath", type=str, default=None, help="path for persisted model" + ) args = parser.parse_args() set_seed(args.seed) main_gen(args) diff --git a/zmain/main_task.py b/zmain/main_task.py index 956ba4e7d..9f5e3c595 100644 --- a/zmain/main_task.py +++ b/zmain/main_task.py @@ -1,10 +1,9 @@ """ probe task by saving images to folder with class and domain label """ +from domainlab.arg_parser import mk_parser_main from domainlab.exp.exp_cuda_seed import set_seed from domainlab.tasks.zoo_tasks import TaskChainNodeGetter -from domainlab.arg_parser import mk_parser_main - if __name__ == "__main__": parser = mk_parser_main()