diff --git a/pyproject.toml b/pyproject.toml index fe7ecdba..898eae5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "uvicorn", "sqlalchemy", "mysqlclient", + "python_dotenv", ] [project.optional-dependencies] diff --git a/src/config.py b/src/config.py new file mode 100644 index 00000000..ae425f96 --- /dev/null +++ b/src/config.py @@ -0,0 +1,45 @@ +import functools +import os +import tomllib +import typing +from pathlib import Path + +from dotenv import load_dotenv + +TomlTable = dict[str, typing.Any] + + +def _apply_defaults_to_siblings(configuration: TomlTable) -> TomlTable: + defaults = configuration["defaults"] + return { + subtable: (defaults | overrides) if isinstance(overrides, dict) else overrides + for subtable, overrides in configuration.items() + if subtable != "defaults" + } + + +@functools.cache +def load_database_configuration(file: Path = Path(__file__).parent / "config.toml") -> TomlTable: + configuration = tomllib.loads(file.read_text()) + + database_configuration = _apply_defaults_to_siblings( + configuration["databases"], + ) + load_dotenv() + database_configuration["openml"]["username"] = os.environ.get( + "OPENML_DATABASES_OPENML_USERNAME", + "root", + ) + database_configuration["openml"]["password"] = os.environ.get( + "OPENML_DATABASES_OPENML_PASSWORD", + "ok", + ) + database_configuration["expdb"]["username"] = os.environ.get( + "OPENML_DATABASES_EXPDB_USERNAME", + "root", + ) + database_configuration["expdb"]["password"] = os.environ.get( + "OPENML_DATABASES_EXPDB_PASSWORD", + "ok", + ) + return database_configuration diff --git a/src/config.toml b/src/config.toml new file mode 100644 index 00000000..f85d9a06 --- /dev/null +++ b/src/config.toml @@ -0,0 +1,11 @@ +[databases.defaults] +host="127.0.0.1" +port="3306" +# SQLAlchemy `dialect` and `driver`: https://docs.sqlalchemy.org/en/20/dialects/index.html +drivername="mysql" + +[databases.expdb] +database="openml_expdb" + +[databases.openml] +database="openml" diff --git a/src/database/datasets.py b/src/database/datasets.py index f794aeae..1e534396 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,17 +1,22 @@ """ Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" from typing import Any +from config import load_database_configuration from sqlalchemy import create_engine, text +from sqlalchemy.engine import URL from database.meta import get_column_names +_database_configuration = load_database_configuration() +expdb_url = URL.create(**_database_configuration["expdb"]) expdb = create_engine( - "mysql://root:ok@127.0.0.1:3306/openml_expdb", + expdb_url, echo=True, pool_recycle=3600, ) +openml_url = URL.create(**_database_configuration["openml"]) openml = create_engine( - "mysql://root:ok@127.0.0.1:3306/openml", + openml_url, echo=True, pool_recycle=3600, ) diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 00000000..412606d1 --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,35 @@ +import os +from pathlib import Path + +from config import _apply_defaults_to_siblings, load_database_configuration + + +def test_apply_defaults_to_siblings_applies_defaults() -> None: + input_ = {"defaults": {1: 1}, "other": {}} + expected = {"other": {1: 1}} + output = _apply_defaults_to_siblings(input_) + assert expected == output + + +def test_apply_defaults_to_siblings_does_not_override() -> None: + input_ = {"defaults": {1: 1}, "other": {1: 2}} + expected = {"other": {1: 2}} + output = _apply_defaults_to_siblings(input_) + assert expected == output + + +def test_apply_defaults_to_siblings_ignores_nontables() -> None: + input_ = {"defaults": {1: 1}, "other": {1: 2}, "not-a-table": 3} + expected = {"other": {1: 2}, "not-a-table": 3} + output = _apply_defaults_to_siblings(input_) + assert expected == output + + +def test_load_configuration_adds_environment_variables(default_configuration_file: Path) -> None: + database_configuration = load_database_configuration(default_configuration_file) + assert database_configuration["openml"]["username"] == "root" + + load_database_configuration.cache_clear() + os.environ["OPENML_DATABASES_OPENML_USERNAME"] = "foo" + database_configuration = load_database_configuration(default_configuration_file) + assert database_configuration["openml"]["username"] == "foo" diff --git a/tests/conftest.py b/tests/conftest.py index fead0652..45c68f1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,27 @@ import json -import pathlib +from pathlib import Path from typing import Any, Generator import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from main import app @pytest.fixture() def api_client() -> Generator[FastAPI, None, None]: + # We want to avoid starting a test client app if tests don't need it. + from main import app + return TestClient(app) @pytest.fixture() def dataset_130() -> Generator[dict[str, Any], None, None]: - json_path = pathlib.Path(__file__).parent / "resources" / "datasets" / "dataset_130.json" + json_path = Path(__file__).parent / "resources" / "datasets" / "dataset_130.json" with json_path.open("r") as dataset_file: yield json.load(dataset_file) + + +@pytest.fixture() +def default_configuration_file() -> Path: + return Path().parent.parent / "src" / "config.toml"