Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions spinedb_api/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class SpineDBAPIError(Exception):
"""Basic exception for errors raised by the API."""

def __init__(self, msg=None):
def __init__(self, msg: str | None = None):
super().__init__(msg)
self.msg = msg

Expand All @@ -34,7 +34,7 @@ class SpineIntegrityError(SpineDBAPIError):
id (int): the id the instance that caused a unique violation
"""

def __init__(self, msg=None, id_=None):
def __init__(self, msg: str | None = None, id_=None):
super().__init__(msg)
self.id = id_

Expand Down Expand Up @@ -72,15 +72,14 @@ def __init__(self, msg):


class InvalidMappingComponent(InvalidMapping):
def __init__(self, msg, rank=None, key=None):
def __init__(self, msg: str, rank: int | None = None):
super().__init__(msg)
self.rank = rank
self.key = key

def __eq__(self, other):
if not isinstance(other, InvalidMappingComponent):
return NotImplemented
return self.msg == other.msg and self.rank == other.rank and self.key == other.key
return self.msg == other.msg and self.rank == other.rank


class ReaderError(SpineDBAPIError):
Expand Down
139 changes: 76 additions & 63 deletions spinedb_api/import_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
Contains `get_mapped_data()` that converts rows of tabular data into a dictionary for import to a Spine DB,
using ``import_functions.import_data()``
"""
from collections.abc import Callable
from collections.abc import Callable, Iterable
from copy import deepcopy
from itertools import dropwhile
from typing import Any, Optional
from typing import Any, Optional, TypeVar
from ..exception import ParameterValueFormatError
from ..helpers import string_to_bool
from ..import_functions import UnparseCallable
Expand All @@ -43,48 +43,50 @@
check_validity,
)
from .import_mapping_compat import import_mapping_from_dict
from .type_conversion import ConvertSpec

_NO_VALUE = object()

T = TypeVar("T")

def identity(x):

def identity(x: T) -> T:
"""Returns argument unchanged.

Args:
x (Any): value to return
x : value to return

Returns:
Any: x
x
"""
return x


def get_mapped_data(
data_source,
mappings,
data_header=None,
table_name="",
column_convert_fns=None,
default_column_convert_fn=None,
row_convert_fns=None,
unparse_value=identity,
mapping_names=None,
):
data_source: Iterable[list],
mappings: list[ImportMapping | list | dict],
data_header: list | None = None,
table_name: str = "",
column_convert_fns: dict[int, ConvertSpec] | None = None,
default_column_convert_fn: ConvertSpec | Callable[[Any], Any] | None = None,
row_convert_fns: ConvertSpec | Callable[[Any], Any] | None = None,
unparse_value: Callable[[Any], Any] = identity,
mapping_names: list[str] | None = None,
) -> tuple[dict, list[str]]:
"""
Args:
data_source (Iterable): Yields rows (lists)
mappings (list(ImportMapping)): Mappings from data rows into mapped data for ``import_data()``
data_header (list, optional): table header
table_name (str, optional): table name
column_convert_fns (dict(int,function), optional): mapping from column number to convert function
default_column_convert_fn (Callable, optional): default convert function for surplus columns
row_convert_fns (dict(int,function), optional): mapping from row number to convert function
unparse_value (Callable): a callable that converts values to database format
mapping_names (list, optional): list of mapping names (order corresponds to order of mappings).
data_source: Yields rows (lists)
mappings: Mappings from data rows into mapped data for ``import_data()``
data_header: table header
table_name: table name
column_convert_fns: mapping from column number to convert function
default_column_convert_fn: default convert function for surplus columns
row_convert_fns: mapping from row number to convert function
unparse_value: a callable that converts values to database format
mapping_names: list of mapping names (order corresponds to order of mappings).

Returns:
dict: Mapped data, ready for ``import_data()``
list: Conversion errors
Mapped data, ready for ``import_data()`` and conversion errors
"""
# Sanitize mappings
for k, mapping in enumerate(mappings):
Expand Down Expand Up @@ -199,11 +201,17 @@ def _last_column_convert_function(functions: Optional[dict]) -> Callable[[Any],
return functions[max(functions)] if functions else identity


def _is_valid_row(row):
def _is_valid_row(row: list | None) -> bool:
return row is not None and not all(i is None for i in row)


def _convert_row(row, convert_fns, row_number, errors, default_convert_fn=lambda x: x):
def _convert_row(
row: list,
convert_fns: dict[int, ConvertSpec],
row_number: int,
errors: list[str],
default_convert_fn: ConvertSpec | Callable[[Any], Any] = identity,
) -> list:
new_row = []
for j, item in enumerate(row):
if item is None:
Expand All @@ -219,17 +227,17 @@ def _convert_row(row, convert_fns, row_number, errors, default_convert_fn=lambda
return new_row


def _split_mapping(mapping):
def _split_mapping(
mapping: ImportMapping,
) -> tuple[list[ImportMapping], list[ImportMapping], list[ImportMapping], ImportMapping]:
"""Splits the given mapping into pivot components.

Args:
mapping (ImportMapping)
mapping: mapping to split

Returns:
list(ImportMapping): Pivoted mappings (reading from rows)
list(ImportMapping): Non-pivoted mappings ('regular', reading from columns)
list(ImportMapping): Pivoted from header mappings
ImportMapping: last mapping (typically representing the parameter value)
Pivoted mappings (reading from rows), non-pivoted mappings ('regular', reading from columns),
pivoted from header mappings and last mapping (typically representing the parameter value)
"""
flattened = mapping.flatten()
pivoted = []
Expand All @@ -252,25 +260,30 @@ def _split_mapping(mapping):


def _unpivot_rows(
rows, data_header, pivoted, non_pivoted, pivoted_from_header, skip_columns, read_start_row, pivoted_by_leaf
):
rows: list[list],
data_header: list[str],
pivoted: list[ImportMapping],
non_pivoted: list[ImportMapping],
pivoted_from_header: list[ImportMapping],
skip_columns: list[int],
read_start_row: int,
pivoted_by_leaf: bool,
) -> tuple[list[list], list[int], list[Position | int | None], list[int]]:
"""Unpivots rows.

Args:
rows (list of list): Source table rows
data_header (list): Source table header
pivoted (list of ImportMapping): Pivoted mappings (reading from rows)
non_pivoted (list of ImportMapping): Non-pivoted mappings ('regular', reading from columns)
pivoted_from_header (list of ImportMapping): Mappings pivoted from header
skip_columns (list of int): columns that should be skipped
read_start_row (int): first row to include
pivoted_by_leaf (bool): whether only the leaf mapping is pivoted
rows: Source table rows
data_header: Source table header
pivoted: Pivoted mappings (reading from rows)
non_pivoted: Non-pivoted mappings ('regular', reading from columns)
pivoted_from_header: Mappings pivoted from header
skip_columns: columns that should be skipped
read_start_row: first row to include
pivoted_by_leaf: whether only the leaf mapping is pivoted

Returns:
list of list: Unpivoted rows
int: Position of last pivoted row
int: Position of last non-pivoted row
list of int: Columns positions corresponding to unpivoted rows
Unpivoted rows, positions of pivoted rows, positions of non-pivoted rows
and column positions corresponding to unpivoted rows
"""
# First we collect pivoted and unpivoted positions
pivoted_pos = [-(m.position + 1) for m in pivoted] # (-1) -> (0), (-2) -> (1), (-3) -> (2), etc.
Expand Down Expand Up @@ -307,7 +320,7 @@ def _unpivot_rows(
return unpivoted_rows, pivoted_pos, non_pivoted_pos, unpivoted_column_pos


def _make_entity_classes(mapped_data: dict) -> None:
def _make_entity_classes(mapped_data: SemiMappedData) -> None:
try:
rows = mapped_data.pop("entity_classes")
except KeyError:
Expand All @@ -322,7 +335,7 @@ def _make_entity_classes(mapped_data: dict) -> None:
mapped_data["entity_classes"] = final_rows


def _make_entities(mapped_data):
def _make_entities(mapped_data: SemiMappedData) -> None:
try:
rows = mapped_data.pop("entities")
except KeyError:
Expand All @@ -337,7 +350,7 @@ def _make_entities(mapped_data):
mapped_data["entities"] = final_rows


def _make_entity_alternatives(mapped_data, errors):
def _make_entity_alternatives(mapped_data: SemiMappedData, errors: list[str]) -> None:
if "entity_alternatives" not in mapped_data:
return
rows = []
Expand Down Expand Up @@ -382,7 +395,7 @@ def _make_parameter_definitions(mapped_data: SemiMappedData, unparse_value: Unpa
mapped_data[key] = final_rows


def _make_parameter_values(mapped_data, unparse_value):
def _make_parameter_values(mapped_data: SemiMappedData, unparse_value: UnparseCallable) -> None:
key = "parameter_values"
try:
rows = mapped_data.pop(key)
Expand Down Expand Up @@ -410,14 +423,14 @@ def _make_parameter_values(mapped_data, unparse_value):
mapped_data[key] = final_rows


def _make_parameter_value_metadata(mapped_data):
def _make_parameter_value_metadata(mapped_data: SemiMappedData) -> None:
rows = mapped_data.get("parameter_value_metadata")
if rows is None:
return
mapped_data["parameter_value_metadata"] = list(rows)


def _make_entity_metadata(mapped_data):
def _make_entity_metadata(mapped_data: SemiMappedData) -> None:
rows = mapped_data.get("entity_metadata")
if rows is None:
return
Expand Down Expand Up @@ -448,15 +461,15 @@ def _make_value(record: ValueRecord) -> IndexedValue:
raise RuntimeError(f"logic error: unknown record type '{type(record).__name__}'")


def _table_to_map(table, compress=False):
def _table_to_map(table: Iterable[list], compress: bool = False) -> IndexedValue:
d = _table_to_dict(table)
m = _dict_to_map_recursive(d)
if compress:
return convert_leaf_maps_to_specialized_containers(m)
return m


def _table_to_dict(table):
def _table_to_dict(table: Iterable[list]) -> dict:
map_dict = {}
for row in table:
row = [item for item in row if item not in (None, "")]
Expand All @@ -469,7 +482,7 @@ def _table_to_dict(table):
return map_dict


def _dict_to_map_recursive(d):
def _dict_to_map_recursive(d: dict) -> Map:
indexes = []
values = []
for key, value in d.items():
Expand All @@ -480,19 +493,19 @@ def _dict_to_map_recursive(d):
return Map(indexes, values)


def _apply_index_names(map_, index_names):
"""Applies index names to Map.
def _apply_index_names(indexed_value: IndexedValue, index_names: list[str]) -> None:
"""Applies index names to indexed value.

Args:
map_ (Map): target Map.
index_names (Sequence of str): index names, one for each Map depth
indexed_value: target value.
index_names: index names, one for each index depth
"""
name = index_names[0]
if name:
map_.index_name = name
indexed_value.index_name = name
if len(index_names) == 1:
return
for v in map_.values:
for v in indexed_value.values:
if isinstance(v, Map):
_apply_index_names(v, index_names[1:])

Expand Down
6 changes: 3 additions & 3 deletions spinedb_api/import_mapping/import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,15 +1282,15 @@ def _default_parameter_value_metadata_mapping() -> EntityClassMapping:
return unflatten(mappings)


def from_dict(serialized):
def from_dict(serialized: list[dict]) -> ImportMapping:
"""
Deserializes mappings.

Args:
serialized (list): serialize mappings
serialized: serialize mappings

Returns:
Mapping: root mapping
root mapping
"""
mappings = {
klass.MAP_TYPE: klass
Expand Down
Loading