Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ jobs:
pip install -r requirements.txt
pip install pytest pytest-cov pytest-shell
python setup.py install
- name: Install Python Poetry
run: |
python -m pip install --upgrade pip
pip install poetry
poetry install --with dev
- name: Generate coverage report
run: pytest --cov=domainlab tests/ --cov-report=xml
run: poetry run pytest --cov=domainlab tests/ --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
Expand Down
26 changes: 26 additions & 0 deletions domainlab/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
command-line interface (CLI) for the domainlab package
"""

from domainlab.arg_parser import parse_cmd_args
from domainlab.exp.exp_cuda_seed import set_seed # reproducibility
from domainlab.exp.exp_main import Exp
from domainlab.exp_protocol import aggregate_results
from domainlab.utils.generate_benchmark_plots import gen_benchmark_plots


def domainlab_cli():
"""
Function used to run domainlab as a command line tool for the package installed with pip.
"""
args = parse_cmd_args()
if args.bm_dir:
aggregate_results.agg_main(args.bm_dir)
elif args.plot_data is not None:
gen_benchmark_plots(
args.plot_data, args.outp_dir, use_param_index=args.param_idx
)
else:
set_seed(args.seed)
exp = Exp(args=args)
exp.execute()
13 changes: 2 additions & 11 deletions main_out.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
from domainlab.arg_parser import parse_cmd_args
from domainlab.cli import domainlab_cli
from domainlab.exp.exp_cuda_seed import set_seed # reproducibility
from domainlab.exp.exp_main import Exp
from domainlab.exp_protocol import aggregate_results
from domainlab.utils.generate_benchmark_plots import gen_benchmark_plots

if __name__ == "__main__":
args = parse_cmd_args()
if args.bm_dir:
aggregate_results.agg_main(args.bm_dir)
elif args.plot_data is not None:
gen_benchmark_plots(
args.plot_data, args.outp_dir, use_param_index=args.param_idx
)
else:
set_seed(args.seed)
exp = Exp(args=args)
exp.execute()
domainlab_cli()
147 changes: 144 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ scikit-learn = "^1.2.1"
pyyaml = "^6.0"
gdown = "^4.7.1"

[tool.poetry.scripts]
domainlab = 'domainlab.cli:domainlab_cli'

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
pytest = "^6.2.5"
pytest-cov = "^4.1.0"
pytest-shell = "^0.3.2"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
52 changes: 52 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
test 'domainlab' entry point
"""

from importlib.metadata import entry_points

from domainlab.cli import domainlab_cli


def test_entry_point():
"""
Test the entry point for the 'domainlab' console script.

This function retrieves all entry points and asserts that the 'domainlab'
entry point is correctly associated with the 'domainlab_cli' function.
"""
eps = entry_points()
cli_entry = eps.select(group="console_scripts")["domainlab"]
assert cli_entry.load() is domainlab_cli


def test_domainlab_cli(monkeypatch):
"""
Test the 'domainlab_cli' function by simulating command-line arguments.

This function uses the 'monkeypatch' fixture to set the command-line
arguments for the 'domainlab_cli' function and then calls it to ensure
it processes the arguments correctly. The test arguments simulate a
representative command-line input for the 'domainlab' tool.
"""
test_args = [
"--te_d",
"1",
"2",
"--tr_d",
"0",
"3",
"--task",
"mnistcolor10",
"--epos",
"500",
"--bs",
"16",
"--model",
"erm",
"--nname",
"conv_bn_pool_2",
"--lr",
"1e-3",
]
monkeypatch.setattr("sys.argv", ["domainlab"] + test_args)
domainlab_cli()