From e383ad1a86aed3ceb3ded346d229448465fa689b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 Mar 2023 12:08:12 +0100 Subject: [PATCH 1/2] Suggestion how to simplify serialization of funcs Description It looks like the state we save for functions contains unnecessary information (the class name, which is not used) and duplicate information (the module name appears twice). This change gets rid of that unnecessary information. This also allows to remove the need for ufunc_get_state, which can now be replaced by function_get_state. Comment Maybe I'm missing something and there is a reason for the previous structure, but if there is, there should be a test to check whatever is covered by that, or at the very least a comment to explain. Btw. can't remember who wrote this, could be my fault :) --- skops/io/_general.py | 24 +++++++----------------- skops/io/_numpy.py | 11 ++++------- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index b5685c3f..bb5cac0e 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -180,13 +180,9 @@ def isnamedtuple(self, t) -> bool: def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { - "__class__": obj.__class__.__name__, + "__class__": obj.__name__, "__module__": get_module(obj), "__loader__": "FunctionNode", - "content": { - "module_path": get_module(obj), - "function": obj.__name__, - }, } return res @@ -201,26 +197,20 @@ def __init__( super().__init__(state, load_context, trusted) # TODO: what do we trust? self.trusted = self._get_trusted(trusted, default=SCIPY_UFUNC_TYPE_NAMES) - self.children = {"content": state["content"]} + self.children = {} def _construct(self): - return _import_obj( - self.children["content"]["module_path"], - self.children["content"]["function"], - ) + return gettype(self.module_name, self.class_name) def _get_function_name(self) -> str: - return ( - self.children["content"]["module_path"] - + "." - + self.children["content"]["function"] - ) + return f"{self.module_name}.{self.class_name}" def get_unsafe_set(self) -> set[str]: - if (self.trusted is True) or (self._get_function_name() in self.trusted): + fn_name = self._get_function_name() + if (self.trusted is True) or (fn_name in self.trusted): return set() - return {self._get_function_name()} + return {fn_name} def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index e7bc40a9..8f17eb95 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -6,6 +6,7 @@ import numpy as np from ._audit import Node, get_tree +from ._general import function_get_state from ._utils import LoadContext, SaveContext, get_module, get_state, gettype from .exceptions import UnsupportedTypeException @@ -200,13 +201,9 @@ def _construct(self): # get_state method for them here. The load is the same as other functions. def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: res = { - "__class__": obj.__class__.__name__, # ufunc - "__module__": get_module(type(obj)), # numpy + "__class__": obj.__name__, + "__module__": get_module(obj), "__loader__": "FunctionNode", - "content": { - "module_path": get_module(obj), - "function": obj.__name__, - }, } return res @@ -247,7 +244,7 @@ def _construct(self): (np.generic, ndarray_get_state), (np.ndarray, ndarray_get_state), (np.ma.MaskedArray, maskedarray_get_state), - (np.ufunc, ufunc_get_state), + (np.ufunc, function_get_state), (np.dtype, dtype_get_state), (np.random.RandomState, random_state_get_state), (np.random.Generator, random_generator_get_state), From 1ae57e1ea39095ff1a40a338cefe08cc7c82c3d6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 15 Mar 2023 17:04:06 +0100 Subject: [PATCH 2/2] Reviewer comment: Remove obsolete ufunc_get_state --- skops/io/_numpy.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 8f17eb95..7c19cc96 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -196,18 +196,6 @@ def _construct(self): return gettype(self.module_name, self.class_name)(bit_generator=bit_generator) -# For numpy.ufunc we need to get the type from the type's module, but for other -# functions we get it from objet's module directly. Therefore sett a especial -# get_state method for them here. The load is the same as other functions. -def ufunc_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: - res = { - "__class__": obj.__name__, - "__module__": get_module(obj), - "__loader__": "FunctionNode", - } - return res - - def dtype_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype.