diff --git a/skops/io/_general.py b/skops/io/_general.py index b5685c3f..bb5cac0e 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -180,13 +180,9 @@ def isnamedtuple(self, t) -> bool: def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { - "__class__": obj.__class__.__name__, + "__class__": obj.__name__, "__module__": get_module(obj), "__loader__": "FunctionNode", - "content": { - "module_path": get_module(obj), - "function": obj.__name__, - }, } return res @@ -201,26 +197,20 @@ def __init__( super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES) - self.children = {"content": state["content"]} + self.children = {} def _construct(self): - return _import_obj( - self.children["content"]["module_path"], - self.children["content"]["function"], - ) + return gettype(self.module_name, self.class_name) def _get_function_name(self) -> str: - return ( - self.children["content"]["module_path"] - + "." - + self.children["content"]["function"] - ) + return f"{self.module_name}.{self.class_name}" def get_unsafe_set(self) -> set[str]: - if (self.trusted is True) or (self._get_function_name() in self.trusted): + fn_name = self._get_function_name() + if (self.trusted is True) or (fn_name in self.trusted): return set() - return {self._get_function_name()} + return {fn_name} def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index e7bc40a9..7c19cc96 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -6,6 +6,7 @@ import numpy as np from ._audit import Node, get_tree +from ._general import function_get_state from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -195,22 +196,6 @@ def _construct(self): return gettype(self.module_name, self.class_name)(bit_generator=bit_generator) -# For numpy.ufunc we need to get the type from the type's module, but for other -# functions we get it from objet's module directly. Therefore sett a especial -# get_state method for them here. The load is the same as other functions. -def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: - res = { - "__class__": obj.__class__.__name__, # ufunc - "__module__": get_module(type(obj)), # numpy - "__loader__": "FunctionNode", - "content": { - "module_path": get_module(obj), - "function": obj.__name__, - }, - } - return res - - def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. @@ -247,7 +232,7 @@ def _construct(self): (np.generic, ndarray_get_state), (np.ndarray, ndarray_get_state), (np.ma.MaskedArray, maskedarray_get_state), - (np.ufunc, ufunc_get_state), + (np.ufunc, function_get_state), (np.dtype, dtype_get_state), (np.random.RandomState, random_state_get_state), (np.random.Generator, random_generator_get_state),