From ce66557067d079402ced65633d1cb0ebeda7ed9f Mon Sep 17 00:00:00 2001 From: Plamen Neykov Date: Mon, 17 Mar 2025 21:53:41 +0000 Subject: [PATCH] allow ..WithMeta data types to be initialized with objects of the base type --- src/rune/runtime/metadata.py | 125 ++++++++++++++++------------- test/test_basic_types_with_meta.py | 23 ++++++ 2 files changed, 93 insertions(+), 55 deletions(-) diff --git a/src/rune/runtime/metadata.py b/src/rune/runtime/metadata.py index 38eab15..4151548 100644 --- a/src/rune/runtime/metadata.py +++ b/src/rune/runtime/metadata.py @@ -6,7 +6,7 @@ from decimal import Decimal from typing import Any, Never, get_args import datetime -from typing_extensions import Self +from typing_extensions import Self, Tuple from pydantic import (PlainSerializer, PlainValidator, WrapValidator, WrapSerializer) # from rune.runtime.object_registry import get_object @@ -404,6 +404,17 @@ def validator(cls, allowed_meta: tuple[str] | tuple[Never, ...] = tuple()): class BasicTypeMetaDataMixin(BaseMetaDataMixin): '''holds the metadata associated with an instance''' + _INPUT_TYPES: Any | Tuple[Any, ...] = str # to be overridden by subclasses + _OUTPUT_TYPE: Any = str # to be overridden by subclasses + _JSON_OUTPUT = str | dict + + @classmethod + def _check_type(cls, value): + if not isinstance(value, cls._INPUT_TYPES): + raise ValueError(f'{cls.__name__} can be instantiated only with ' + f'one of the following type(s): {cls._INPUT_TYPES},' + f' however the value is of type {type(value)}') + @classmethod def serialise(cls, obj, base_type) -> dict: '''used as serialisation method with pydantic''' @@ -431,7 +442,7 @@ def deserialize(cls, obj, handler, base_types, allowed_meta: set[str]): @lru_cache def serializer(cls): '''should return the validator for the specific class''' - ser_fn = partial(cls.serialise, base_type=str) + ser_fn = partial(cls.serialise, base_type=cls._OUTPUT_TYPE) return PlainSerializer(ser_fn, return_type=dict) @classmethod @@ -440,15 +451,20 @@ def validator(cls, allowed_meta: tuple[str]): '''default validator for the specific class''' allowed = set(allowed_meta) return WrapValidator(partial(cls.deserialize, - base_types=str, + base_types=cls._INPUT_TYPES, allowed_meta=allowed), - json_schema_input_type=str | dict) + json_schema_input_type=cls._JSON_OUTPUT) class DateWithMeta(datetime.date, BasicTypeMetaDataMixin): '''date with metadata''' + _INPUT_TYPES = (datetime.date, str) + def __new__(cls, value, **kwds): # pylint: disable=signature-differs - ymd = datetime.date.fromisoformat(value).timetuple()[:3] + cls._check_type(value) + if isinstance(value, str): + value = datetime.date.fromisoformat(value) + ymd = value.timetuple()[:3] obj = datetime.date.__new__(cls, *ymd) obj.set_meta(check_allowed=False, **kwds) return obj @@ -456,33 +472,41 @@ def __new__(cls, value, **kwds): # pylint: disable=signature-differs class TimeWithMeta(datetime.time, BasicTypeMetaDataMixin): '''annotated time''' + _INPUT_TYPES = (datetime.time, str) + def __new__(cls, value, **kwds): # pylint: disable=signature-differs - aux = datetime.time.fromisoformat(value) + cls._check_type(value) + if isinstance(value, str): + value = datetime.time.fromisoformat(value) obj = datetime.time.__new__(cls, - aux.hour, - aux.minute, - aux.second, - aux.microsecond, - aux.tzinfo, - fold=aux.fold) + value.hour, + value.minute, + value.second, + value.microsecond, + value.tzinfo, + fold=value.fold) obj.set_meta(check_allowed=False, **kwds) return obj class DateTimeWithMeta(datetime.datetime, BasicTypeMetaDataMixin): '''annotated datetime''' + _INPUT_TYPES = (datetime.datetime, str) + def __new__(cls, value, **kwds): # pylint: disable=signature-differs - aux = datetime.datetime.fromisoformat(value) + cls._check_type(value) + if isinstance(value, str): + value = datetime.datetime.fromisoformat(value) obj = datetime.datetime.__new__(cls, - aux.year, - aux.month, - aux.day, - aux.hour, - aux.minute, - aux.second, - aux.microsecond, - aux.tzinfo, - fold=aux.fold) + value.year, + value.month, + value.day, + value.hour, + value.minute, + value.second, + value.microsecond, + value.tzinfo, + fold=value.fold) obj.set_meta(check_allowed=False, **kwds) return obj @@ -500,31 +524,22 @@ def __new__(cls, value, **kwds): class IntWithMeta(int, BasicTypeMetaDataMixin): '''annotated integer''' + _INPUT_TYPES = int + _OUTPUT_TYPE = int + _JSON_OUTPUT = int | dict + def __new__(cls, value, **kwds): obj = int.__new__(cls, value) obj.set_meta(check_allowed=False, **kwds) return obj - @classmethod - @lru_cache - def serializer(cls): - '''should return the validator for the specific class''' - ser_fn = partial(cls.serialise, base_type=int) - return PlainSerializer(ser_fn, return_type=dict) - - @classmethod - @lru_cache - def validator(cls, allowed_meta: tuple[str]): - '''default validator for the specific class''' - allowed = set(allowed_meta) - return WrapValidator(partial(cls.deserialize, - base_types=int, - allowed_meta=allowed), - json_schema_input_type=int | dict) - class NumberWithMeta(Decimal, BasicTypeMetaDataMixin): '''annotated number''' + _INPUT_TYPES = (Decimal, float, int, str) + _OUTPUT_TYPE = Decimal + _JSON_OUTPUT = float | int | str | dict + def __new__(cls, value, **kwds): # NOTE: it could be necessary to convert the value to str if it is a # float @@ -532,22 +547,22 @@ def __new__(cls, value, **kwds): obj.set_meta(check_allowed=False, **kwds) return obj - @classmethod - @lru_cache - def serializer(cls): - '''should return the validator for the specific class''' - ser_fn = partial(cls.serialise, base_type=Decimal) - return PlainSerializer(ser_fn, return_type=dict) - - @classmethod - @lru_cache - def validator(cls, allowed_meta: tuple[str]): - '''default validator for the specific class''' - allowed = set(allowed_meta) - return WrapValidator(partial(cls.deserialize, - base_types=(Decimal, float, int, str), - allowed_meta=allowed), - json_schema_input_type=float | int | str | dict) + # @classmethod + # @lru_cache + # def serializer(cls): + # '''should return the validator for the specific class''' + # ser_fn = partial(cls.serialise, base_type=Decimal) + # return PlainSerializer(ser_fn, return_type=dict) + + # @classmethod + # @lru_cache + # def validator(cls, allowed_meta: tuple[str]): + # '''default validator for the specific class''' + # allowed = set(allowed_meta) + # return WrapValidator(partial(cls.deserialize, + # base_types=(Decimal, float, int, str), + # allowed_meta=allowed), + # json_schema_input_type=float | int | str | dict) class _EnumWrapperDefaultVal(Enum): diff --git a/test/test_basic_types_with_meta.py b/test/test_basic_types_with_meta.py index eea893b..166e040 100644 --- a/test/test_basic_types_with_meta.py +++ b/test/test_basic_types_with_meta.py @@ -216,6 +216,29 @@ def test_dump_annotated_date_simple(): assert json_str == '{"date":{"@data":"2024-10-10"}}' +def test_dump_annotated_date_date(): + '''test the annotated string''' + model = AnnotatedDateModel(date=date(2024, 10, 10)) + json_str = model.model_dump_json(exclude_unset=True) + assert json_str == '{"date":{"@data":"2024-10-10"}}' + + model = AnnotatedDateModel(date=DateWithMeta(date(2024, 10, 10))) + json_str = model.model_dump_json(exclude_unset=True) + assert json_str == '{"date":{"@data":"2024-10-10"}}' + + +def test_annotated_date_fail(): + '''test instantiation failure with an incorrect type''' + with pytest.raises(AttributeError): + AnnotatedDateModel(date=10) + + +def test_date_with_meta_fail(): + '''test instantiation failure with an incorrect type''' + with pytest.raises(ValueError): + DateWithMeta(10) + + def test_load_annotated_date_scheme(): '''test the loading of annotated with a scheme strings''' scheme_json = '{"date":{"@data":"2024-10-10","@scheme":"http://fpml.org"}}'