Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/ml/anomaly/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(
threshold_criterion: Optional[ThresholdFn] = None,
**kwargs):
self._model_id = model_id if model_id is not None else getattr(
self, 'spec_type', 'unknown')
self, 'spec_type', lambda: "unknown")()
self._features = features
self._target = target
self._threshold_criterion = threshold_criterion
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(
aggregation_strategy: Optional[AggregationFn] = None,
**kwargs):
if "model_id" not in kwargs or kwargs["model_id"] is None:
kwargs["model_id"] = getattr(self, 'spec_type', 'custom')
kwargs["model_id"] = getattr(self, 'spec_type', lambda: 'custom')()

super().__init__(**kwargs)

Expand Down
175 changes: 117 additions & 58 deletions sdks/python/apache_beam/ml/anomaly/specifiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
import logging
import os
from typing import Any
from typing import ClassVar
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Type
from typing import TypeVar
from typing import Union
from typing import overload
from typing import runtime_checkable

from typing_extensions import Self
Expand All @@ -59,7 +60,8 @@
#: `spec_type` when applying the `specifiable` decorator to an existing class.
_KNOWN_SPECIFIABLE = collections.defaultdict(dict)

SpecT = TypeVar('SpecT', bound='Specifiable')
T = TypeVar('T', bound=type)
BUILTIN_TYPES_IN_SPEC = (int, float, complex, str, bytes, bytearray)


def _class_to_subspace(cls: Type) -> str:
Expand Down Expand Up @@ -104,33 +106,59 @@ class Spec():
config: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)


@runtime_checkable
class Specifiable(Protocol):
"""Protocol that a specifiable class needs to implement."""
#: The value of the `type` field in the object's spec for this class.
spec_type: ClassVar[str]
#: The raw keyword arguments passed to `__init__` method during object
#: initialization.
init_kwargs: dict[str, Any]
def _specifiable_from_spec_helper(v, _run_init):
if isinstance(v, Spec):
return Specifiable.from_spec(v, _run_init)

if isinstance(v, List):
return [_specifiable_from_spec_helper(e, _run_init) for e in v]

# TODO: support spec treatment for more types
if not isinstance(v, BUILTIN_TYPES_IN_SPEC):
logging.warning(
"Type %s is not a recognized supported type for the"
"specification. It will be included without conversion.",
str(type(v)))
return v

# a boolean to tell whether the original `__init__` method is called
_initialized: bool
# a boolean used by new_getattr to tell whether it is in the `__init__` method
# call
_in_init: bool

@staticmethod
def _from_spec_helper(v, _run_init):
if isinstance(v, Spec):
return Specifiable.from_spec(v, _run_init)
def _specifiable_to_spec_helper(v):
if isinstance(v, Specifiable):
return v.to_spec()

if isinstance(v, List):
return [Specifiable._from_spec_helper(e, _run_init) for e in v]
if isinstance(v, List):
return [_specifiable_to_spec_helper(e) for e in v]

return v
if inspect.isfunction(v):
if not hasattr(v, "spec_type"):
_register(v, inject_spec_type=False)
return Spec(type=_get_default_spec_type(v), config=None)

if inspect.isclass(v):
if not hasattr(v, "spec_type"):
_register(v, inject_spec_type=False)
return Spec(type=_get_default_spec_type(v), config=None)

# TODO: support spec treatment for more types
if not isinstance(v, BUILTIN_TYPES_IN_SPEC):
logging.warning(
"Type %s is not a recognized supported type for the"
"specification. It will be included without conversion.",
str(type(v)))
return v


@runtime_checkable
class Specifiable(Protocol):
"""Protocol that a specifiable class needs to implement."""
@classmethod
def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]:
def spec_type(cls) -> str:
pass

@classmethod
def from_spec(cls,
spec: Spec,
_run_init: bool = True) -> Union[Self, type[Self]]:
"""Generate a `Specifiable` subclass object based on a spec.

Args:
Expand All @@ -155,7 +183,7 @@ def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]:
return subclass

kwargs = {
k: Specifiable._from_spec_helper(v, _run_init)
k: _specifiable_from_spec_helper(v, _run_init)
for k,
v in spec.config.items()
}
Expand All @@ -164,26 +192,6 @@ def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]:
kwargs["_run_init"] = True
return subclass(**kwargs)

@staticmethod
def _to_spec_helper(v):
if isinstance(v, Specifiable):
return v.to_spec()

if isinstance(v, List):
return [Specifiable._to_spec_helper(e) for e in v]

if inspect.isfunction(v):
if not hasattr(v, "spec_type"):
_register(v, inject_spec_type=False)
return Spec(type=_get_default_spec_type(v), config=None)

if inspect.isclass(v):
if not hasattr(v, "spec_type"):
_register(v, inject_spec_type=False)
return Spec(type=_get_default_spec_type(v), config=None)

return v

def to_spec(self) -> Spec:
"""Generate a spec from a `Specifiable` subclass object.

Expand All @@ -195,14 +203,22 @@ def to_spec(self) -> Spec:
f"'{type(self).__name__}' not registered as Specifiable. "
f"Decorate ({type(self).__name__}) with @specifiable")

args = {k: self._to_spec_helper(v) for k, v in self.init_kwargs.items()}
args = {
k: _specifiable_to_spec_helper(v)
for k, v in self.init_kwargs.items()
}

return Spec(type=self.__class__.spec_type, config=args)
return Spec(type=self.spec_type(), config=args)

def run_original_init(self) -> None:
"""Invoke the original __init__ method with original keyword arguments"""
pass

@classmethod
def unspecifiable(cls) -> None:
"""Resume the class structure prior to specifiable"""
pass


def _get_default_spec_type(cls):
spec_type = cls.__name__
Expand All @@ -216,7 +232,7 @@ def _get_default_spec_type(cls):


# Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE`
def _register(cls, spec_type=None, inject_spec_type=True) -> None:
def _register(cls: type, spec_type=None, inject_spec_type=True) -> None:
assert spec_type is None or inject_spec_type, \
"need to inject spec_type to class if spec_type is not None"
if spec_type is None:
Expand All @@ -237,7 +253,8 @@ def _register(cls, spec_type=None, inject_spec_type=True) -> None:
_KNOWN_SPECIFIABLE[subspace][spec_type] = cls

if inject_spec_type:
cls.spec_type = spec_type
setattr(cls, cls.__name__ + '__spec_type', spec_type)
# cls.__spec_type = spec_type


# Keep a copy of arguments that are used to call the `__init__` method when the
Expand All @@ -250,13 +267,35 @@ def _get_init_kwargs(inst, init_method, *args, **kwargs):
return params


@overload
def specifiable(
my_cls: None = None,
/,
*,
spec_type: Optional[str] = None,
on_demand_init: bool = True,
just_in_time_init: bool = True) -> Callable[[T], T]:
pass


@overload
def specifiable(
my_cls=None,
my_cls: T,
/,
*,
spec_type=None,
on_demand_init=True,
just_in_time_init=True):
spec_type: Optional[str] = None,
on_demand_init: bool = True,
just_in_time_init: bool = True) -> T:
pass


def specifiable(
my_cls: Optional[T] = None,
/,
*,
spec_type: Optional[str] = None,
on_demand_init: bool = True,
just_in_time_init: bool = True) -> Union[T, Callable[[T], T]]:
"""A decorator that turns a class into a `Specifiable` subclass by
implementing the `Specifiable` protocol.

Expand Down Expand Up @@ -285,8 +324,8 @@ class Bar():
original `__init__` method will be called when the first time an attribute
is accessed.
"""
def _wrapper(cls):
def new_init(self: Specifiable, *args, **kwargs):
def _wrapper(cls: T) -> T:
def new_init(self, *args, **kwargs):
self._initialized = False
self._in_init = False

Expand Down Expand Up @@ -358,20 +397,40 @@ def new_getattr(self, name):
name)
return self.__getattribute__(name)

def spec_type_func(cls):
return getattr(cls, spec_type_attr_name)

def unspecifiable(cls):
delattr(cls, spec_type_attr_name)
cls.__init__ = original_init
if just_in_time_init:
delattr(cls, '__getattr__')
delattr(cls, 'spec_type')
delattr(cls, 'run_original_init')
delattr(cls, 'to_spec')
delattr(cls, 'from_spec')
delattr(cls, 'unspecifiable')

spec_type_attr_name = cls.__name__ + "__spec_type"

# the class is registered
if hasattr(cls, spec_type_attr_name):
return cls

# start of the function body of _wrapper
_register(cls, spec_type)

class_name = cls.__name__
original_init = cls.__init__
cls.__init__ = new_init
original_init = cls.__init__ # type: ignore[misc]
cls.__init__ = new_init # type: ignore[misc]
if just_in_time_init:
cls.__getattr__ = new_getattr

cls.spec_type = classmethod(spec_type_func)
cls.run_original_init = run_original_init
cls.to_spec = Specifiable.to_spec
cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper)
cls.from_spec = Specifiable.from_spec
cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper)
cls.unspecifiable = classmethod(unspecifiable)

return cls
# end of the function body of _wrapper
Expand Down
54 changes: 47 additions & 7 deletions sdks/python/apache_beam/ml/anomaly/specifiable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class B():

# apply the decorator function to an existing class
A = specifiable(A)
self.assertEqual(A.spec_type, "A")
self.assertEqual(A.spec_type(), "A")
self.assertTrue(isinstance(A(), Specifiable))
self.assertIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A"], A)
Expand All @@ -63,13 +63,10 @@ class B():
# Raise an error when re-registering spec_type with a different class
self.assertRaises(ValueError, specifiable(spec_type='A'), B)

# apply the decorator function to an existing class with a different
# spec_type
# Applying the decorator function to an existing class with a different
# spec_type will have no effect.
A = specifiable(spec_type="A_DUP")(A)
self.assertEqual(A.spec_type, "A_DUP")
self.assertTrue(isinstance(A(), Specifiable))
self.assertIn("A_DUP", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A_DUP"], A)
self.assertEqual(A.spec_type(), "A")

def test_decorator_in_syntactic_sugar_form(self):
# call decorator without parameters
Expand Down Expand Up @@ -585,6 +582,49 @@ def apply(self, x, y):
self.assertEqual(w_2.run_func_in_class(5, 3), 150)


class TestUncommonUsages(unittest.TestCase):
def test_double_specifiable(self):
@specifiable
@specifiable
class ZZ():
def __init__(self, a):
self.a = a

assert issubclass(ZZ, Specifiable)
c = ZZ("b")
c.run_original_init()
self.assertEqual(c.a, "b")

def test_unspecifiable(self):
class YY():
def __init__(self, x):
self.x = x
assert False

YY = specifiable(YY)
assert issubclass(YY, Specifiable)
y = YY(1)
# __init__ is called (with assertion error raised) when attribute is first
# accessed
self.assertRaises(AssertionError, lambda: y.x)

# unspecifiable YY
YY.unspecifiable()
# __init__ is called immediately
self.assertRaises(AssertionError, YY, 1)
self.assertFalse(hasattr(YY, 'run_original_init'))
self.assertFalse(hasattr(YY, 'spec_type'))
self.assertFalse(hasattr(YY, 'to_spec'))
self.assertFalse(hasattr(YY, 'from_spec'))
self.assertFalse(hasattr(YY, 'unspecifiable'))

# make YY specifiable again
YY = specifiable(YY)
assert issubclass(YY, Specifiable)
y = YY(1)
self.assertRaises(AssertionError, lambda: y.x)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Loading