diff --git a/docs/changes/newsfragments/6942.breaking b/docs/changes/newsfragments/6942.breaking new file mode 100644 index 000000000000..b5a5bdcab12b --- /dev/null +++ b/docs/changes/newsfragments/6942.breaking @@ -0,0 +1,4 @@ +The QCoDeS dataset sqlite connection class `ConnectionPlus` has been deprecated and replace with `AtomicConnection`. +Unlike `ConnectionPlus` `AtomicConnection` is a direct subclass of `sqlite3.Connection` which enables better type checking +and will allow QCoDeS to drop the dependency on `wrapt`. The function `make_connection_plus_from` is also deprecated and +it is no longer supported to convert a connection from a sqlite3 connection to a QCoDeS specific connection. diff --git a/pyproject.toml b/pyproject.toml index 1350cbac386d..381f608f3bb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -239,7 +239,9 @@ filterwarnings = [ 'ignore:open_binary is deprecated:DeprecationWarning', # pyvisa-sim 'ignore:Jupyter is migrating its paths to use standard platformdirs:DeprecationWarning', # jupyter 'ignore:Parsing dates involving a day of month without a year specified is ambiguious:DeprecationWarning', # ipykernel 3.13+ - 'ignore:unclosed database in:ResourceWarning' # internal should be fixed + 'ignore:unclosed database in:ResourceWarning', # internal should be fixed + 'ignore:ConnectionPlus is deprecated:qcodes.utils.deprecate.QCoDeSDeprecationWarning', # remove once deprecated ConnectionPlus has been removed + 'ignore:make_connection_plus_from is deprecated:qcodes.utils.deprecate.QCoDeSDeprecationWarning', # remove once deprecated ConnectionPlus has been removed ] [tool.ruff] diff --git a/src/qcodes/dataset/__init__.py b/src/qcodes/dataset/__init__.py index a46957714493..999d5ce4a6ac 100644 --- a/src/qcodes/dataset/__init__.py +++ b/src/qcodes/dataset/__init__.py @@ -44,7 +44,10 @@ ) from .measurements import Measurement from .plotting import plot_by_id, plot_dataset -from .sqlite.connection import ConnectionPlus +from .sqlite.connection import ( + AtomicConnection, + ConnectionPlus, # pyright: ignore[reportDeprecated] +) from .sqlite.database import ( connect, initialise_database, @@ -61,6 +64,7 @@ __all__ = [ "AbstractSweep", "ArraySweep", + "AtomicConnection", "BreakConditionInterrupt", "ConnectionPlus", "DataSetDefinition", diff --git a/src/qcodes/dataset/data_set.py b/src/qcodes/dataset/data_set.py index d67290a687ec..c5efd9772efd 100644 --- a/src/qcodes/dataset/data_set.py +++ b/src/qcodes/dataset/data_set.py @@ -35,7 +35,11 @@ ) from qcodes.dataset.guids import filter_guids_by_parts, generate_guid, parse_guid from qcodes.dataset.linked_datasets.links import Link, links_to_str, str_to_links -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic, atomic_transaction +from qcodes.dataset.sqlite.connection import ( + AtomicConnection, + atomic, + atomic_transaction, +) from qcodes.dataset.sqlite.database import ( conn_from_dbpath_or_conn, connect, @@ -128,7 +132,7 @@ class _BackgroundWriter(Thread): Write the results from the DataSet's dataqueue in a new thread """ - def __init__(self, queue: Queue[Any], conn: ConnectionPlus): + def __init__(self, queue: Queue[Any], conn: AtomicConnection): super().__init__(daemon=True) self.queue = queue self.path = conn.path_to_dbfile @@ -206,7 +210,7 @@ def __init__( self, path_to_db: str | None = None, run_id: int | None = None, - conn: ConnectionPlus | None = None, + conn: AtomicConnection | None = None, exp_id: int | None = None, name: str | None = None, specs: SpecsOrInterDeps | None = None, @@ -1556,7 +1560,7 @@ def load_by_run_spec( sample_id: int | None = None, location: int | None = None, work_station: int | None = None, - conn: ConnectionPlus | None = None, + conn: AtomicConnection | None = None, ) -> DataSetProtocol: """ Load a run from one or more pieces of runs specification. All @@ -1635,7 +1639,7 @@ def get_guids_by_run_spec( sample_id: int | None = None, location: int | None = None, work_station: int | None = None, - conn: ConnectionPlus | None = None, + conn: AtomicConnection | None = None, ) -> list[str]: """ Get a list of matching guids from one or more pieces of runs specification. All @@ -1677,7 +1681,7 @@ def get_guids_by_run_spec( return matched_guids -def load_by_id(run_id: int, conn: ConnectionPlus | None = None) -> DataSetProtocol: +def load_by_id(run_id: int, conn: AtomicConnection | None = None) -> DataSetProtocol: """ Load a dataset by run id @@ -1723,7 +1727,7 @@ def load_by_id(run_id: int, conn: ConnectionPlus | None = None) -> DataSetProtoc return d -def load_by_guid(guid: str, conn: ConnectionPlus | None = None) -> DataSetProtocol: +def load_by_guid(guid: str, conn: AtomicConnection | None = None) -> DataSetProtocol: """ Load a dataset by its GUID @@ -1763,7 +1767,7 @@ def load_by_guid(guid: str, conn: ConnectionPlus | None = None) -> DataSetProtoc def load_by_counter( - counter: int, exp_id: int, conn: ConnectionPlus | None = None + counter: int, exp_id: int, conn: AtomicConnection | None = None ) -> DataSetProtocol: """ Load a dataset given its counter in a given experiment @@ -1807,7 +1811,9 @@ def load_by_counter( return d -def _get_datasetprotocol_from_guid(guid: str, conn: ConnectionPlus) -> DataSetProtocol: +def _get_datasetprotocol_from_guid( + guid: str, conn: AtomicConnection +) -> DataSetProtocol: run_id = get_runid_from_guid(conn, guid) if run_id is None: raise NameError( @@ -1840,7 +1846,7 @@ def _get_datasetprotocol_from_guid(guid: str, conn: ConnectionPlus) -> DataSetPr return d -def _get_datasetprotocol_export_info(run_id: int, conn: ConnectionPlus) -> ExportInfo: +def _get_datasetprotocol_export_info(run_id: int, conn: AtomicConnection) -> ExportInfo: metadata = get_metadata_from_run_id(conn=conn, run_id=run_id) export_info_str = metadata.get("export_info", "") export_info = ExportInfo.from_str(export_info_str) @@ -1853,7 +1859,7 @@ def new_data_set( specs: SPECS | None = None, values: VALUES | None = None, metadata: Any | None = None, - conn: ConnectionPlus | None = None, + conn: AtomicConnection | None = None, in_memory_cache: bool = True, ) -> DataSet: """ @@ -1894,7 +1900,7 @@ def new_data_set( def generate_dataset_table( - guids: Sequence[str], conn: ConnectionPlus | None = None + guids: Sequence[str], conn: AtomicConnection | None = None ) -> str: """ Generate an ASCII art table of information about the runs attached to the @@ -1902,7 +1908,7 @@ def generate_dataset_table( Args: guids: Sequence of one or more guids - conn: A ConnectionPlus object with a connection to the database. + conn: An AtomicConnection object with a connection to the database. Returns: ASCII art table of information about the supplied guids. diff --git a/src/qcodes/dataset/data_set_cache.py b/src/qcodes/dataset/data_set_cache.py index b9948c1e4f9b..80d57ad4c0a6 100644 --- a/src/qcodes/dataset/data_set_cache.py +++ b/src/qcodes/dataset/data_set_cache.py @@ -25,7 +25,7 @@ import xarray as xr from qcodes.dataset.descriptions.rundescriber import RunDescriber - from qcodes.dataset.sqlite.connection import ConnectionPlus + from qcodes.dataset.sqlite.connection import AtomicConnection # used in forward refs that cannot be detected from .data_set import DataSet # noqa F401 @@ -201,7 +201,7 @@ def to_xarray_dataset( def load_new_data_from_db_and_append( - conn: ConnectionPlus, + conn: AtomicConnection, table_name: str, rundescriber: RunDescriber, write_status: Mapping[str, int | None], diff --git a/src/qcodes/dataset/data_set_in_memory.py b/src/qcodes/dataset/data_set_in_memory.py index 347cf86fd30c..bed5d709c0c8 100644 --- a/src/qcodes/dataset/data_set_in_memory.py +++ b/src/qcodes/dataset/data_set_in_memory.py @@ -23,7 +23,7 @@ from qcodes.dataset.export_config import DataExportType from qcodes.dataset.guids import generate_guid from qcodes.dataset.linked_datasets.links import Link, links_to_str -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic +from qcodes.dataset.sqlite.connection import AtomicConnection, atomic from qcodes.dataset.sqlite.database import conn_from_dbpath_or_conn from qcodes.dataset.sqlite.queries import ( RUNS_TABLE_COLUMNS, @@ -302,7 +302,7 @@ def _load_from_netcdf( return ds @classmethod - def _load_from_db(cls, conn: ConnectionPlus, guid: str) -> DataSetInMem: + def _load_from_db(cls, conn: AtomicConnection, guid: str) -> DataSetInMem: run_attributes = get_raw_run_attributes(conn, guid) if run_attributes is None: raise RuntimeError( diff --git a/src/qcodes/dataset/data_set_info.py b/src/qcodes/dataset/data_set_info.py index 3483b4631d7a..b79faffef9a5 100644 --- a/src/qcodes/dataset/data_set_info.py +++ b/src/qcodes/dataset/data_set_info.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from qcodes.dataset.descriptions.rundescriber import RunDescriber - from qcodes.dataset.sqlite.connection import ConnectionPlus + from qcodes.dataset.sqlite.connection import AtomicConnection class RunAttributesDict(TypedDict): @@ -34,7 +34,7 @@ class RunAttributesDict(TypedDict): snapshot: dict[str, Any] | None -def get_run_attributes(conn: ConnectionPlus, guid: str) -> RunAttributesDict | None: +def get_run_attributes(conn: AtomicConnection, guid: str) -> RunAttributesDict | None: """ Look up all information and metadata about a given dataset captured in the database. diff --git a/src/qcodes/dataset/database_extract_runs.py b/src/qcodes/dataset/database_extract_runs.py index a2a39f9ac4e6..045940185ecf 100644 --- a/src/qcodes/dataset/database_extract_runs.py +++ b/src/qcodes/dataset/database_extract_runs.py @@ -9,7 +9,7 @@ from qcodes.dataset.data_set import DataSet from qcodes.dataset.dataset_helpers import _add_run_to_runs_table from qcodes.dataset.experiment_container import _create_exp_if_needed -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic +from qcodes.dataset.sqlite.connection import AtomicConnection, atomic from qcodes.dataset.sqlite.database import ( connect, get_db_version_and_newest_available_version, @@ -128,7 +128,7 @@ def extract_runs_into_db( def _extract_single_dataset_into_db( - dataset: DataSet, target_conn: ConnectionPlus, target_exp_id: int + dataset: DataSet, target_conn: AtomicConnection, target_exp_id: int ) -> None: """ NB: This function should only be called from within diff --git a/src/qcodes/dataset/database_fix_functions.py b/src/qcodes/dataset/database_fix_functions.py index 404f977e35a2..bbcbb63750c4 100644 --- a/src/qcodes/dataset/database_fix_functions.py +++ b/src/qcodes/dataset/database_fix_functions.py @@ -15,7 +15,11 @@ from qcodes.dataset.descriptions.rundescriber import RunDescriber from qcodes.dataset.descriptions.versioning import v0 from qcodes.dataset.descriptions.versioning.converters import old_to_new -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic, atomic_transaction +from qcodes.dataset.sqlite.connection import ( + AtomicConnection, + atomic, + atomic_transaction, +) from qcodes.dataset.sqlite.db_upgrades.version import get_user_version from qcodes.dataset.sqlite.queries import ( _get_parameters, @@ -35,7 +39,7 @@ log = logging.getLogger(__name__) -def fix_version_4a_run_description_bug(conn: ConnectionPlus) -> dict[str, int]: +def fix_version_4a_run_description_bug(conn: AtomicConnection) -> dict[str, int]: """ Fix function to fix a bug where the RunDescriber accidentally wrote itself to string using the (new) InterDependencies_ object instead of the (old) @@ -124,7 +128,7 @@ def _convert_run_describer_v1_like_dict_to_v0_like_dict( return old_desc_dict -def fix_wrong_run_descriptions(conn: ConnectionPlus, run_ids: Sequence[int]) -> None: +def fix_wrong_run_descriptions(conn: AtomicConnection, run_ids: Sequence[int]) -> None: """ NB: This is a FIX function. Do not use it unless your database has been diagnosed with the problem that this function fixes. diff --git a/src/qcodes/dataset/dataset_helpers.py b/src/qcodes/dataset/dataset_helpers.py index 9388abdf8188..5cb11a13f0d4 100644 --- a/src/qcodes/dataset/dataset_helpers.py +++ b/src/qcodes/dataset/dataset_helpers.py @@ -11,12 +11,12 @@ if TYPE_CHECKING: from qcodes.dataset.data_set_protocol import DataSetProtocol - from qcodes.dataset.sqlite.connection import ConnectionPlus + from qcodes.dataset.sqlite.connection import AtomicConnection def _add_run_to_runs_table( dataset: DataSetProtocol, - target_conn: ConnectionPlus, + target_conn: AtomicConnection, target_exp_id: int, create_run_table: bool = True, ) -> tuple[int, int, str | None]: diff --git a/src/qcodes/dataset/experiment_container.py b/src/qcodes/dataset/experiment_container.py index 939e44fb3633..ad1f658e4a3c 100644 --- a/src/qcodes/dataset/experiment_container.py +++ b/src/qcodes/dataset/experiment_container.py @@ -7,7 +7,7 @@ from qcodes.dataset.data_set import DataSet, load_by_id, new_data_set from qcodes.dataset.experiment_settings import _set_default_experiment_id -from qcodes.dataset.sqlite.connection import ConnectionPlus, path_to_dbfile +from qcodes.dataset.sqlite.connection import AtomicConnection, path_to_dbfile from qcodes.dataset.sqlite.database import ( conn_from_dbpath_or_conn, connect, @@ -42,7 +42,7 @@ def __init__( name: str | None = None, sample_name: str | None = None, format_string: str = "{}-{}-{}", - conn: ConnectionPlus | None = None, + conn: AtomicConnection | None = None, ) -> None: """ Create or load an experiment. If exp_id is None, a new experiment is @@ -207,7 +207,7 @@ def __repr__(self) -> str: # public api -def experiments(conn: ConnectionPlus | None = None) -> list[Experiment]: +def experiments(conn: AtomicConnection | None = None) -> list[Experiment]: """ List all the experiments in the container (database file from config) @@ -228,7 +228,7 @@ def new_experiment( name: str, sample_name: str | None, format_string: str = "{}-{}-{}", - conn: ConnectionPlus | None = None, + conn: AtomicConnection | None = None, ) -> Experiment: """ Create a new experiment (in the database file from config) @@ -259,7 +259,7 @@ def new_experiment( return experiment -def load_experiment(exp_id: int, conn: ConnectionPlus | None = None) -> Experiment: +def load_experiment(exp_id: int, conn: AtomicConnection | None = None) -> Experiment: """ Load experiment with the specified id (from database file from config) @@ -304,7 +304,7 @@ def load_last_experiment() -> Experiment: def load_experiment_by_name( name: str, sample: str | None = None, - conn: ConnectionPlus | None = None, + conn: AtomicConnection | None = None, load_last_duplicate: bool = False, ) -> Experiment: """ @@ -365,7 +365,7 @@ def load_experiment_by_name( def load_or_create_experiment( experiment_name: str, sample_name: str | None = None, - conn: ConnectionPlus | None = None, + conn: AtomicConnection | None = None, load_last_duplicate: bool = False, ) -> Experiment: """ @@ -405,7 +405,7 @@ def load_or_create_experiment( def _create_exp_if_needed( - target_conn: ConnectionPlus, + target_conn: AtomicConnection, exp_name: str, sample_name: str, fmt_str: str, diff --git a/src/qcodes/dataset/experiment_settings.py b/src/qcodes/dataset/experiment_settings.py index 00da1c178b83..d1b6c1cb8d06 100644 --- a/src/qcodes/dataset/experiment_settings.py +++ b/src/qcodes/dataset/experiment_settings.py @@ -2,7 +2,7 @@ from __future__ import annotations -from qcodes.dataset.sqlite.connection import ConnectionPlus, path_to_dbfile +from qcodes.dataset.sqlite.connection import AtomicConnection, path_to_dbfile from qcodes.dataset.sqlite.queries import get_last_experiment _default_experiment: dict[str, int | None] = {} @@ -38,7 +38,7 @@ def _get_latest_default_experiment_id(db_path: str) -> int | None: return _default_experiment.get(db_path, None) -def reset_default_experiment_id(conn: ConnectionPlus | None = None) -> None: +def reset_default_experiment_id(conn: AtomicConnection | None = None) -> None: """ Resets the default experiment id to to the last experiment in the db. """ @@ -50,7 +50,7 @@ def reset_default_experiment_id(conn: ConnectionPlus | None = None) -> None: _default_experiment[db_path] = None -def get_default_experiment_id(conn: ConnectionPlus) -> int: +def get_default_experiment_id(conn: AtomicConnection) -> int: """ Returns the latest created/ loaded experiment's exp_id as the default experiment. If it is not set the maximum exp_id returned as the default. diff --git a/src/qcodes/dataset/measurements.py b/src/qcodes/dataset/measurements.py index e04a0cceed20..9d0c24effe01 100644 --- a/src/qcodes/dataset/measurements.py +++ b/src/qcodes/dataset/measurements.py @@ -59,7 +59,7 @@ from qcodes.dataset.descriptions.versioning.rundescribertypes import Shapes from qcodes.dataset.experiment_container import Experiment - from qcodes.dataset.sqlite.connection import ConnectionPlus + from qcodes.dataset.sqlite.connection import AtomicConnection from qcodes.dataset.sqlite.query_helpers import VALUE log = logging.getLogger(__name__) @@ -631,7 +631,7 @@ def __enter__(self) -> DataSaver: if self.experiment is not None: exp_id: int | None = self.experiment.exp_id path_to_db: str | None = self.experiment.path_to_db - conn: ConnectionPlus | None = self.experiment.conn + conn: AtomicConnection | None = self.experiment.conn else: exp_id = None path_to_db = None diff --git a/src/qcodes/dataset/sqlite/connection.py b/src/qcodes/dataset/sqlite/connection.py index c469c806d477..2e82099dbfaa 100644 --- a/src/qcodes/dataset/sqlite/connection.py +++ b/src/qcodes/dataset/sqlite/connection.py @@ -7,22 +7,29 @@ from __future__ import annotations import logging +import sqlite3 from contextlib import contextmanager from typing import TYPE_CHECKING, Any import wrapt # type: ignore[import-untyped] +from typing_extensions import deprecated -from qcodes.utils import DelayedKeyboardInterrupt +from qcodes.utils import DelayedKeyboardInterrupt, QCoDeSDeprecationWarning if TYPE_CHECKING: - import sqlite3 from collections.abc import Iterator log = logging.getLogger(__name__) +@deprecated( + "ConnectionPlus is deprecated. Please use connect to create an AtomicConnection.", + category=QCoDeSDeprecationWarning, +) class ConnectionPlus(wrapt.ObjectProxy): # pyright: ignore[reportUntypedBaseClass] """ + Note this is a legacy class. Please refer to :class:`AtomicConnection` + A class to extend the sqlite3.Connection object. Since sqlite3.Connection has no __dict__, we can not directly add attributes to its instance directly. @@ -48,7 +55,7 @@ class ConnectionPlus(wrapt.ObjectProxy): # pyright: ignore[reportUntypedBaseCla def __init__(self, sqlite3_connection: sqlite3.Connection): super().__init__(sqlite3_connection) - if isinstance(sqlite3_connection, ConnectionPlus): + if isinstance(sqlite3_connection, ConnectionPlus): # pyright: ignore[reportDeprecated] raise ValueError( "Attempted to create `ConnectionPlus` from a " "`ConnectionPlus` object which is not allowed." @@ -57,9 +64,39 @@ def __init__(self, sqlite3_connection: sqlite3.Connection): self.path_to_dbfile = path_to_dbfile(sqlite3_connection) +class AtomicConnection(sqlite3.Connection): + """ + A class to extend the sqlite3.Connection object. This extends + Connection to allow addition operations to be performed atomically. + + It is recommended to create an AtomicConnection using the function :func:`connect` + + """ + + atomic_in_progress: bool = False + """ + a bool describing whether the connection is + currently in the middle of an atomic block of transactions, thus + allowing to nest `atomic` context managers + """ + path_to_dbfile: str = "" + """ + Path to the database file of the connection. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self.path_to_dbfile = path_to_dbfile(self) + + +@deprecated( + "make_connection_plus_from is deprecated. Please use connect to create an AtomicConnection", + category=QCoDeSDeprecationWarning, +) def make_connection_plus_from( - conn: sqlite3.Connection | ConnectionPlus, -) -> ConnectionPlus: + conn: sqlite3.Connection | ConnectionPlus, # pyright: ignore[reportDeprecated] +) -> ConnectionPlus: # pyright: ignore[reportDeprecated] """ Makes a ConnectionPlus connection object out of a given argument. @@ -73,15 +110,15 @@ def make_connection_plus_from( the "same" connection but as ConnectionPlus object """ - if not isinstance(conn, ConnectionPlus): - conn_plus = ConnectionPlus(conn) + if not isinstance(conn, ConnectionPlus): # pyright: ignore[reportDeprecated] + conn_plus = ConnectionPlus(conn) # pyright: ignore[reportDeprecated] else: conn_plus = conn return conn_plus @contextmanager -def atomic(conn: ConnectionPlus) -> Iterator[ConnectionPlus]: +def atomic(conn: AtomicConnection) -> Iterator[AtomicConnection]: """ Guard a series of transactions as atomic. @@ -97,10 +134,10 @@ def atomic(conn: ConnectionPlus) -> Iterator[ConnectionPlus]: """ with DelayedKeyboardInterrupt(context={"reason": "sqlite atomic operation"}): - if not isinstance(conn, ConnectionPlus): + if not isinstance(conn, ConnectionPlus | AtomicConnection): # pyright: ignore[reportDeprecated] raise ValueError( "atomic context manager only accepts " - "ConnectionPlus database connection objects." + "AtomicConnection or ConnectionPlus database connection objects." ) is_outmost = not (conn.atomic_in_progress) @@ -135,7 +172,7 @@ def atomic(conn: ConnectionPlus) -> Iterator[ConnectionPlus]: conn.atomic_in_progress = old_atomic_in_progress -def transaction(conn: ConnectionPlus, sql: str, *args: Any) -> sqlite3.Cursor: +def transaction(conn: AtomicConnection, sql: str, *args: Any) -> sqlite3.Cursor: """Perform a transaction. The transaction needs to be committed or rolled back. @@ -157,7 +194,7 @@ def transaction(conn: ConnectionPlus, sql: str, *args: Any) -> sqlite3.Cursor: return c -def atomic_transaction(conn: ConnectionPlus, sql: str, *args: Any) -> sqlite3.Cursor: +def atomic_transaction(conn: AtomicConnection, sql: str, *args: Any) -> sqlite3.Cursor: """Perform an **atomic** transaction. The transaction is committed if there are no exceptions else the transaction is rolled back. @@ -179,7 +216,7 @@ def atomic_transaction(conn: ConnectionPlus, sql: str, *args: Any) -> sqlite3.Cu return c -def path_to_dbfile(conn: ConnectionPlus | sqlite3.Connection) -> str: +def path_to_dbfile(conn: AtomicConnection | sqlite3.Connection) -> str: """ Return the path of the database file that the conn object is connected to """ diff --git a/src/qcodes/dataset/sqlite/database.py b/src/qcodes/dataset/sqlite/database.py index a9644ea383ea..b1ea7fc761ee 100644 --- a/src/qcodes/dataset/sqlite/database.py +++ b/src/qcodes/dataset/sqlite/database.py @@ -18,7 +18,7 @@ import qcodes from qcodes.dataset.experiment_settings import reset_default_experiment_id -from qcodes.dataset.sqlite.connection import ConnectionPlus +from qcodes.dataset.sqlite.connection import AtomicConnection from qcodes.dataset.sqlite.db_upgrades import ( _latest_available_version, perform_db_upgrade, @@ -119,7 +119,9 @@ def _adapt_complex(value: complex | np.complexfloating) -> sqlite3.Binary: return sqlite3.Binary(out.read()) -def connect(name: str | Path, debug: bool = False, version: int = -1) -> ConnectionPlus: +def connect( + name: str | Path, debug: bool = False, version: int = -1 +) -> AtomicConnection: """ Connect or create database. If debug the queries will be echoed back. This function takes care of registering the numpy/sqlite type @@ -133,7 +135,7 @@ def connect(name: str | Path, debug: bool = False, version: int = -1) -> Connect Returns: connection object to the database (note, it is - :class:`ConnectionPlus`, not :class:`sqlite3.Connection`) + :class:`AtomicConnection`, which is a subclass of :class:`sqlite3.Connection`) """ # register numpy->binary(TEXT) adapter @@ -141,10 +143,12 @@ def connect(name: str | Path, debug: bool = False, version: int = -1) -> Connect # register binary(TEXT) -> numpy converter sqlite3.register_converter("array", _convert_array) - sqlite3_conn = sqlite3.connect( - name, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=True + conn = sqlite3.connect( + name, + detect_types=sqlite3.PARSE_DECLTYPES, + check_same_thread=True, + factory=AtomicConnection, ) - conn = ConnectionPlus(sqlite3_conn) latest_supported_version = _latest_available_version() db_version = get_user_version(conn) @@ -232,7 +236,7 @@ def initialise_database(journal_mode: JournalMode | None = "WAL") -> None: del conn -def set_journal_mode(conn: ConnectionPlus, journal_mode: JournalMode) -> None: +def set_journal_mode(conn: AtomicConnection, journal_mode: JournalMode) -> None: """ Set the ``atomic commit and rollback mode`` of the sqlite database. See https://www.sqlite.org/pragma.html#pragma_journal_mode for details. @@ -291,20 +295,20 @@ def initialised_database_at(db_file_with_abs_path: str | Path) -> Iterator[None] def conn_from_dbpath_or_conn( - conn: ConnectionPlus | None, path_to_db: str | Path | None -) -> ConnectionPlus: + conn: AtomicConnection | None, path_to_db: str | Path | None +) -> AtomicConnection: """ A small helper function to abstract the logic needed for functions - that take either a `ConnectionPlus` or the path to a db file. + that take either an `AtomicConnection` or the path to a db file. If neither is given this will fall back to the default db location. It is an error to supply both. Args: - conn: A ConnectionPlus object pointing to a sqlite database + conn: A AtomicConnection object pointing to a sqlite database path_to_db: The path to a db file. Returns: - A `ConnectionPlus` object + A `AtomicConnection` object """ diff --git a/src/qcodes/dataset/sqlite/db_upgrades/__init__.py b/src/qcodes/dataset/sqlite/db_upgrades/__init__.py index ce831c21cbd5..3c9000f07377 100644 --- a/src/qcodes/dataset/sqlite/db_upgrades/__init__.py +++ b/src/qcodes/dataset/sqlite/db_upgrades/__init__.py @@ -25,7 +25,7 @@ from qcodes.dataset.guids import generate_guid from qcodes.dataset.sqlite.connection import ( - ConnectionPlus, + AtomicConnection, atomic, atomic_transaction, transaction, @@ -41,7 +41,7 @@ # see https://mypy.readthedocs.io/en/stable/protocols.html#callback-protocols class TUpgraderFunction(Protocol): def __call__( - self, conn: ConnectionPlus, show_progress_bar: bool = True + self, conn: AtomicConnection, show_progress_bar: bool = True ) -> None: ... @property @@ -60,7 +60,7 @@ def _latest_available_version() -> int: return len(_UPGRADE_ACTIONS) -def _get_no_of_runs(conn: ConnectionPlus) -> int: +def _get_no_of_runs(conn: AtomicConnection) -> int: no_of_runs_query = "SELECT max(run_id) FROM runs" no_of_runs = one(atomic_transaction(conn, no_of_runs_query), "max(run_id)") no_of_runs = no_of_runs or 0 @@ -72,7 +72,7 @@ def upgrader(func: TUpgraderFunction) -> TUpgraderFunction: Decorator for database version upgrade functions. An upgrade function must have the name `perform_db_upgrade_N_to_M` where N = M-1. For simplicity, an upgrade function must take a single argument of type - `ConnectionPlus`. The upgrade function must either perform the upgrade + `AtomicConnection`. The upgrade function must either perform the upgrade and return (no return values allowed) or fail to perform the upgrade, in which case it must raise a RuntimeError. A failed upgrade must be completely rolled back before the RuntimeError is raises. @@ -104,7 +104,7 @@ def upgrader(func: TUpgraderFunction) -> TUpgraderFunction: ) @wraps(func) - def do_upgrade(conn: ConnectionPlus, show_progress_bar: bool = True) -> None: + def do_upgrade(conn: AtomicConnection, show_progress_bar: bool = True) -> None: log.info(f"Starting database upgrade version {from_version} to {to_version}") start_version = get_user_version(conn) @@ -126,7 +126,7 @@ def do_upgrade(conn: ConnectionPlus, show_progress_bar: bool = True) -> None: return do_upgrade -def perform_db_upgrade(conn: ConnectionPlus, version: int = -1) -> None: +def perform_db_upgrade(conn: AtomicConnection, version: int = -1) -> None: """ This is intended to perform all upgrades as needed to bring the db from version 0 to the most current version (or the version specified). @@ -156,7 +156,7 @@ def perform_db_upgrade(conn: ConnectionPlus, version: int = -1) -> None: @upgrader def perform_db_upgrade_0_to_1( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 0 to version 1 @@ -205,7 +205,7 @@ def perform_db_upgrade_0_to_1( @upgrader def perform_db_upgrade_1_to_2( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 1 to version 2 @@ -243,7 +243,7 @@ def perform_db_upgrade_1_to_2( @upgrader def perform_db_upgrade_2_to_3( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 2 to version 3 @@ -260,7 +260,7 @@ def perform_db_upgrade_2_to_3( @upgrader def perform_db_upgrade_3_to_4( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 3 to version 4. This really @@ -278,7 +278,7 @@ def perform_db_upgrade_3_to_4( @upgrader def perform_db_upgrade_4_to_5( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 4 to version 5. @@ -299,7 +299,7 @@ def perform_db_upgrade_4_to_5( @upgrader def perform_db_upgrade_5_to_6( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 5 to version 6. @@ -315,7 +315,7 @@ def perform_db_upgrade_5_to_6( @upgrader def perform_db_upgrade_6_to_7( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 6 to version 7 @@ -352,7 +352,7 @@ def perform_db_upgrade_6_to_7( @upgrader def perform_db_upgrade_7_to_8( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 7 to version 8. @@ -370,7 +370,7 @@ def perform_db_upgrade_7_to_8( @upgrader def perform_db_upgrade_8_to_9( - conn: ConnectionPlus, show_progress_bar: bool = True + conn: AtomicConnection, show_progress_bar: bool = True ) -> None: """ Perform the upgrade from version 8 to version 9. diff --git a/src/qcodes/dataset/sqlite/db_upgrades/upgrade_2_to_3.py b/src/qcodes/dataset/sqlite/db_upgrades/upgrade_2_to_3.py index c3caf4fed984..e51e76e85943 100644 --- a/src/qcodes/dataset/sqlite/db_upgrades/upgrade_2_to_3.py +++ b/src/qcodes/dataset/sqlite/db_upgrades/upgrade_2_to_3.py @@ -11,7 +11,7 @@ from qcodes.dataset.descriptions.param_spec import ParamSpec from qcodes.dataset.descriptions.versioning.v0 import InterDependencies from qcodes.dataset.sqlite.connection import ( - ConnectionPlus, + AtomicConnection, atomic, atomic_transaction, transaction, @@ -24,7 +24,7 @@ log = logging.getLogger(__name__) -def _2to3_get_result_tables(conn: ConnectionPlus) -> dict[int, str]: +def _2to3_get_result_tables(conn: AtomicConnection) -> dict[int, str]: rst_query = "SELECT run_id, result_table_name FROM runs" cur = conn.cursor() cur.execute(rst_query) @@ -37,7 +37,7 @@ def _2to3_get_result_tables(conn: ConnectionPlus) -> dict[int, str]: return results -def _2to3_get_layout_ids(conn: ConnectionPlus) -> defaultdict[int, list[int]]: +def _2to3_get_layout_ids(conn: AtomicConnection) -> defaultdict[int, list[int]]: query = """ select runs.run_id, layouts.layout_id FROM layouts @@ -56,7 +56,7 @@ def _2to3_get_layout_ids(conn: ConnectionPlus) -> defaultdict[int, list[int]]: return results -def _2to3_get_indeps(conn: ConnectionPlus) -> defaultdict[int, list[int]]: +def _2to3_get_indeps(conn: AtomicConnection) -> defaultdict[int, list[int]]: query = """ SELECT layouts.run_id, layouts.layout_id FROM layouts @@ -75,7 +75,7 @@ def _2to3_get_indeps(conn: ConnectionPlus) -> defaultdict[int, list[int]]: return results -def _2to3_get_deps(conn: ConnectionPlus) -> defaultdict[int, list[int]]: +def _2to3_get_deps(conn: AtomicConnection) -> defaultdict[int, list[int]]: query = """ SELECT layouts.run_id, layouts.layout_id FROM layouts @@ -94,7 +94,7 @@ def _2to3_get_deps(conn: ConnectionPlus) -> defaultdict[int, list[int]]: return results -def _2to3_get_dependencies(conn: ConnectionPlus) -> defaultdict[int, list[int]]: +def _2to3_get_dependencies(conn: AtomicConnection) -> defaultdict[int, list[int]]: query = """ SELECT dependent, independent FROM dependencies @@ -115,7 +115,7 @@ def _2to3_get_dependencies(conn: ConnectionPlus) -> defaultdict[int, list[int]]: return results -def _2to3_get_layouts(conn: ConnectionPlus) -> dict[int, tuple[str, str, str, str]]: +def _2to3_get_layouts(conn: AtomicConnection) -> dict[int, tuple[str, str, str, str]]: query = """ SELECT layout_id, parameter, label, unit, inferred_from FROM layouts @@ -130,7 +130,7 @@ def _2to3_get_layouts(conn: ConnectionPlus) -> dict[int, tuple[str, str, str, st def _2to3_get_paramspecs( - conn: ConnectionPlus, + conn: AtomicConnection, layout_ids: list[int], layouts: Mapping[int, tuple[str, str, str, str]], dependencies: Mapping[int, Sequence[int]], @@ -186,7 +186,7 @@ def _2to3_get_paramspecs( return paramspecs -def upgrade_2_to_3(conn: ConnectionPlus, show_progress_bar: bool = True) -> None: +def upgrade_2_to_3(conn: AtomicConnection, show_progress_bar: bool = True) -> None: """ Perform the upgrade from version 2 to version 3 diff --git a/src/qcodes/dataset/sqlite/db_upgrades/upgrade_3_to_4.py b/src/qcodes/dataset/sqlite/db_upgrades/upgrade_3_to_4.py index 90d3f13cd58e..7ec70e299e6a 100644 --- a/src/qcodes/dataset/sqlite/db_upgrades/upgrade_3_to_4.py +++ b/src/qcodes/dataset/sqlite/db_upgrades/upgrade_3_to_4.py @@ -7,7 +7,11 @@ from tqdm import tqdm from qcodes.dataset.descriptions.versioning.v0 import InterDependencies -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic, atomic_transaction +from qcodes.dataset.sqlite.connection import ( + AtomicConnection, + atomic, + atomic_transaction, +) from qcodes.dataset.sqlite.db_upgrades.upgrade_2_to_3 import ( _2to3_get_dependencies, _2to3_get_deps, @@ -22,7 +26,7 @@ log = logging.getLogger(__name__) -def upgrade_3_to_4(conn: ConnectionPlus, show_progress_bar: bool = True) -> None: +def upgrade_3_to_4(conn: AtomicConnection, show_progress_bar: bool = True) -> None: """ Perform the upgrade from version 3 to version 4. This really repeats the version 3 upgrade as it originally had two bugs in diff --git a/src/qcodes/dataset/sqlite/db_upgrades/upgrade_5_to_6.py b/src/qcodes/dataset/sqlite/db_upgrades/upgrade_5_to_6.py index 9153d7b80bf7..abcd50fbacbd 100644 --- a/src/qcodes/dataset/sqlite/db_upgrades/upgrade_5_to_6.py +++ b/src/qcodes/dataset/sqlite/db_upgrades/upgrade_5_to_6.py @@ -6,12 +6,16 @@ from tqdm import tqdm from qcodes.dataset.descriptions.versioning.v0 import InterDependencies -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic, atomic_transaction +from qcodes.dataset.sqlite.connection import ( + AtomicConnection, + atomic, + atomic_transaction, +) from qcodes.dataset.sqlite.queries import get_run_description, update_run_description from qcodes.dataset.sqlite.query_helpers import one -def upgrade_5_to_6(conn: ConnectionPlus, show_progress_bar: bool = True) -> None: +def upgrade_5_to_6(conn: AtomicConnection, show_progress_bar: bool = True) -> None: """ Perform the upgrade from version 5 to version 6. diff --git a/src/qcodes/dataset/sqlite/db_upgrades/version.py b/src/qcodes/dataset/sqlite/db_upgrades/version.py index b8bb0acafcbc..566a67f5da98 100644 --- a/src/qcodes/dataset/sqlite/db_upgrades/version.py +++ b/src/qcodes/dataset/sqlite/db_upgrades/version.py @@ -1,14 +1,14 @@ from __future__ import annotations -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic_transaction +from qcodes.dataset.sqlite.connection import AtomicConnection, atomic_transaction from qcodes.dataset.sqlite.query_helpers import one -def get_user_version(conn: ConnectionPlus) -> int: +def get_user_version(conn: AtomicConnection) -> int: curr = atomic_transaction(conn, "PRAGMA user_version") res = one(curr, 0) return res -def set_user_version(conn: ConnectionPlus, version: int) -> None: +def set_user_version(conn: AtomicConnection, version: int) -> None: atomic_transaction(conn, f"PRAGMA user_version({version})") diff --git a/src/qcodes/dataset/sqlite/initial_schema.py b/src/qcodes/dataset/sqlite/initial_schema.py index b27a9ef43319..7f93b694f1a8 100644 --- a/src/qcodes/dataset/sqlite/initial_schema.py +++ b/src/qcodes/dataset/sqlite/initial_schema.py @@ -6,10 +6,10 @@ from __future__ import annotations -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic, transaction +from qcodes.dataset.sqlite.connection import AtomicConnection, atomic, transaction -def init_db(conn: ConnectionPlus) -> None: +def init_db(conn: AtomicConnection) -> None: with atomic(conn) as atomic_conn: transaction(atomic_conn, _experiment_table_schema) transaction(atomic_conn, _runs_table_schema) diff --git a/src/qcodes/dataset/sqlite/queries.py b/src/qcodes/dataset/sqlite/queries.py index 831e4cd278ea..c992461c875f 100644 --- a/src/qcodes/dataset/sqlite/queries.py +++ b/src/qcodes/dataset/sqlite/queries.py @@ -27,7 +27,7 @@ from qcodes.dataset.descriptions.versioning.converters import new_to_old, old_to_new from qcodes.dataset.guids import build_guid_from_components, parse_guid from qcodes.dataset.sqlite.connection import ( - ConnectionPlus, + AtomicConnection, atomic, atomic_transaction, transaction, @@ -89,7 +89,7 @@ ) -def is_run_id_in_database(conn: ConnectionPlus, *run_ids: int) -> dict[int, bool]: +def is_run_id_in_database(conn: AtomicConnection, *run_ids: int) -> dict[int, bool]: """ Look up run_ids and return a dictionary with the answers to the question "is this run_id in the database?" @@ -120,7 +120,7 @@ def is_run_id_in_database(conn: ConnectionPlus, *run_ids: int) -> dict[int, bool def get_parameter_data( - conn: ConnectionPlus, + conn: AtomicConnection, table_name: str, columns: Sequence[str] = (), start: int | None = None, @@ -170,7 +170,7 @@ def get_parameter_data( def get_shaped_parameter_data_for_one_paramtree( - conn: ConnectionPlus, + conn: AtomicConnection, table_name: str, rundescriber: RunDescriber, output_param: str, @@ -214,7 +214,7 @@ def get_shaped_parameter_data_for_one_paramtree( def get_rundescriber_from_result_table_name( - conn: ConnectionPlus, result_table_name: str + conn: AtomicConnection, result_table_name: str ) -> RunDescriber: sql = """ SELECT run_id FROM runs WHERE result_table_name = ? @@ -226,7 +226,7 @@ def get_rundescriber_from_result_table_name( def get_parameter_data_for_one_paramtree( - conn: ConnectionPlus, + conn: AtomicConnection, table_name: str, rundescriber: RunDescriber, output_param: str, @@ -328,7 +328,7 @@ def _expand_data_to_arrays( def _get_data_for_one_param_tree( - conn: ConnectionPlus, + conn: AtomicConnection, table_name: str, interdeps: InterDependencies_, output_param: str, @@ -355,7 +355,9 @@ def _get_data_for_one_param_tree( return res, paramspecs, n_rows -def get_parameter_db_row(conn: ConnectionPlus, table_name: str, param_name: str) -> int: +def get_parameter_db_row( + conn: AtomicConnection, table_name: str, param_name: str +) -> int: """ Get the total number of not-null values of a parameter @@ -377,7 +379,7 @@ def get_parameter_db_row(conn: ConnectionPlus, table_name: str, param_name: str) return one(c, 0) -def get_table_max_id(conn: ConnectionPlus, table_name: str) -> int: +def get_table_max_id(conn: AtomicConnection, table_name: str) -> int: """ Get the max id of a table @@ -399,7 +401,7 @@ def get_table_max_id(conn: ConnectionPlus, table_name: str) -> int: def _get_offset_limit_for_callback( - conn: ConnectionPlus, table_name: str, param_name: str + conn: AtomicConnection, table_name: str, param_name: str ) -> tuple[np.ndarray, np.ndarray]: """ Since sqlite3 does not allow to keep track of the data loading progress, @@ -450,7 +452,7 @@ def _get_offset_limit_for_callback( def get_parameter_tree_values( - conn: ConnectionPlus, + conn: AtomicConnection, result_table_name: str, toplevel_param_name: str, *other_param_names: str, @@ -545,7 +547,7 @@ def get_parameter_tree_values( def get_runid_from_expid_and_counter( - conn: ConnectionPlus, exp_id: int, counter: int + conn: AtomicConnection, exp_id: int, counter: int ) -> int: """ Get the run_id of a run in the specified experiment with the specified @@ -569,7 +571,7 @@ def get_runid_from_expid_and_counter( def get_guid_from_expid_and_counter( - conn: ConnectionPlus, exp_id: int, counter: int + conn: AtomicConnection, exp_id: int, counter: int ) -> str: """ Get the guid of a run in the specified experiment with the specified @@ -592,7 +594,7 @@ def get_guid_from_expid_and_counter( return run_id -def get_runid_from_guid(conn: ConnectionPlus, guid: str) -> int | None: +def get_runid_from_guid(conn: AtomicConnection, guid: str) -> int | None: """ Get the run_id of a run based on the guid @@ -632,7 +634,7 @@ def get_runid_from_guid(conn: ConnectionPlus, guid: str) -> int | None: def _query_guids_from_run_spec( - conn: ConnectionPlus, + conn: AtomicConnection, captured_run_id: int | None = None, captured_counter: int | None = None, experiment_name: str | None = None, @@ -702,7 +704,7 @@ def _query_guids_from_run_spec( def _get_layout_id( - conn: ConnectionPlus, parameter: ParamSpec | str, run_id: int + conn: AtomicConnection, parameter: ParamSpec | str, run_id: int ) -> int: """ Get the layout id of a parameter in a given run @@ -736,7 +738,7 @@ def _get_layout_id( return res -def _get_dependents(conn: ConnectionPlus, run_id: int) -> list[int]: +def _get_dependents(conn: AtomicConnection, run_id: int) -> list[int]: """ Get dependent layout_ids for a certain run_id, i.e. the layout_ids of all the dependent variables @@ -750,7 +752,7 @@ def _get_dependents(conn: ConnectionPlus, run_id: int) -> list[int]: return res -def _get_dependencies(conn: ConnectionPlus, layout_id: int) -> list[tuple[int, int]]: +def _get_dependencies(conn: AtomicConnection, layout_id: int) -> list[tuple[int, int]]: """ Get the dependencies of a certain dependent variable (indexed by its layout_id) @@ -771,7 +773,7 @@ def _get_dependencies(conn: ConnectionPlus, layout_id: int) -> list[tuple[int, i def new_experiment( - conn: ConnectionPlus, + conn: AtomicConnection, name: str, sample_name: str, format_string: str | None = "{}-{}-{}", @@ -818,7 +820,7 @@ def new_experiment( # TODO(WilliamHPNielsen): we should remove the redundant # is_completed def mark_run_complete( - conn: ConnectionPlus, + conn: AtomicConnection, run_id: int, timestamp: float | None = None, override: bool = False, @@ -853,7 +855,7 @@ def mark_run_complete( atomic_transaction(conn, query, timestamp, True, run_id) -def completed(conn: ConnectionPlus, run_id: int) -> bool: +def completed(conn: AtomicConnection, run_id: int) -> bool: """Check if the run is complete Args: @@ -865,7 +867,7 @@ def completed(conn: ConnectionPlus, run_id: int) -> bool: def get_completed_timestamp_from_run_id( - conn: ConnectionPlus, run_id: int + conn: AtomicConnection, run_id: int ) -> float | None: """ Retrieve the timestamp when the given measurement run was completed @@ -889,7 +891,7 @@ def get_completed_timestamp_from_run_id( return ts -def get_guid_from_run_id(conn: ConnectionPlus, run_id: int) -> str | None: +def get_guid_from_run_id(conn: AtomicConnection, run_id: int) -> str | None: """ Get the guid of the given run. Returns None if the run is not found @@ -910,7 +912,7 @@ def get_guid_from_run_id(conn: ConnectionPlus, run_id: int) -> str | None: def get_guids_from_multiple_run_ids( - conn: ConnectionPlus, run_ids: Iterable[int] + conn: AtomicConnection, run_ids: Iterable[int] ) -> list[str]: """ Retrieve guids of runs in the connected database specified by their run ids. @@ -937,7 +939,7 @@ def get_guids_from_multiple_run_ids( return guids -def finish_experiment(conn: ConnectionPlus, exp_id: int) -> None: +def finish_experiment(conn: AtomicConnection, exp_id: int) -> None: """Finish experiment Args: @@ -951,7 +953,7 @@ def finish_experiment(conn: ConnectionPlus, exp_id: int) -> None: atomic_transaction(conn, query, time.time(), exp_id) -def get_run_counter(conn: ConnectionPlus, exp_id: int) -> int: +def get_run_counter(conn: AtomicConnection, exp_id: int) -> int: """Get the experiment run counter Args: @@ -971,7 +973,7 @@ def get_run_counter(conn: ConnectionPlus, exp_id: int) -> int: return counter -def get_experiments(conn: ConnectionPlus) -> list[int]: +def get_experiments(conn: AtomicConnection) -> list[int]: """ Get a list of experiments @@ -988,7 +990,7 @@ def get_experiments(conn: ConnectionPlus) -> list[int]: return [exp_id for (exp_id,) in c.fetchall()] -def get_matching_exp_ids(conn: ConnectionPlus, **match_conditions: Any) -> list[int]: +def get_matching_exp_ids(conn: AtomicConnection, **match_conditions: Any) -> list[int]: """ Get exp_ids for experiments matching the match_conditions @@ -1036,7 +1038,9 @@ def get_matching_exp_ids(conn: ConnectionPlus, **match_conditions: Any) -> list[ return [exp_id for (exp_id,) in cursor.fetchall()] -def get_exp_ids_from_run_ids(conn: ConnectionPlus, run_ids: Sequence[int]) -> list[int]: +def get_exp_ids_from_run_ids( + conn: AtomicConnection, run_ids: Sequence[int] +) -> list[int]: """ Get the corresponding exp_id for a sequence of run_ids @@ -1061,7 +1065,7 @@ def get_exp_ids_from_run_ids(conn: ConnectionPlus, run_ids: Sequence[int]) -> li return [exp_id for row in rows for exp_id in row] -def get_last_experiment(conn: ConnectionPlus) -> int | None: +def get_last_experiment(conn: AtomicConnection) -> int | None: """ Return last started experiment id @@ -1072,7 +1076,7 @@ def get_last_experiment(conn: ConnectionPlus) -> int | None: return c.fetchall()[0][0] -def get_runs(conn: ConnectionPlus, exp_id: int | None = None) -> list[int]: +def get_runs(conn: AtomicConnection, exp_id: int | None = None) -> list[int]: """Get a list of runs. Args: @@ -1098,7 +1102,7 @@ def get_runs(conn: ConnectionPlus, exp_id: int | None = None) -> list[int]: return [run_id for (run_id,) in c.fetchall()] -def get_last_run(conn: ConnectionPlus, exp_id: int | None = None) -> int | None: +def get_last_run(conn: AtomicConnection, exp_id: int | None = None) -> int | None: """ Get run_id of the last run in experiment with exp_id @@ -1128,7 +1132,7 @@ def get_last_run(conn: ConnectionPlus, exp_id: int | None = None) -> int | None: return one(c, "run_id") -def run_exists(conn: ConnectionPlus, run_id: int) -> bool: +def run_exists(conn: AtomicConnection, run_id: int) -> bool: # the following query always returns a single tuple with an integer # value of `1` or `0` for existing and non-existing run_id in the database query = """ @@ -1160,7 +1164,7 @@ def format_table_name(fmt_str: str, name: str, exp_id: int, run_counter: int) -> def _insert_run( - conn: ConnectionPlus, + conn: AtomicConnection, exp_id: int, name: str, guid: str, @@ -1324,7 +1328,7 @@ def _insert_run( def _update_experiment_run_counter( - conn: ConnectionPlus, exp_id: int, run_counter: int + conn: AtomicConnection, exp_id: int, run_counter: int ) -> None: query = """ UPDATE experiments @@ -1334,7 +1338,7 @@ def _update_experiment_run_counter( atomic_transaction(conn, query, run_counter, exp_id) -def _get_parameters(conn: ConnectionPlus, run_id: int) -> list[ParamSpec]: +def _get_parameters(conn: AtomicConnection, run_id: int) -> list[ParamSpec]: """ Get the list of param specs for run @@ -1358,7 +1362,7 @@ def _get_parameters(conn: ConnectionPlus, run_id: int) -> list[ParamSpec]: ] -def _get_paramspec(conn: ConnectionPlus, run_id: int, param_name: str) -> ParamSpec: +def _get_paramspec(conn: AtomicConnection, run_id: int, param_name: str) -> ParamSpec: """ Get the ParamSpec object for the given parameter name in the given run @@ -1427,7 +1431,9 @@ def _get_paramspec(conn: ConnectionPlus, run_id: int, param_name: str) -> ParamS return parspec -def update_run_description(conn: ConnectionPlus, run_id: int, description: str) -> None: +def update_run_description( + conn: AtomicConnection, run_id: int, description: str +) -> None: """ Update the run_description field for the given run_id. The description string must be a valid JSON string representation of a RunDescriber object @@ -1444,7 +1450,7 @@ def update_run_description(conn: ConnectionPlus, run_id: int, description: str) def _update_run_description( - conn: ConnectionPlus, run_id: int, description: str + conn: AtomicConnection, run_id: int, description: str ) -> None: """ Update the run_description field for the given run_id. The description @@ -1459,7 +1465,7 @@ def _update_run_description( atomic_conn.cursor().execute(sql, (description, run_id)) -def update_parent_datasets(conn: ConnectionPlus, run_id: int, links_str: str) -> None: +def update_parent_datasets(conn: AtomicConnection, run_id: int, links_str: str) -> None: """ Update (i.e. overwrite) the parent_datasets field for the given run_id """ @@ -1476,7 +1482,7 @@ def update_parent_datasets(conn: ConnectionPlus, run_id: int, links_str: str) -> def set_run_timestamp( - conn: ConnectionPlus, run_id: int, timestamp: float | None = None + conn: AtomicConnection, run_id: int, timestamp: float | None = None ) -> None: """ Set the run_timestamp for the run with the given run_id. If the @@ -1518,7 +1524,7 @@ def set_run_timestamp( def add_parameter( *parameter: ParamSpec, - conn: ConnectionPlus, + conn: AtomicConnection, run_id: int, insert_into_results_table: bool, ) -> None: @@ -1567,7 +1573,7 @@ def add_parameter( def _add_parameters_to_layout_and_deps( - conn: ConnectionPlus, run_id: int, *parameter: ParamSpec + conn: AtomicConnection, run_id: int, *parameter: ParamSpec ) -> sqlite3.Cursor: layout_args: list[int | str] = [] for p in parameter: @@ -1619,7 +1625,7 @@ def _validate_table_name(table_name: str) -> bool: def _create_run_table( - conn: ConnectionPlus, + conn: AtomicConnection, formatted_name: str, parameters: Sequence[ParamSpecBase] | None = None, values: VALUES | None = None, @@ -1668,7 +1674,7 @@ def _create_run_table( def create_run( - conn: ConnectionPlus, + conn: AtomicConnection, exp_id: int, name: str, guid: str, @@ -1763,7 +1769,7 @@ def create_run( return run_counter, run_id, formatted_name -def get_run_description(conn: ConnectionPlus, run_id: int) -> str: +def get_run_description(conn: AtomicConnection, run_id: int) -> str: """ Return the (JSON string) run description of the specified run """ @@ -1772,7 +1778,7 @@ def get_run_description(conn: ConnectionPlus, run_id: int) -> str: return rds -def get_parent_dataset_links(conn: ConnectionPlus, run_id: int) -> str: +def get_parent_dataset_links(conn: AtomicConnection, run_id: int) -> str: """ Return the (JSON string) of the parent-child dataset links for the specified run @@ -1802,7 +1808,7 @@ def get_parent_dataset_links(conn: ConnectionPlus, run_id: int) -> str: def get_data_by_tag_and_table_name( - conn: ConnectionPlus, tag: str, table_name: str + conn: AtomicConnection, tag: str, table_name: str ) -> VALUE | None: """ Get data from the "tag" column for the row in "runs" table where @@ -1825,7 +1831,7 @@ def get_data_by_tag_and_table_name( return data -def get_metadata_from_run_id(conn: ConnectionPlus, run_id: int) -> dict[str, Any]: +def get_metadata_from_run_id(conn: AtomicConnection, run_id: int) -> dict[str, Any]: """ Get all metadata associated with the specified run """ @@ -1879,7 +1885,7 @@ def validate_dynamic_column_data(data: Mapping[str, Any]) -> None: def insert_data_in_dynamic_columns( - conn: ConnectionPlus, row_id: int, table_name: str, data: Mapping[str, Any] + conn: AtomicConnection, row_id: int, table_name: str, data: Mapping[str, Any] ) -> None: """ Insert new data column and add values. Note that None is not a valid @@ -1900,7 +1906,7 @@ def insert_data_in_dynamic_columns( def update_columns( - conn: ConnectionPlus, row_id: int, table_name: str, data: Mapping[str, Any] + conn: AtomicConnection, row_id: int, table_name: str, data: Mapping[str, Any] ) -> None: """ Updates data in columns matching the given keys (they must exist already) @@ -1917,7 +1923,10 @@ def update_columns( def add_data_to_dynamic_columns( - conn: ConnectionPlus, row_id: int, data: Mapping[str, Any], table_name: str = "runs" + conn: AtomicConnection, + row_id: int, + data: Mapping[str, Any], + table_name: str = "runs", ) -> None: """ Add columns from keys and insert values. @@ -1945,13 +1954,13 @@ def add_data_to_dynamic_columns( raise e -def get_experiment_name_from_experiment_id(conn: ConnectionPlus, exp_id: int) -> str: +def get_experiment_name_from_experiment_id(conn: AtomicConnection, exp_id: int) -> str: exp_name = select_one_where(conn, "experiments", "name", "exp_id", exp_id) assert isinstance(exp_name, str) return exp_name -def get_sample_name_from_experiment_id(conn: ConnectionPlus, exp_id: int) -> str: +def get_sample_name_from_experiment_id(conn: AtomicConnection, exp_id: int) -> str: sample_name = select_one_where(conn, "experiments", "sample_name", "exp_id", exp_id) assert isinstance(sample_name, (str, type(None))) # there may be a few cases for very old db where None is returned as a sample name @@ -1960,7 +1969,7 @@ def get_sample_name_from_experiment_id(conn: ConnectionPlus, exp_id: int) -> str return cast("str", sample_name) -def get_run_timestamp_from_run_id(conn: ConnectionPlus, run_id: int) -> float | None: +def get_run_timestamp_from_run_id(conn: AtomicConnection, run_id: int) -> float | None: time_stamp = select_one_where(conn, "runs", "run_timestamp", "run_id", run_id) # sometimes it happens that the timestamp is saved as an integer in the database if isinstance(time_stamp, int): @@ -1969,7 +1978,7 @@ def get_run_timestamp_from_run_id(conn: ConnectionPlus, run_id: int) -> float | return time_stamp -def update_GUIDs(conn: ConnectionPlus) -> None: +def update_GUIDs(conn: AtomicConnection) -> None: """ Update all GUIDs in this database where either the location code or the work_station code is zero to use the location and work_station code from @@ -2023,7 +2032,7 @@ def _workstation_only_zero(run_id: int, *args: Any) -> None: ) def _both_zero( - run_id: int, conn: ConnectionPlus, guid_comps: dict[str, Any] + run_id: int, conn: AtomicConnection, guid_comps: dict[str, Any] ) -> None: guid_str = build_guid_from_components(guid_comps) with atomic(conn) as atomic_conn: @@ -2037,7 +2046,7 @@ def _both_zero( log.info(f"Succesfully updated run number {run_id}.") actions: dict[ - tuple[bool, bool], Callable[[int, ConnectionPlus, dict[str, Any]], None] + tuple[bool, bool], Callable[[int, AtomicConnection, dict[str, Any]], None] ] actions = { (True, True): _both_zero, @@ -2060,7 +2069,7 @@ def _both_zero( actions[(old_loc == 0, old_ws == 0)](run_id, conn, guid_comps) -def remove_trigger(conn: ConnectionPlus, trigger_id: str) -> None: +def remove_trigger(conn: AtomicConnection, trigger_id: str) -> None: """ Removes a trigger with a given id if it exists. @@ -2075,7 +2084,7 @@ def remove_trigger(conn: ConnectionPlus, trigger_id: str) -> None: def load_new_data_for_rundescriber( - conn: ConnectionPlus, + conn: AtomicConnection, table_name: str, rundescriber: RunDescriber, read_status: Mapping[str, int], @@ -2125,7 +2134,7 @@ class ExperimentAttributeDict(TypedDict): def get_experiment_attributes_by_exp_id( - conn: ConnectionPlus, exp_id: int + conn: AtomicConnection, exp_id: int ) -> ExperimentAttributeDict: """ Return a dict of all attributes describing an experiment from the exp_id. @@ -2163,8 +2172,8 @@ def get_experiment_attributes_by_exp_id( def _populate_results_table( - source_conn: ConnectionPlus, - target_conn: ConnectionPlus, + source_conn: AtomicConnection, + target_conn: AtomicConnection, source_table_name: str, target_table_name: str, ) -> None: @@ -2194,7 +2203,7 @@ def _populate_results_table( def _rewrite_timestamps( - target_conn: ConnectionPlus, + target_conn: AtomicConnection, target_run_id: int, correct_run_timestamp: float | None, correct_completed_timestamp: float | None, @@ -2235,7 +2244,7 @@ class RawRunAttributesDict(TypedDict): def get_raw_run_attributes( - conn: ConnectionPlus, guid: str + conn: AtomicConnection, guid: str ) -> RawRunAttributesDict | None: run_id = get_runid_from_guid(conn, guid) @@ -2285,13 +2294,13 @@ def raw_time_to_str_time( return time.strftime(fmt, time.localtime(raw_timestamp)) -def _check_if_table_found(conn: ConnectionPlus, table_name: str) -> bool: +def _check_if_table_found(conn: AtomicConnection, table_name: str) -> bool: query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" cursor = conn.cursor() return not many_many(cursor.execute(query, (table_name,)), "name") == [] -def _get_result_table_name_by_guid(conn: ConnectionPlus, guid: str) -> str: +def _get_result_table_name_by_guid(conn: AtomicConnection, guid: str) -> str: sql = "SELECT result_table_name FROM runs WHERE guid=?" formatted_name = one(transaction(conn, sql, guid), "result_table_name") return formatted_name diff --git a/src/qcodes/dataset/sqlite/query_helpers.py b/src/qcodes/dataset/sqlite/query_helpers.py index 60a122529b10..afeb5c15c561 100644 --- a/src/qcodes/dataset/sqlite/query_helpers.py +++ b/src/qcodes/dataset/sqlite/query_helpers.py @@ -14,7 +14,7 @@ from packaging import version from qcodes.dataset.sqlite.connection import ( - ConnectionPlus, + AtomicConnection, atomic, atomic_transaction, transaction, @@ -140,7 +140,11 @@ def many_many(curr: sqlite3.Cursor, *columns: str) -> list[tuple[Any, ...]]: def select_one_where( - conn: ConnectionPlus, table: str, column: str, where_column: str, where_value: VALUE + conn: AtomicConnection, + table: str, + column: str, + where_column: str, + where_value: VALUE, ) -> VALUE: """ Select a value from a given column given a match of a value in a @@ -176,7 +180,7 @@ def select_one_where( def select_many_where( - conn: ConnectionPlus, + conn: AtomicConnection, table: str, *columns: str, where_column: str, @@ -208,7 +212,7 @@ def _massage_dict(metadata: Mapping[str, Any]) -> tuple[str, list[Any]]: def update_where( - conn: ConnectionPlus, + conn: AtomicConnection, table: str, where_column: str, where_value: Any, @@ -227,7 +231,7 @@ def update_where( def insert_values( - conn: ConnectionPlus, + conn: AtomicConnection, formatted_name: str, columns: list[str], values: VALUES, @@ -253,7 +257,7 @@ def insert_values( def insert_many_values( - conn: ConnectionPlus, + conn: AtomicConnection, formatted_name: str, columns: Sequence[str], values: Sequence[VALUES], @@ -331,7 +335,7 @@ def insert_many_values( return return_value -def length(conn: ConnectionPlus, formatted_name: str) -> int: +def length(conn: AtomicConnection, formatted_name: str) -> int: """ Return the length of the table @@ -358,7 +362,7 @@ def length(conn: ConnectionPlus, formatted_name: str) -> int: def insert_column( - conn: ConnectionPlus, table: str, name: str, paramtype: str | None = None + conn: AtomicConnection, table: str, name: str, paramtype: str | None = None ) -> None: """Insert new column to a table @@ -388,7 +392,7 @@ def insert_column( transaction(atomic_conn, f'ALTER TABLE "{table}" ADD COLUMN "{name}"') -def is_column_in_table(conn: ConnectionPlus, table: str, column: str) -> bool: +def is_column_in_table(conn: AtomicConnection, table: str, column: str) -> bool: """ A look-before-you-leap function to look up if a table has a certain column. diff --git a/tests/dataset/test_database_creation_and_upgrading.py b/tests/dataset/test_database_creation_and_upgrading.py index e35bf8e7fdef..8679a71867a6 100644 --- a/tests/dataset/test_database_creation_and_upgrading.py +++ b/tests/dataset/test_database_creation_and_upgrading.py @@ -11,7 +11,6 @@ import qcodes.dataset.descriptions.versioning.serialization as serial import tests.dataset from qcodes.dataset import ( - ConnectionPlus, connect, initialise_database, initialise_or_create_database_at, @@ -26,7 +25,7 @@ from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.descriptions.versioning.v0 import InterDependencies from qcodes.dataset.guids import parse_guid -from qcodes.dataset.sqlite.connection import atomic_transaction +from qcodes.dataset.sqlite.connection import AtomicConnection, atomic_transaction from qcodes.dataset.sqlite.database import get_db_version_and_newest_available_version from qcodes.dataset.sqlite.db_upgrades import ( _latest_available_version, @@ -703,7 +702,7 @@ def test_perform_actual_upgrade_6_to_7() -> None: skip_if_no_fixtures(dbname_old) with temporarily_copied_DB(dbname_old, debug=False, version=6) as conn: - assert isinstance(conn, ConnectionPlus) + assert isinstance(conn, AtomicConnection) perform_db_upgrade_6_to_7(conn) assert get_user_version(conn) == 7 @@ -762,7 +761,7 @@ def test_perform_actual_upgrade_6_to_newest_add_new_data() -> None: skip_if_no_fixtures(dbname_old) with temporarily_copied_DB(dbname_old, debug=False, version=6) as conn: - assert isinstance(conn, ConnectionPlus) + assert isinstance(conn, AtomicConnection) perform_db_upgrade(conn) assert get_user_version(conn) >= 7 no_of_runs_query = "SELECT max(run_id) FROM runs" diff --git a/tests/dataset/test_dataset_in_memory.py b/tests/dataset/test_dataset_in_memory.py index f32519495d20..a6741d955ce7 100644 --- a/tests/dataset/test_dataset_in_memory.py +++ b/tests/dataset/test_dataset_in_memory.py @@ -1,7 +1,6 @@ import contextlib import os import shutil -import sqlite3 from pathlib import Path import hypothesis.strategies as hst @@ -15,7 +14,7 @@ from qcodes.dataset import load_by_id, load_by_run_spec from qcodes.dataset.data_set_in_memory import DataSetInMem, load_from_file from qcodes.dataset.data_set_protocol import DataSetType -from qcodes.dataset.sqlite.connection import ConnectionPlus, atomic_transaction +from qcodes.dataset.sqlite.connection import AtomicConnection, atomic_transaction from qcodes.station import Station @@ -362,7 +361,7 @@ def test_dataset_in_memory_does_not_create_runs_table( ds = datasaver.dataset dbfile = datasaver.dataset._path_to_db - conn = ConnectionPlus(sqlite3.connect(dbfile)) + conn = AtomicConnection(dbfile) tables_query = 'SELECT * FROM sqlite_master WHERE TYPE = "table"' tables = list(atomic_transaction(conn, tables_query).fetchall()) diff --git a/tests/dataset/test_guid_helpers.py b/tests/dataset/test_guid_helpers.py index 23b727151c80..8b775a034ad5 100644 --- a/tests/dataset/test_guid_helpers.py +++ b/tests/dataset/test_guid_helpers.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from pathlib import Path - from qcodes.dataset.sqlite.connection import ConnectionPlus + from qcodes.dataset.sqlite.connection import AtomicConnection def test_guids_from_dir(tmp_path: "Path") -> None: @@ -92,7 +92,7 @@ def test_many_guids_from_list_str() -> None: def test_get_guids_from_multiple_run_ids(tmp_path: "Path") -> None: - def generate_local_exp(dbpath: "Path") -> tuple[list[str], "ConnectionPlus"]: + def generate_local_exp(dbpath: "Path") -> tuple[list[str], "AtomicConnection"]: with initialised_database_at(str(dbpath)): guids = [] exp = load_or_create_experiment(experiment_name="test_guid") diff --git a/tests/dataset/test_sqlite_connection.py b/tests/dataset/test_sqlite_connection.py index 864abc3450fe..94eacb61fd72 100644 --- a/tests/dataset/test_sqlite_connection.py +++ b/tests/dataset/test_sqlite_connection.py @@ -4,10 +4,11 @@ import pytest from qcodes.dataset.sqlite.connection import ( - ConnectionPlus, + AtomicConnection, + ConnectionPlus, # pyright: ignore[reportDeprecated] atomic, atomic_transaction, - make_connection_plus_from, + make_connection_plus_from, # pyright: ignore[reportDeprecated] ) from qcodes.dataset.sqlite.database import connect from tests.common import error_caused_by @@ -20,8 +21,8 @@ def sqlite_conn_in_transaction(conn: sqlite3.Connection): return True -def conn_plus_in_transaction(conn: ConnectionPlus): - assert isinstance(conn, ConnectionPlus) +def conn_plus_in_transaction(conn: AtomicConnection): + assert isinstance(conn, AtomicConnection | ConnectionPlus) # pyright: ignore[reportDeprecated] assert True is conn.atomic_in_progress assert None is conn.isolation_level assert True is conn.in_transaction @@ -35,8 +36,8 @@ def sqlite_conn_is_idle(conn: sqlite3.Connection, isolation=None): return True -def conn_plus_is_idle(conn: ConnectionPlus, isolation=None): - assert isinstance(conn, ConnectionPlus) +def conn_plus_is_idle(conn: ConnectionPlus | AtomicConnection, isolation=None): # pyright: ignore[reportDeprecated] + assert isinstance(conn, ConnectionPlus | AtomicConnection) # pyright: ignore[reportDeprecated] assert False is conn.atomic_in_progress assert isolation == conn.isolation_level assert False is conn.in_transaction @@ -45,10 +46,10 @@ def conn_plus_is_idle(conn: ConnectionPlus, isolation=None): def test_connection_plus() -> None: sqlite_conn = sqlite3.connect(":memory:") - conn_plus = ConnectionPlus(sqlite_conn) + conn_plus = ConnectionPlus(sqlite_conn) # pyright: ignore[reportDeprecated] assert conn_plus.path_to_dbfile == "" - assert isinstance(conn_plus, ConnectionPlus) + assert isinstance(conn_plus, ConnectionPlus) # pyright: ignore[reportDeprecated] assert isinstance(conn_plus, sqlite3.Connection) assert False is conn_plus.atomic_in_progress @@ -57,41 +58,72 @@ def test_connection_plus() -> None: "`ConnectionPlus` object which is not allowed." ) with pytest.raises(ValueError, match=match_str): - ConnectionPlus(conn_plus) + ConnectionPlus(conn_plus) # pyright: ignore[reportDeprecated] + + +def test_atomic_connection() -> None: + sqlite_conn = AtomicConnection(":memory:") + + assert sqlite_conn.path_to_dbfile == "" + assert isinstance(sqlite_conn, AtomicConnection) + assert isinstance(sqlite_conn, sqlite3.Connection) + assert False is sqlite_conn.atomic_in_progress def test_make_connection_plus_from_sqlite3_connection() -> None: conn = sqlite3.connect(":memory:") - conn_plus = make_connection_plus_from(conn) + conn_plus = make_connection_plus_from(conn) # pyright: ignore[reportDeprecated] assert conn_plus.path_to_dbfile == "" - assert isinstance(conn_plus, ConnectionPlus) + assert isinstance(conn_plus, ConnectionPlus) # pyright: ignore[reportDeprecated] assert False is conn_plus.atomic_in_progress assert conn_plus is not conn def test_make_connection_plus_from_connecton_plus() -> None: - conn = ConnectionPlus(sqlite3.connect(":memory:")) - conn_plus = make_connection_plus_from(conn) + conn = ConnectionPlus(sqlite3.connect(":memory:")) # pyright: ignore[reportDeprecated] + conn_plus = make_connection_plus_from(conn) # pyright: ignore[reportDeprecated] assert conn_plus.path_to_dbfile == "" - assert isinstance(conn_plus, ConnectionPlus) + assert isinstance(conn_plus, ConnectionPlus) # pyright: ignore[reportDeprecated] assert conn.atomic_in_progress is conn_plus.atomic_in_progress assert conn_plus is conn -def test_atomic() -> None: +def test_atomic_connection_plus() -> None: sqlite_conn = sqlite3.connect(":memory:") match_str = re.escape( - "atomic context manager only accepts ConnectionPlus " + "atomic context manager only accepts AtomicConnection or ConnectionPlus " "database connection objects." ) with pytest.raises(ValueError, match=match_str): with atomic(sqlite_conn): # type: ignore[arg-type] pass - conn_plus = ConnectionPlus(sqlite_conn) + conn_plus = ConnectionPlus(sqlite_conn) # pyright: ignore[reportDeprecated] + assert False is conn_plus.atomic_in_progress + + atomic_in_progress = conn_plus.atomic_in_progress + isolation_level = conn_plus.isolation_level + + assert False is conn_plus.in_transaction + + with atomic(conn_plus) as atomic_conn: + assert conn_plus_in_transaction(atomic_conn) + assert conn_plus_in_transaction(conn_plus) + + assert isolation_level == conn_plus.isolation_level + assert False is conn_plus.in_transaction + assert atomic_in_progress is conn_plus.atomic_in_progress + + assert isolation_level == conn_plus.isolation_level + assert False is atomic_conn.in_transaction + assert atomic_in_progress is atomic_conn.atomic_in_progress + + +def test_atomic() -> None: + conn_plus = AtomicConnection(":memory:") assert False is conn_plus.atomic_in_progress atomic_in_progress = conn_plus.atomic_in_progress @@ -113,8 +145,7 @@ def test_atomic() -> None: def test_atomic_with_exception() -> None: - sqlite_conn = sqlite3.connect(":memory:") - conn_plus = ConnectionPlus(sqlite_conn) + sqlite_conn = AtomicConnection(":memory:") sqlite_conn.execute("PRAGMA user_version(25)") sqlite_conn.commit() @@ -124,7 +155,7 @@ def test_atomic_with_exception() -> None: with pytest.raises( RuntimeError, match="Rolling back due to unhandled exception" ) as e: - with atomic(conn_plus) as atomic_conn: + with atomic(sqlite_conn) as atomic_conn: atomic_conn.execute("PRAGMA user_version(42)") raise Exception("intended exception") assert error_caused_by(e, "intended exception") @@ -133,7 +164,7 @@ def test_atomic_with_exception() -> None: def test_atomic_on_outmost_connection_that_is_in_transaction() -> None: - conn = ConnectionPlus(sqlite3.connect(":memory:")) + conn = AtomicConnection(":memory:") conn.execute("BEGIN") assert True is conn.in_transaction @@ -150,32 +181,31 @@ def test_atomic_on_outmost_connection_that_is_in_transaction() -> None: @pytest.mark.parametrize("in_transaction", (True, False)) def test_atomic_on_connection_plus_that_is_in_progress(in_transaction) -> None: - sqlite_conn = sqlite3.connect(":memory:") - conn_plus = ConnectionPlus(sqlite_conn) + sqlite_conn = AtomicConnection(":memory:") # explicitly set to True for testing purposes - conn_plus.atomic_in_progress = True + sqlite_conn.atomic_in_progress = True # implement parametrizing over connection's `in_transaction` attribute if in_transaction: - conn_plus.cursor().execute("BEGIN") - assert in_transaction is conn_plus.in_transaction + sqlite_conn.cursor().execute("BEGIN") + assert in_transaction is sqlite_conn.in_transaction - isolation_level = conn_plus.isolation_level - in_transaction = conn_plus.in_transaction + isolation_level = sqlite_conn.isolation_level + in_transaction = sqlite_conn.in_transaction - with atomic(conn_plus) as atomic_conn: - assert True is conn_plus.atomic_in_progress - assert isolation_level == conn_plus.isolation_level - assert in_transaction is conn_plus.in_transaction + with atomic(sqlite_conn) as atomic_conn: + assert True is sqlite_conn.atomic_in_progress + assert isolation_level == sqlite_conn.isolation_level + assert in_transaction is sqlite_conn.in_transaction assert True is atomic_conn.atomic_in_progress assert isolation_level == atomic_conn.isolation_level assert in_transaction is atomic_conn.in_transaction - assert True is conn_plus.atomic_in_progress - assert isolation_level == conn_plus.isolation_level - assert in_transaction is conn_plus.in_transaction + assert True is sqlite_conn.atomic_in_progress + assert isolation_level == sqlite_conn.isolation_level + assert in_transaction is sqlite_conn.in_transaction assert True is atomic_conn.atomic_in_progress assert isolation_level == atomic_conn.isolation_level @@ -183,39 +213,38 @@ def test_atomic_on_connection_plus_that_is_in_progress(in_transaction) -> None: def test_two_nested_atomics() -> None: - sqlite_conn = sqlite3.connect(":memory:") - conn_plus = ConnectionPlus(sqlite_conn) + sqlite_conn = AtomicConnection(":memory:") - atomic_in_progress = conn_plus.atomic_in_progress - isolation_level = conn_plus.isolation_level + atomic_in_progress = sqlite_conn.atomic_in_progress + isolation_level = sqlite_conn.isolation_level - assert False is conn_plus.in_transaction + assert False is sqlite_conn.in_transaction - with atomic(conn_plus) as atomic_conn_1: - assert conn_plus_in_transaction(conn_plus) + with atomic(sqlite_conn) as atomic_conn_1: + assert conn_plus_in_transaction(sqlite_conn) assert conn_plus_in_transaction(atomic_conn_1) with atomic(atomic_conn_1) as atomic_conn_2: - assert conn_plus_in_transaction(conn_plus) + assert conn_plus_in_transaction(sqlite_conn) assert conn_plus_in_transaction(atomic_conn_1) assert conn_plus_in_transaction(atomic_conn_2) - assert conn_plus_in_transaction(conn_plus) + assert conn_plus_in_transaction(sqlite_conn) assert conn_plus_in_transaction(atomic_conn_1) assert conn_plus_in_transaction(atomic_conn_2) - assert conn_plus_is_idle(conn_plus, isolation_level) + assert conn_plus_is_idle(sqlite_conn, isolation_level) assert conn_plus_is_idle(atomic_conn_1, isolation_level) assert conn_plus_is_idle(atomic_conn_2, isolation_level) - assert atomic_in_progress == conn_plus.atomic_in_progress + assert atomic_in_progress == sqlite_conn.atomic_in_progress assert atomic_in_progress == atomic_conn_1.atomic_in_progress assert atomic_in_progress == atomic_conn_2.atomic_in_progress @pytest.mark.parametrize( argnames="create_conn_plus", - argvalues=(make_connection_plus_from, ConnectionPlus), + argvalues=(make_connection_plus_from, ConnectionPlus), # pyright: ignore[reportDeprecated] ids=("make_connection_plus_from", "ConnectionPlus"), ) def test_that_use_of_atomic_commits_only_at_outermost_context( @@ -304,19 +333,103 @@ def test_that_use_of_atomic_commits_only_at_outermost_context( assert 3 == len(control_conn.execute(get_all_runs).fetchall()) +def test_that_use_of_atomic_commits_only_at_outermost_context_atomic_connection( + tmp_path, +) -> None: + """ + This test tests the behavior of `ConnectionPlus` that is created from + `sqlite3.Connection` with respect to `atomic` context manager and commits. + """ + dbfile = str(tmp_path / "temp.db") + # just initialize the database file, connection objects needed for + # testing in this test function are created separately, see below + connect(dbfile) + + connection = AtomicConnection(dbfile) + + # this connection is going to be used to test whether changes have been + # committed to the database file + control_conn = connect(dbfile) + + get_all_runs = "SELECT * FROM runs" + insert_run_with_name = "INSERT INTO runs (name) VALUES (?)" + + # assert that at the beginning of the test there are no runs in the + # table; we'll be adding new rows to the runs table below + + assert 0 == len(connection.execute(get_all_runs).fetchall()) + assert 0 == len(control_conn.execute(get_all_runs).fetchall()) + + # add 1 new row, and assert the state of the runs table at every step + # note that control_conn will only detect the change after the `atomic` + # context manager is exited + + with atomic(connection) as atomic_conn: + assert 0 == len(connection.execute(get_all_runs).fetchall()) + assert 0 == len(atomic_conn.execute(get_all_runs).fetchall()) + assert 0 == len(control_conn.execute(get_all_runs).fetchall()) + + atomic_conn.cursor().execute(insert_run_with_name, ["aaa"]) + + assert 1 == len(connection.execute(get_all_runs).fetchall()) + assert 1 == len(atomic_conn.execute(get_all_runs).fetchall()) + assert 0 == len(control_conn.execute(get_all_runs).fetchall()) + + assert 1 == len(connection.execute(get_all_runs).fetchall()) + assert 1 == len(atomic_conn.execute(get_all_runs).fetchall()) + assert 1 == len(control_conn.execute(get_all_runs).fetchall()) + + # let's add two new rows but each inside its own `atomic` context manager + # we expect to see the actual change in the database only after we exit + # the outermost context. + + with atomic(connection) as atomic_conn_1: + assert 1 == len(connection.execute(get_all_runs).fetchall()) + assert 1 == len(atomic_conn_1.execute(get_all_runs).fetchall()) + assert 1 == len(control_conn.execute(get_all_runs).fetchall()) + + atomic_conn_1.cursor().execute(insert_run_with_name, ["bbb"]) + + assert 2 == len(connection.execute(get_all_runs).fetchall()) + assert 2 == len(atomic_conn_1.execute(get_all_runs).fetchall()) + assert 1 == len(control_conn.execute(get_all_runs).fetchall()) + + with atomic(atomic_conn_1) as atomic_conn_2: + assert 2 == len(connection.execute(get_all_runs).fetchall()) + assert 2 == len(atomic_conn_1.execute(get_all_runs).fetchall()) + assert 2 == len(atomic_conn_2.execute(get_all_runs).fetchall()) + assert 1 == len(control_conn.execute(get_all_runs).fetchall()) + + atomic_conn_2.cursor().execute(insert_run_with_name, ["ccc"]) + + assert 3 == len(connection.execute(get_all_runs).fetchall()) + assert 3 == len(atomic_conn_1.execute(get_all_runs).fetchall()) + assert 3 == len(atomic_conn_2.execute(get_all_runs).fetchall()) + assert 1 == len(control_conn.execute(get_all_runs).fetchall()) + + assert 3 == len(connection.execute(get_all_runs).fetchall()) + assert 3 == len(atomic_conn_1.execute(get_all_runs).fetchall()) + assert 3 == len(atomic_conn_2.execute(get_all_runs).fetchall()) + assert 1 == len(control_conn.execute(get_all_runs).fetchall()) + + assert 3 == len(connection.execute(get_all_runs).fetchall()) + assert 3 == len(atomic_conn_1.execute(get_all_runs).fetchall()) + assert 3 == len(atomic_conn_2.execute(get_all_runs).fetchall()) + assert 3 == len(control_conn.execute(get_all_runs).fetchall()) + + def test_atomic_transaction(tmp_path) -> None: """Test that atomic_transaction works for ConnectionPlus""" dbfile = str(tmp_path / "temp.db") - conn = ConnectionPlus(sqlite3.connect(dbfile)) - - ctrl_conn = sqlite3.connect(dbfile) + conn = AtomicConnection(dbfile) sql_create_table = "CREATE TABLE smth (name TEXT)" - sql_table_exists = 'SELECT sql FROM sqlite_master WHERE TYPE = "table"' atomic_transaction(conn, sql_create_table) + ctrl_conn = sqlite3.connect(dbfile) + sql_table_exists = 'SELECT sql FROM sqlite_master WHERE TYPE = "table"' assert sql_create_table in ctrl_conn.execute(sql_table_exists).fetchall()[0] @@ -327,7 +440,7 @@ def test_atomic_transaction_on_sqlite3_connection_raises(tmp_path) -> None: conn = sqlite3.connect(dbfile) match_str = re.escape( - "atomic context manager only accepts ConnectionPlus " + "atomic context manager only accepts AtomicConnection or ConnectionPlus " "database connection objects." ) @@ -339,6 +452,6 @@ def test_connect() -> None: conn = connect(":memory:") assert isinstance(conn, sqlite3.Connection) - assert isinstance(conn, ConnectionPlus) + assert isinstance(conn, AtomicConnection) assert False is conn.atomic_in_progress assert None is conn.row_factory