From 2578e6bb2cb7183cb1777dcb9a7538c02ed82cf8 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 6 Mar 2026 09:30:28 +0200 Subject: [PATCH] Removed unused fields from exceptions; added typing --- spinedb_api/exception.py | 9 +- spinedb_api/import_mapping/generator.py | 139 ++++++++++-------- spinedb_api/import_mapping/import_mapping.py | 6 +- .../import_mapping/import_mapping_compat.py | 33 +++-- spinedb_api/import_mapping/type_conversion.py | 45 ++++-- 5 files changed, 131 insertions(+), 101 deletions(-) diff --git a/spinedb_api/exception.py b/spinedb_api/exception.py index 9855de9c..96b72963 100644 --- a/spinedb_api/exception.py +++ b/spinedb_api/exception.py @@ -18,7 +18,7 @@ class SpineDBAPIError(Exception): """Basic exception for errors raised by the API.""" - def __init__(self, msg=None): + def __init__(self, msg: str | None = None): super().__init__(msg) self.msg = msg @@ -34,7 +34,7 @@ class SpineIntegrityError(SpineDBAPIError): id (int): the id the instance that caused a unique violation """ - def __init__(self, msg=None, id_=None): + def __init__(self, msg: str | None = None, id_=None): super().__init__(msg) self.id = id_ @@ -72,15 +72,14 @@ def __init__(self, msg): class InvalidMappingComponent(InvalidMapping): - def __init__(self, msg, rank=None, key=None): + def __init__(self, msg: str, rank: int | None = None): super().__init__(msg) self.rank = rank - self.key = key def __eq__(self, other): if not isinstance(other, InvalidMappingComponent): return NotImplemented - return self.msg == other.msg and self.rank == other.rank and self.key == other.key + return self.msg == other.msg and self.rank == other.rank class ReaderError(SpineDBAPIError): diff --git a/spinedb_api/import_mapping/generator.py b/spinedb_api/import_mapping/generator.py index 1abe6702..ad8e177d 100644 --- a/spinedb_api/import_mapping/generator.py +++ b/spinedb_api/import_mapping/generator.py @@ -14,10 +14,10 @@ Contains `get_mapped_data()` that converts rows of tabular data into a dictionary for import to a Spine DB, using ``import_functions.import_data()`` """ -from collections.abc import Callable +from collections.abc import Callable, Iterable from copy import deepcopy from itertools import dropwhile -from typing import Any, Optional +from typing import Any, Optional, TypeVar from ..exception import ParameterValueFormatError from ..helpers import string_to_bool from ..import_functions import UnparseCallable @@ -43,48 +43,50 @@ check_validity, ) from .import_mapping_compat import import_mapping_from_dict +from .type_conversion import ConvertSpec _NO_VALUE = object() +T = TypeVar("T") -def identity(x): + +def identity(x: T) -> T: """Returns argument unchanged. Args: - x (Any): value to return + x : value to return Returns: - Any: x + x """ return x def get_mapped_data( - data_source, - mappings, - data_header=None, - table_name="", - column_convert_fns=None, - default_column_convert_fn=None, - row_convert_fns=None, - unparse_value=identity, - mapping_names=None, -): + data_source: Iterable[list], + mappings: list[ImportMapping | list | dict], + data_header: list | None = None, + table_name: str = "", + column_convert_fns: dict[int, ConvertSpec] | None = None, + default_column_convert_fn: ConvertSpec | Callable[[Any], Any] | None = None, + row_convert_fns: ConvertSpec | Callable[[Any], Any] | None = None, + unparse_value: Callable[[Any], Any] = identity, + mapping_names: list[str] | None = None, +) -> tuple[dict, list[str]]: """ Args: - data_source (Iterable): Yields rows (lists) - mappings (list(ImportMapping)): Mappings from data rows into mapped data for ``import_data()`` - data_header (list, optional): table header - table_name (str, optional): table name - column_convert_fns (dict(int,function), optional): mapping from column number to convert function - default_column_convert_fn (Callable, optional): default convert function for surplus columns - row_convert_fns (dict(int,function), optional): mapping from row number to convert function - unparse_value (Callable): a callable that converts values to database format - mapping_names (list, optional): list of mapping names (order corresponds to order of mappings). + data_source: Yields rows (lists) + mappings: Mappings from data rows into mapped data for ``import_data()`` + data_header: table header + table_name: table name + column_convert_fns: mapping from column number to convert function + default_column_convert_fn: default convert function for surplus columns + row_convert_fns: mapping from row number to convert function + unparse_value: a callable that converts values to database format + mapping_names: list of mapping names (order corresponds to order of mappings). Returns: - dict: Mapped data, ready for ``import_data()`` - list: Conversion errors + Mapped data, ready for ``import_data()`` and conversion errors """ # Sanitize mappings for k, mapping in enumerate(mappings): @@ -199,11 +201,17 @@ def _last_column_convert_function(functions: Optional[dict]) -> Callable[[Any], return functions[max(functions)] if functions else identity -def _is_valid_row(row): +def _is_valid_row(row: list | None) -> bool: return row is not None and not all(i is None for i in row) -def _convert_row(row, convert_fns, row_number, errors, default_convert_fn=lambda x: x): +def _convert_row( + row: list, + convert_fns: dict[int, ConvertSpec], + row_number: int, + errors: list[str], + default_convert_fn: ConvertSpec | Callable[[Any], Any] = identity, +) -> list: new_row = [] for j, item in enumerate(row): if item is None: @@ -219,17 +227,17 @@ def _convert_row(row, convert_fns, row_number, errors, default_convert_fn=lambda return new_row -def _split_mapping(mapping): +def _split_mapping( + mapping: ImportMapping, +) -> tuple[list[ImportMapping], list[ImportMapping], list[ImportMapping], ImportMapping]: """Splits the given mapping into pivot components. Args: - mapping (ImportMapping) + mapping: mapping to split Returns: - list(ImportMapping): Pivoted mappings (reading from rows) - list(ImportMapping): Non-pivoted mappings ('regular', reading from columns) - list(ImportMapping): Pivoted from header mappings - ImportMapping: last mapping (typically representing the parameter value) + Pivoted mappings (reading from rows), non-pivoted mappings ('regular', reading from columns), + pivoted from header mappings and last mapping (typically representing the parameter value) """ flattened = mapping.flatten() pivoted = [] @@ -252,25 +260,30 @@ def _split_mapping(mapping): def _unpivot_rows( - rows, data_header, pivoted, non_pivoted, pivoted_from_header, skip_columns, read_start_row, pivoted_by_leaf -): + rows: list[list], + data_header: list[str], + pivoted: list[ImportMapping], + non_pivoted: list[ImportMapping], + pivoted_from_header: list[ImportMapping], + skip_columns: list[int], + read_start_row: int, + pivoted_by_leaf: bool, +) -> tuple[list[list], list[int], list[Position | int | None], list[int]]: """Unpivots rows. Args: - rows (list of list): Source table rows - data_header (list): Source table header - pivoted (list of ImportMapping): Pivoted mappings (reading from rows) - non_pivoted (list of ImportMapping): Non-pivoted mappings ('regular', reading from columns) - pivoted_from_header (list of ImportMapping): Mappings pivoted from header - skip_columns (list of int): columns that should be skipped - read_start_row (int): first row to include - pivoted_by_leaf (bool): whether only the leaf mapping is pivoted + rows: Source table rows + data_header: Source table header + pivoted: Pivoted mappings (reading from rows) + non_pivoted: Non-pivoted mappings ('regular', reading from columns) + pivoted_from_header: Mappings pivoted from header + skip_columns: columns that should be skipped + read_start_row: first row to include + pivoted_by_leaf: whether only the leaf mapping is pivoted Returns: - list of list: Unpivoted rows - int: Position of last pivoted row - int: Position of last non-pivoted row - list of int: Columns positions corresponding to unpivoted rows + Unpivoted rows, positions of pivoted rows, positions of non-pivoted rows + and column positions corresponding to unpivoted rows """ # First we collect pivoted and unpivoted positions pivoted_pos = [-(m.position + 1) for m in pivoted] # (-1) -> (0), (-2) -> (1), (-3) -> (2), etc. @@ -307,7 +320,7 @@ def _unpivot_rows( return unpivoted_rows, pivoted_pos, non_pivoted_pos, unpivoted_column_pos -def _make_entity_classes(mapped_data: dict) -> None: +def _make_entity_classes(mapped_data: SemiMappedData) -> None: try: rows = mapped_data.pop("entity_classes") except KeyError: @@ -322,7 +335,7 @@ def _make_entity_classes(mapped_data: dict) -> None: mapped_data["entity_classes"] = final_rows -def _make_entities(mapped_data): +def _make_entities(mapped_data: SemiMappedData) -> None: try: rows = mapped_data.pop("entities") except KeyError: @@ -337,7 +350,7 @@ def _make_entities(mapped_data): mapped_data["entities"] = final_rows -def _make_entity_alternatives(mapped_data, errors): +def _make_entity_alternatives(mapped_data: SemiMappedData, errors: list[str]) -> None: if "entity_alternatives" not in mapped_data: return rows = [] @@ -382,7 +395,7 @@ def _make_parameter_definitions(mapped_data: SemiMappedData, unparse_value: Unpa mapped_data[key] = final_rows -def _make_parameter_values(mapped_data, unparse_value): +def _make_parameter_values(mapped_data: SemiMappedData, unparse_value: UnparseCallable) -> None: key = "parameter_values" try: rows = mapped_data.pop(key) @@ -410,14 +423,14 @@ def _make_parameter_values(mapped_data, unparse_value): mapped_data[key] = final_rows -def _make_parameter_value_metadata(mapped_data): +def _make_parameter_value_metadata(mapped_data: SemiMappedData) -> None: rows = mapped_data.get("parameter_value_metadata") if rows is None: return mapped_data["parameter_value_metadata"] = list(rows) -def _make_entity_metadata(mapped_data): +def _make_entity_metadata(mapped_data: SemiMappedData) -> None: rows = mapped_data.get("entity_metadata") if rows is None: return @@ -448,7 +461,7 @@ def _make_value(record: ValueRecord) -> IndexedValue: raise RuntimeError(f"logic error: unknown record type '{type(record).__name__}'") -def _table_to_map(table, compress=False): +def _table_to_map(table: Iterable[list], compress: bool = False) -> IndexedValue: d = _table_to_dict(table) m = _dict_to_map_recursive(d) if compress: @@ -456,7 +469,7 @@ def _table_to_map(table, compress=False): return m -def _table_to_dict(table): +def _table_to_dict(table: Iterable[list]) -> dict: map_dict = {} for row in table: row = [item for item in row if item not in (None, "")] @@ -469,7 +482,7 @@ def _table_to_dict(table): return map_dict -def _dict_to_map_recursive(d): +def _dict_to_map_recursive(d: dict) -> Map: indexes = [] values = [] for key, value in d.items(): @@ -480,19 +493,19 @@ def _dict_to_map_recursive(d): return Map(indexes, values) -def _apply_index_names(map_, index_names): - """Applies index names to Map. +def _apply_index_names(indexed_value: IndexedValue, index_names: list[str]) -> None: + """Applies index names to indexed value. Args: - map_ (Map): target Map. - index_names (Sequence of str): index names, one for each Map depth + indexed_value: target value. + index_names: index names, one for each index depth """ name = index_names[0] if name: - map_.index_name = name + indexed_value.index_name = name if len(index_names) == 1: return - for v in map_.values: + for v in indexed_value.values: if isinstance(v, Map): _apply_index_names(v, index_names[1:]) diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index bd58f3e6..827dcd32 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -1282,15 +1282,15 @@ def _default_parameter_value_metadata_mapping() -> EntityClassMapping: return unflatten(mappings) -def from_dict(serialized): +def from_dict(serialized: list[dict]) -> ImportMapping: """ Deserializes mappings. Args: - serialized (list): serialize mappings + serialized: serialize mappings Returns: - Mapping: root mapping + root mapping """ mappings = { klass.MAP_TYPE: klass diff --git a/spinedb_api/import_mapping/import_mapping_compat.py b/spinedb_api/import_mapping/import_mapping_compat.py index a629e7e1..329e9b2d 100644 --- a/spinedb_api/import_mapping/import_mapping_compat.py +++ b/spinedb_api/import_mapping/import_mapping_compat.py @@ -22,6 +22,7 @@ EntityMapping, ExpandedParameterDefaultValueMapping, ExpandedParameterValueMapping, + ImportMapping, IndexNameMapping, ParameterDefaultValueIndexMapping, ParameterDefaultValueMapping, @@ -41,7 +42,7 @@ from .import_mapping import from_dict as mapping_from_dict -def parse_named_mapping_spec(named_mapping_spec): +def parse_named_mapping_spec(named_mapping_spec: dict) -> tuple[str, ImportMapping]: if len(named_mapping_spec) == 1: name, mapping_spec = next(iter(named_mapping_spec.items())) mapping = mapping_spec["mapping"] @@ -52,11 +53,11 @@ def parse_named_mapping_spec(named_mapping_spec): return name, import_mapping_from_dict(mapping) -def unparse_named_mapping_spec(name, root_mapping): +def unparse_named_mapping_spec(name: str, root_mapping: ImportMapping) -> dict[str, dict[str, list[dict]]]: return {name: {"mapping": import_mapping_to_dict(root_mapping)}} -def import_mapping_from_dict(map_dict): +def import_mapping_from_dict(map_dict: list | dict) -> ImportMapping: """Creates Mapping object from a dict""" if isinstance(map_dict, list): # New system, flattened mapping as list @@ -84,7 +85,7 @@ def import_mapping_from_dict(map_dict): ) -def _parameter_value_list_mapping_from_dict(map_dict): +def _parameter_value_list_mapping_from_dict(map_dict: dict) -> ParameterValueListMapping: name = map_dict.get("name") value = map_dict.get("value") skip_columns = map_dict.get("skip_columns", []) @@ -96,7 +97,7 @@ def _parameter_value_list_mapping_from_dict(map_dict): return root_mapping -def _alternative_mapping_from_dict(map_dict): +def _alternative_mapping_from_dict(map_dict: dict) -> AlternativeMapping: name = map_dict.get("name") skip_columns = map_dict.get("skip_columns", []) read_start_row = map_dict.get("read_start_row", 0) @@ -104,7 +105,7 @@ def _alternative_mapping_from_dict(map_dict): return root_mapping -def _scenario_mapping_from_dict(map_dict): +def _scenario_mapping_from_dict(map_dict: dict) -> ScenarioMapping: name = map_dict.get("name") skip_columns = map_dict.get("skip_columns", []) read_start_row = map_dict.get("read_start_row", 0) @@ -112,7 +113,7 @@ def _scenario_mapping_from_dict(map_dict): return root_mapping -def _scenario_alternative_mapping_from_dict(map_dict): +def _scenario_alternative_mapping_from_dict(map_dict: dict) -> ScenarioMapping: scenario_name = map_dict.get("scenario_name") alternative_name = map_dict.get("alternative_name") before_alternative_name = map_dict.get("before_alternative_name") @@ -126,7 +127,7 @@ def _scenario_alternative_mapping_from_dict(map_dict): return root_mapping -def _object_class_mapping_from_dict(map_dict): +def _object_class_mapping_from_dict(map_dict: dict) -> EntityClassMapping: name = map_dict.get("name") entities = map_dict.get("objects", map_dict.get("object")) parameters = map_dict.get("parameters") @@ -138,7 +139,7 @@ def _object_class_mapping_from_dict(map_dict): return root_mapping -def _object_group_mapping_from_dict(map_dict): +def _object_group_mapping_from_dict(map_dict: dict) -> EntityClassMapping: name = map_dict.get("name") groups = map_dict.get("groups") members = map_dict.get("members") @@ -151,7 +152,7 @@ def _object_group_mapping_from_dict(map_dict): return root_mapping -def _relationship_class_mapping_from_dict(map_dict): +def _relationship_class_mapping_from_dict(map_dict: dict) -> EntityClassMapping: name = map_dict.get("name") objects = map_dict.get("objects") if objects is None: @@ -179,7 +180,7 @@ def _relationship_class_mapping_from_dict(map_dict): return root_mapping -def parameter_mapping_from_dict(map_dict): +def parameter_mapping_from_dict(map_dict: dict | None) -> ParameterDefinitionMapping | AlternativeMapping | None: if map_dict is None: return None map_type = map_dict.get("map_type") @@ -203,7 +204,9 @@ def parameter_mapping_from_dict(map_dict): return alt_mapping -def parameter_default_value_mapping_from_dict(default_value_dict): +def parameter_default_value_mapping_from_dict( + default_value_dict: dict | None, +) -> ParameterDefaultValueMapping | ParameterDefaultValueTypeMapping: if default_value_dict is None: return ParameterDefaultValueMapping(*_pos_and_val(None)) value_type = default_value_dict["value_type"].replace(" ", "_") @@ -225,7 +228,7 @@ def parameter_default_value_mapping_from_dict(default_value_dict): return root_mapping -def parameter_value_mapping_from_dict(value_dict): +def parameter_value_mapping_from_dict(value_dict: dict | None) -> ParameterValueMapping | ParameterValueTypeMapping: if value_dict is None: return ParameterValueMapping(*_pos_and_val(None)) value_type = value_dict["value_type"].replace(" ", "_") @@ -247,7 +250,7 @@ def parameter_value_mapping_from_dict(value_dict): return root_mapping -def _fix_parameter_mapping_dict(map_dict): +def _fix_parameter_mapping_dict(map_dict: dict) -> None: # Even deeper legacy parameter_type = map_dict.pop("parameter_type", None) if parameter_type == "definition": @@ -261,7 +264,7 @@ def _fix_parameter_mapping_dict(map_dict): map_dict["value"] = value_dict -def _pos_and_val(x): +def _pos_and_val(x: dict | str | int | None) -> tuple[Position | int, str | int | None]: if not isinstance(x, dict): map_type = "constant" if isinstance(x, str) else "column" map_dict = {"map_type": map_type, "reference": x} diff --git a/spinedb_api/import_mapping/type_conversion.py b/spinedb_api/import_mapping/type_conversion.py index daf48742..15f6c78d 100644 --- a/spinedb_api/import_mapping/type_conversion.py +++ b/spinedb_api/import_mapping/type_conversion.py @@ -11,13 +11,26 @@ ###################################################################################################################### """ Type conversion functions. """ - +from __future__ import annotations +from collections.abc import Callable +from datetime import datetime import re +from typing import Any, ClassVar, Generic, Literal, TypeAlias, TypedDict, TypeVar +from dateutil.relativedelta import relativedelta +from typing_extensions import NotRequired from spinedb_api.helpers import string_to_bool from spinedb_api.parameter_value import DateTime, Duration, ParameterValueFormatError -def value_to_convert_spec(value): +class ConvertSpecDict(TypedDict): + name: str + start_datetime: NotRequired[str] + duration: NotRequired[str] + start_int: NotRequired[int] + +ConvertSpecValue: TypeAlias = Literal["datetime", "duration", "float", "string", "boolean"] + +def value_to_convert_spec(value: ConvertSpec | ConvertSpecValue | ConvertSpecDict): if isinstance(value, ConvertSpec): return value if isinstance(value, str): @@ -27,7 +40,7 @@ def value_to_convert_spec(value): "float": FloatConvertSpec, "string": StringConvertSpec, "boolean": BooleanConvertSpec, - }.get(value) + }[value] return spec() if isinstance(value, dict): start_datetime = DateTime(value.get("start_datetime")) @@ -37,11 +50,13 @@ def value_to_convert_spec(value): raise TypeError(f"value must be str or dict instead got {type(value).__name__}") -class ConvertSpec: - DISPLAY_NAME = "" - RETURN_TYPE = str +T = TypeVar("T") - def __call__(self, value): +class ConvertSpec(Generic[T]): + DISPLAY_NAME: ClassVar[str] = NotImplemented + RETURN_TYPE: Callable[[Any], T] = NotImplemented + + def __call__(self, value: Any) -> T | None: try: return self.RETURN_TYPE(value) except ValueError as error: @@ -49,31 +64,31 @@ def __call__(self, value): return None raise error - def to_json_value(self): + def to_json_value(self) -> str | ConvertSpecDict: return self.DISPLAY_NAME -class DateTimeConvertSpec(ConvertSpec): +class DateTimeConvertSpec(ConvertSpec[DateTime]): DISPLAY_NAME = "datetime" RETURN_TYPE = DateTime -class DurationConvertSpec(ConvertSpec): +class DurationConvertSpec(ConvertSpec[Duration]): DISPLAY_NAME = "duration" RETURN_TYPE = Duration -class FloatConvertSpec(ConvertSpec): +class FloatConvertSpec(ConvertSpec[float]): DISPLAY_NAME = "float" RETURN_TYPE = float -class StringConvertSpec(ConvertSpec): +class StringConvertSpec(ConvertSpec[str]): DISPLAY_NAME = "string" RETURN_TYPE = str -class BooleanConvertSpec(ConvertSpec): +class BooleanConvertSpec(ConvertSpec[bool]): DISPLAY_NAME = "boolean" RETURN_TYPE = bool @@ -81,11 +96,11 @@ def __call__(self, value): return self.RETURN_TYPE(string_to_bool(str(value))) -class IntegerSequenceDateTimeConvertSpec(ConvertSpec): +class IntegerSequenceDateTimeConvertSpec(ConvertSpec[DateTime]): DISPLAY_NAME = "integer sequence datetime" RETURN_TYPE = DateTime - def __init__(self, start_datetime, start_int, duration): + def __init__(self, start_datetime: str | DateTime | datetime, start_int: int, duration: str | relativedelta | Duration): if not isinstance(start_datetime, DateTime): start_datetime = DateTime(start_datetime) if not isinstance(duration, Duration):