Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added a context manager `spinedb_api.spine_db_client.lock_db` for locking database mappings
under DB server, i.e. when executing Python scripts in Spine Toolbox.

### Changed

### Deprecated
Expand Down
16 changes: 16 additions & 0 deletions docs/source/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,22 @@ This structure has some specialized uses in Spine Toolbox and can usually be ign
:meth:`.DatabaseMapping.commit_session` raises :class:`NothingToCommit` when there are no changes to save.
Other errors raise :class:`SpineDBAPIError`.

Using database mappings in Spine Toolbox
----------------------------------------

When a Python script is executed as part of a Spine Toolbox workflow,
the same script may run in multiple parallel processes at the same.
Using a database mapping as read-only access to a database should not cause any issues in such environment.
However, there is a high chance for conflicts, or even corrupted in-memory data if data is committed to the database.
The database should be explicitly locked to prevent this from happening::

url = sys.argv[1] # Url provided by Spine Toolbox
with api.DatabaseMapping(url) as db_map:
with api.spine_db_client.lock_db(db_map):
# Use db_map here.
db_map.commit("Updated things.")


Performance
-----------

Expand Down
1 change: 1 addition & 0 deletions spinedb_api/db_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(
"""
super().__init__()
# FIXME: We should also check the server memory property and use it here
self.server_url: str = db_url if isinstance(db_url, str) else db_url.render_as_string(hide_password=False)
db_url = get_db_url_from_server(db_url)
self.db_url = str(db_url)
if isinstance(db_url, str):
Expand Down
2 changes: 1 addition & 1 deletion spinedb_api/db_mapping_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ def cascade_remove(self, source: Optional[object] = None) -> None:
elif self.status in (Status.committed, Status.to_update):
self.status = Status.to_remove
else:
raise RuntimeError("invalid status for item being removed")
raise RuntimeError(f"invalid status '{self.status}' for item being removed")
self._removal_source = source
self._removed = True
self._valid = None
Expand Down
32 changes: 28 additions & 4 deletions spinedb_api/spine_db_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
This module defines the :class:`SpineDBClient` class.
"""

"""This module defines the :class:`SpineDBClient` class."""
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
import socket
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from sqlalchemy.engine.url import URL
from .server_client_helpers import ReceiveAllMixing, decode, encode

if TYPE_CHECKING:
from .db_mapping import DatabaseMapping

client_version = 8


Expand Down Expand Up @@ -107,6 +111,12 @@ def call_method(self, method_name, *args, **kwargs):
def query(self, query_name: str, *args, **kwargs) -> dict:
return self._send("query", args=(query_name, *args), kwargs=kwargs)

def acquire_lock(self) -> None:
return self._send("acquire_lock")

def release_lock(self) -> None:
return self._send("release_lock")

def _send(self, request, args=None, kwargs=None, receive=True):
"""
Sends a request to the server with the given arguments.
Expand Down Expand Up @@ -138,3 +148,17 @@ def get_db_url_from_server(url):
if parsed.scheme != "http":
return url
return SpineDBClient((parsed.hostname, parsed.port)).get_db_url()


@contextmanager
def lock_db(db_map: DatabaseMapping) -> Iterator[None]:
url = urlparse(db_map.server_url)
if url.scheme != "http":
yield
return
client = SpineDBClient((url.hostname, url.port))
client.acquire_lock()
try:
yield
finally:
client.release_lock()
27 changes: 21 additions & 6 deletions spinedb_api/spine_db_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def _import_entity_class(server_url, class_name):
import traceback
from typing import ClassVar, Literal, Optional, TypedDict
from urllib.parse import urlunsplit
import uuid
from sqlalchemy.exc import DBAPIError
from spinedb_api import __version__ as spinedb_api_version
from .db_mapping import DatabaseMapping
Expand All @@ -117,7 +116,6 @@ def _import_entity_class(server_url, class_name):
from .import_functions import import_data
from .parameter_value import dump_db_value
from .server_client_helpers import ReceiveAllMixing, decode, encode
from .spine_db_client import SpineDBClient

_current_server_version = 8

Expand Down Expand Up @@ -176,7 +174,7 @@ def __init__(self):
self._process = mp.Process(target=self._do_work)
self._process.start()

def _get_commit_lock(self, db_url):
def _get_commit_lock(self, db_url: str) -> threading.Lock:
clean_url = clear_filter_configs(db_url)
return self._commit_locks.setdefault(clean_url, threading.Lock())

Expand Down Expand Up @@ -349,9 +347,10 @@ def __init__(self, db_url, upgrade, memory, commit_lock, manager_queue, ordering
self._db_map = None
self._closed = False
self._lock = threading.Lock()
self._commit_lock = commit_lock
self._in_queue = Queue()
self._out_queue = Queue()
self._thread = threading.Thread(target=lambda: self._do_work(db_url, upgrade, memory, commit_lock))
self._thread = threading.Thread(target=lambda: self._do_work(db_url, upgrade, memory))
self._thread.start()
error = self._out_queue.get()
if isinstance(error, Exception):
Expand All @@ -361,16 +360,22 @@ def __init__(self, db_url, upgrade, memory, commit_lock, manager_queue, ordering
def db_url(self):
return str(self._db_map.db_url)

@property
def commit_lock(self):
return self._commit_lock

def close_db_map(self):
if not self._closed:
self._closed = True
self._db_map.close()
self._in_queue.put(self._CLOSE)
self._thread.join()

def _do_work(self, db_url, upgrade, memory, commit_lock):
def _do_work(self, db_url, upgrade, memory):
try:
self._db_map = DatabaseMapping(db_url, upgrade=upgrade, memory=memory, commit_lock=commit_lock, create=True)
self._db_map = DatabaseMapping(
db_url, upgrade=upgrade, memory=memory, commit_lock=self._commit_lock, create=True
)
self._out_queue.put(None)
except Exception as error: # pylint: disable=broad-except
self._out_queue.put(error)
Expand Down Expand Up @@ -553,6 +558,14 @@ def cancel_db_checkout(self):
cancel_db_checkout(self.server_manager_queue, self.server_address)
return {"result": True}

def acquire_lock(self) -> dict:
self.server.commit_lock.acquire()
return {"result": True}

def release_lock(self) -> dict:
self.server.commit_lock.release()
return {"result": True}

def _get_response(self, request):
request, *extras = decode(request)
# NOTE: Clients should always send requests "get_api_version" and "get_db_url" in a format that is compatible
Expand All @@ -578,6 +591,8 @@ def _get_response(self, request):
"db_checkin": self.db_checkin,
"db_checkout": self.db_checkout,
"cancel_db_checkout": self.cancel_db_checkout,
"acquire_lock": self.acquire_lock,
"release_lock": self.release_lock,
}.get(request)
if handler is None:
return {"error": f"invalid request '{request}'"}
Expand Down
59 changes: 59 additions & 0 deletions tests/test_spine_db_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
######################################################################################################################
# Copyright (C) 2017-2022 Spine project consortium
# Copyright Spine Database API contributors
# This file is part of Spine Database API.
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
import multiprocessing
import pytest
from spinedb_api import DatabaseMapping, create_new_spine_database
from spinedb_api.spine_db_client import lock_db
from spinedb_api.spine_db_server import closing_spine_db_server, db_server_manager


@pytest.fixture
def db_url(tmp_path):
url = "sqlite:///" + str(tmp_path / "db.sqlite")
create_new_spine_database(url)
return url


def _do_work(url):
with DatabaseMapping(url) as db_map:
with lock_db(db_map) as lock:
assert lock is None
alternatives = db_map.find_alternatives()
if len(alternatives) == 1:
db_map.add_alternative(name="visited")
db_map.commit_session("Added first alternative.")
else:
db_map.add_alternative(name="visited again")
db_map.commit_session("Added second alternative.")


class TestLockDB:
def test_locking_is_no_operation_when_no_server_is_used(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
with lock_db(db_map) as lock:
assert lock is None

def test_locking_with_server(self, db_url):
with db_server_manager() as manager_queue:
with (
closing_spine_db_server(db_url, server_manager_queue=manager_queue) as server_url1,
closing_spine_db_server(db_url, server_manager_queue=manager_queue) as server_url2,
):
task1 = multiprocessing.Process(target=_do_work, args=(server_url1,))
task2 = multiprocessing.Process(target=_do_work, args=(server_url2,))
task1.start()
task2.start()
task1.join()
task2.join()
with DatabaseMapping(db_url) as db_map:
alternatives = db_map.find_alternatives()
assert {alt["name"] for alt in alternatives} == {"Base", "visited", "visited again"}
Loading