Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,9 @@ def isnamedtuple(self, t) -> bool:

def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__class__": obj.__name__,
"__module__": get_module(obj),
"__loader__": "FunctionNode",
"content": {
"module_path": get_module(obj),
"function": obj.__name__,
},
}
return res

Expand All @@ -201,26 +197,20 @@ def __init__(
super().__init__(state, load_context, trusted)
# TODO: what do we trust?
self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES)
self.children = {"content": state["content"]}
self.children = {}

def _construct(self):
return _import_obj(
self.children["content"]["module_path"],
self.children["content"]["function"],
)
return gettype(self.module_name, self.class_name)

def _get_function_name(self) -> str:
return (
self.children["content"]["module_path"]
+ "."
+ self.children["content"]["function"]
)
return f"{self.module_name}.{self.class_name}"

def get_unsafe_set(self) -> set[str]:
if (self.trusted is True) or (self._get_function_name() in self.trusted):
fn_name = self._get_function_name()
if (self.trusted is True) or (fn_name in self.trusted):
return set()

return {self._get_function_name()}
return {fn_name}


def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
Expand Down
19 changes: 2 additions & 17 deletions skops/io/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from ._audit import Node, get_tree
from ._general import function_get_state
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException

Expand Down Expand Up @@ -195,22 +196,6 @@ def _construct(self):
return gettype(self.module_name, self.class_name)(bit_generator=bit_generator)


# For numpy.ufunc we need to get the type from the type's module, but for other
# functions we get it from objet's module directly. Therefore sett a especial
# get_state method for them here. The load is the same as other functions.
def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__, # ufunc
"__module__": get_module(type(obj)), # numpy
"__loader__": "FunctionNode",
"content": {
"module_path": get_module(obj),
"function": obj.__name__,
},
}
return res


def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# we use numpy's internal save mechanism to store the dtype by
# saving/loading an empty array with that dtype.
Expand Down Expand Up @@ -247,7 +232,7 @@ def _construct(self):
(np.generic, ndarray_get_state),
(np.ndarray, ndarray_get_state),
(np.ma.MaskedArray, maskedarray_get_state),
(np.ufunc, ufunc_get_state),
(np.ufunc, function_get_state),
(np.dtype, dtype_get_state),
(np.random.RandomState, random_state_get_state),
(np.random.Generator, random_generator_get_state),
Expand Down