diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6d7bca62..f93d28fc 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
+- Added a context manager `spinedb_api.spine_db_client.lock_db` for locking database mappings
+ under DB server, i.e. when executing Python scripts in Spine Toolbox.
+
### Changed
### Deprecated
diff --git a/docs/source/user_guide.rst b/docs/source/user_guide.rst
index a4be9436..422b704b 100644
--- a/docs/source/user_guide.rst
+++ b/docs/source/user_guide.rst
@@ -432,6 +432,22 @@ This structure has some specialized uses in Spine Toolbox and can usually be ign
:meth:`.DatabaseMapping.commit_session` raises :class:`NothingToCommit` when there are no changes to save.
Other errors raise :class:`SpineDBAPIError`.
+Using database mappings in Spine Toolbox
+----------------------------------------
+
+When a Python script is executed as part of a Spine Toolbox workflow,
+the same script may run in multiple parallel processes at the same.
+Using a database mapping as read-only access to a database should not cause any issues in such environment.
+However, there is a high chance for conflicts, or even corrupted in-memory data if data is committed to the database.
+The database should be explicitly locked to prevent this from happening::
+
+ url = sys.argv[1] # Url provided by Spine Toolbox
+ with api.DatabaseMapping(url) as db_map:
+ with api.spine_db_client.lock_db(db_map):
+ # Use db_map here.
+ db_map.commit("Updated things.")
+
+
Performance
-----------
diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py
index 4bd00ac9..d6aeb4fb 100644
--- a/spinedb_api/db_mapping.py
+++ b/spinedb_api/db_mapping.py
@@ -173,6 +173,7 @@ def __init__(
"""
super().__init__()
# FIXME: We should also check the server memory property and use it here
+ self.server_url: str = db_url if isinstance(db_url, str) else db_url.render_as_string(hide_password=False)
db_url = get_db_url_from_server(db_url)
self.db_url = str(db_url)
if isinstance(db_url, str):
diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py
index de654e49..9bf1c2e7 100644
--- a/spinedb_api/db_mapping_base.py
+++ b/spinedb_api/db_mapping_base.py
@@ -1066,7 +1066,7 @@ def cascade_remove(self, source: Optional[object] = None) -> None:
elif self.status in (Status.committed, Status.to_update):
self.status = Status.to_remove
else:
- raise RuntimeError("invalid status for item being removed")
+ raise RuntimeError(f"invalid status '{self.status}' for item being removed")
self._removal_source = source
self._removed = True
self._valid = None
diff --git a/spinedb_api/spine_db_client.py b/spinedb_api/spine_db_client.py
index 3d9e4107..61fb81a6 100644
--- a/spinedb_api/spine_db_client.py
+++ b/spinedb_api/spine_db_client.py
@@ -10,15 +10,19 @@
# this program. If not, see .
######################################################################################################################
-"""
-This module defines the :class:`SpineDBClient` class.
-"""
-
+"""This module defines the :class:`SpineDBClient` class."""
+from __future__ import annotations
+from collections.abc import Iterator
+from contextlib import contextmanager
import socket
+from typing import TYPE_CHECKING
from urllib.parse import urlparse
from sqlalchemy.engine.url import URL
from .server_client_helpers import ReceiveAllMixing, decode, encode
+if TYPE_CHECKING:
+ from .db_mapping import DatabaseMapping
+
client_version = 8
@@ -107,6 +111,12 @@ def call_method(self, method_name, *args, **kwargs):
def query(self, query_name: str, *args, **kwargs) -> dict:
return self._send("query", args=(query_name, *args), kwargs=kwargs)
+ def acquire_lock(self) -> None:
+ return self._send("acquire_lock")
+
+ def release_lock(self) -> None:
+ return self._send("release_lock")
+
def _send(self, request, args=None, kwargs=None, receive=True):
"""
Sends a request to the server with the given arguments.
@@ -138,3 +148,17 @@ def get_db_url_from_server(url):
if parsed.scheme != "http":
return url
return SpineDBClient((parsed.hostname, parsed.port)).get_db_url()
+
+
+@contextmanager
+def lock_db(db_map: DatabaseMapping) -> Iterator[None]:
+ url = urlparse(db_map.server_url)
+ if url.scheme != "http":
+ yield
+ return
+ client = SpineDBClient((url.hostname, url.port))
+ client.acquire_lock()
+ try:
+ yield
+ finally:
+ client.release_lock()
diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py
index 6e3ddd0f..9edde72e 100644
--- a/spinedb_api/spine_db_server.py
+++ b/spinedb_api/spine_db_server.py
@@ -106,7 +106,6 @@ def _import_entity_class(server_url, class_name):
import traceback
from typing import ClassVar, Literal, Optional, TypedDict
from urllib.parse import urlunsplit
-import uuid
from sqlalchemy.exc import DBAPIError
from spinedb_api import __version__ as spinedb_api_version
from .db_mapping import DatabaseMapping
@@ -117,7 +116,6 @@ def _import_entity_class(server_url, class_name):
from .import_functions import import_data
from .parameter_value import dump_db_value
from .server_client_helpers import ReceiveAllMixing, decode, encode
-from .spine_db_client import SpineDBClient
_current_server_version = 8
@@ -176,7 +174,7 @@ def __init__(self):
self._process = mp.Process(target=self._do_work)
self._process.start()
- def _get_commit_lock(self, db_url):
+ def _get_commit_lock(self, db_url: str) -> threading.Lock:
clean_url = clear_filter_configs(db_url)
return self._commit_locks.setdefault(clean_url, threading.Lock())
@@ -349,9 +347,10 @@ def __init__(self, db_url, upgrade, memory, commit_lock, manager_queue, ordering
self._db_map = None
self._closed = False
self._lock = threading.Lock()
+ self._commit_lock = commit_lock
self._in_queue = Queue()
self._out_queue = Queue()
- self._thread = threading.Thread(target=lambda: self._do_work(db_url, upgrade, memory, commit_lock))
+ self._thread = threading.Thread(target=lambda: self._do_work(db_url, upgrade, memory))
self._thread.start()
error = self._out_queue.get()
if isinstance(error, Exception):
@@ -361,6 +360,10 @@ def __init__(self, db_url, upgrade, memory, commit_lock, manager_queue, ordering
def db_url(self):
return str(self._db_map.db_url)
+ @property
+ def commit_lock(self):
+ return self._commit_lock
+
def close_db_map(self):
if not self._closed:
self._closed = True
@@ -368,9 +371,11 @@ def close_db_map(self):
self._in_queue.put(self._CLOSE)
self._thread.join()
- def _do_work(self, db_url, upgrade, memory, commit_lock):
+ def _do_work(self, db_url, upgrade, memory):
try:
- self._db_map = DatabaseMapping(db_url, upgrade=upgrade, memory=memory, commit_lock=commit_lock, create=True)
+ self._db_map = DatabaseMapping(
+ db_url, upgrade=upgrade, memory=memory, commit_lock=self._commit_lock, create=True
+ )
self._out_queue.put(None)
except Exception as error: # pylint: disable=broad-except
self._out_queue.put(error)
@@ -553,6 +558,14 @@ def cancel_db_checkout(self):
cancel_db_checkout(self.server_manager_queue, self.server_address)
return {"result": True}
+ def acquire_lock(self) -> dict:
+ self.server.commit_lock.acquire()
+ return {"result": True}
+
+ def release_lock(self) -> dict:
+ self.server.commit_lock.release()
+ return {"result": True}
+
def _get_response(self, request):
request, *extras = decode(request)
# NOTE: Clients should always send requests "get_api_version" and "get_db_url" in a format that is compatible
@@ -578,6 +591,8 @@ def _get_response(self, request):
"db_checkin": self.db_checkin,
"db_checkout": self.db_checkout,
"cancel_db_checkout": self.cancel_db_checkout,
+ "acquire_lock": self.acquire_lock,
+ "release_lock": self.release_lock,
}.get(request)
if handler is None:
return {"error": f"invalid request '{request}'"}
diff --git a/tests/test_spine_db_client.py b/tests/test_spine_db_client.py
new file mode 100644
index 00000000..227b7a2e
--- /dev/null
+++ b/tests/test_spine_db_client.py
@@ -0,0 +1,59 @@
+######################################################################################################################
+# Copyright (C) 2017-2022 Spine project consortium
+# Copyright Spine Database API contributors
+# This file is part of Spine Database API.
+# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
+# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
+# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
+# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
+# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
+# this program. If not, see .
+######################################################################################################################
+import multiprocessing
+import pytest
+from spinedb_api import DatabaseMapping, create_new_spine_database
+from spinedb_api.spine_db_client import lock_db
+from spinedb_api.spine_db_server import closing_spine_db_server, db_server_manager
+
+
+@pytest.fixture
+def db_url(tmp_path):
+ url = "sqlite:///" + str(tmp_path / "db.sqlite")
+ create_new_spine_database(url)
+ return url
+
+
+def _do_work(url):
+ with DatabaseMapping(url) as db_map:
+ with lock_db(db_map) as lock:
+ assert lock is None
+ alternatives = db_map.find_alternatives()
+ if len(alternatives) == 1:
+ db_map.add_alternative(name="visited")
+ db_map.commit_session("Added first alternative.")
+ else:
+ db_map.add_alternative(name="visited again")
+ db_map.commit_session("Added second alternative.")
+
+
+class TestLockDB:
+ def test_locking_is_no_operation_when_no_server_is_used(self):
+ with DatabaseMapping("sqlite://", create=True) as db_map:
+ with lock_db(db_map) as lock:
+ assert lock is None
+
+ def test_locking_with_server(self, db_url):
+ with db_server_manager() as manager_queue:
+ with (
+ closing_spine_db_server(db_url, server_manager_queue=manager_queue) as server_url1,
+ closing_spine_db_server(db_url, server_manager_queue=manager_queue) as server_url2,
+ ):
+ task1 = multiprocessing.Process(target=_do_work, args=(server_url1,))
+ task2 = multiprocessing.Process(target=_do_work, args=(server_url2,))
+ task1.start()
+ task2.start()
+ task1.join()
+ task2.join()
+ with DatabaseMapping(db_url) as db_map:
+ alternatives = db_map.find_alternatives()
+ assert {alt["name"] for alt in alternatives} == {"Base", "visited", "visited again"}