Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 5 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,11 @@ extism call plugin.wasm sum --input='{"a": 20, "b": 21}' --wasi
# => {"sum":41}
```

You can also specify your input and output types as dataclasses using
`extism.Json`:

```python
from typing import Optional, List
from dataclasses import dataclass

# ...

@dataclass
class User(extism.Json):
admin: bool
name: Optional[str]
email: str
addresses: List[Address]


@extism.plugin_fn
def reflect_user():
input = extism.input(User)
extism.output(input)
```
For automatic deserialization of input types and serialization of output types,
see [XTP Python Bindgen](https://github.com/dylibso/xtp-python-bindgen/) . The
`extism.Json` dataclass serialization has been removed in-favor of the
[Dataclass Wizard](https://dataclass-wizard.readthedocs.io/en/latest/index.html)
based solution there.

### Configs

Expand Down
8 changes: 1 addition & 7 deletions examples/count-vowels.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
import extism
import json
from dataclasses import dataclass

@dataclass
class Count(extism.Json):
count: int

@extism.plugin_fn
def count_vowels():
Expand All @@ -14,5 +8,5 @@ def count_vowels():
if ch in ['A', 'a', 'E', 'e', 'I', 'i', 'O', 'o', 'U', 'u']:
total += 1
extism.log(extism.LogLevel.Info, "Hello!")
extism.output(Count(total))
extism.output({"count": total})

3 changes: 1 addition & 2 deletions examples/imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import extism
import json

@extism.import_fn("example", "do_something")
def do_something():
Expand All @@ -10,7 +9,7 @@ def reflect(x: str) -> str:
pass

@extism.import_fn("example", "update_dict")
def update_dict(x: extism.JsonObject) -> extism.JsonObject:
def update_dict(x: dict) -> dict:
pass

@extism.plugin_fn
Expand Down
146 changes: 9 additions & 137 deletions lib/src/prelude.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import Union, Optional
import json
from enum import Enum
from abc import ABC, abstractmethod
from datetime import datetime
from base64 import b64encode, b64decode
from dataclasses import is_dataclass

import extism_ffi as ffi

Expand All @@ -28,129 +24,13 @@ def log(level, msg):

IMPORT_INDEX = 0


class Codec(ABC):
"""
Codec is used to serialize and deserialize values in Extism memory
"""

@abstractmethod
def encode(self) -> bytes:
"""Encode the inner value to bytes"""
raise Exception("encode not implemented")

@classmethod
@abstractmethod
def decode(s: bytes):
"""Decode a value from bytes"""
raise Exception("encode not implemented")

def __post_init__(self):
self._fix_fields()

def _fix_fields(self):
if not hasattr(self, '__annotations__'):
return
for k in self.__annotations__:
ty = self.__annotations__[k]
v = getattr(self, k)
setattr(self, k, self._fix_field(ty, v))
return self

def _fix_field(self, ty: type, v):
def check_subclass(a, b):
try:
return issubclass(a, b)
except Exception as _:
return False
if isinstance(v, dict) and check_subclass(ty, Codec):
return ty(**v)._fix_fields()
elif isinstance(v, str) and check_subclass(ty, Enum):
return ty(v)
elif isinstance(v, list) and hasattr(ty, '__origin__') and ty.__origin__ is list:
ty = ty.__args__[0]
return [self._fix_field(ty, x) for x in v]
elif hasattr(ty, '__origin__') and ty.__origin__ is Union:
if len(ty.__args__) == 2 and ty.__args__[1] == type(None) and v is not None:
ty = ty.__args__[0]
return self._fix_field(ty, v)
return v


class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Json):
return json.loads(o.encode().decode())
elif isinstance(o, bytes):
return b64encode(o).decode()
elif isinstance(o, datetime):
return o.isoformat()
elif isinstance(o, Enum):
return str(o.value)
elif isinstance(o, list):
return [self.default(x) for x in o]
elif isinstance(o, dict):
return {k: self.default(x) for k, x in o.items()}
return super().default(o)


class JSONDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)

def object_hook(self, dct):
if not isinstance(dct, dict):
return dct
for k, v in dct.items():
if isinstance(v, str):
try:
dct[k] = datetime.fromisoformat(v)
continue
except Exception as _:
pass

try:
dct[k] = b64decode(v.encode())
continue
except Exception as _:
pass
elif isinstance(v, dict):
dct[k] = self.object_hook(v)
elif isinstance(v, list):
dct[k] = [self.object_hook(x) for x in v]
return dct


class Json(Codec):
def encode(self) -> bytes:
v = self
if not isinstance(self, (dict, datetime, bytes)) and hasattr(self, "__dict__"):
if len(self.__dict__) > 0:
v = self.__dict__
return json.dumps(v, cls=JSONEncoder).encode()

@classmethod
def decode(cls, s: bytes):
x = json.loads(s.decode(), cls=JSONDecoder)
if is_dataclass(cls):
return cls(**x)
else:
return cls(**x)._fix_fields()


class JsonObject(Json, dict):
pass


def _store(x) -> int:
if isinstance(x, str):
return ffi.memory.alloc(x.encode()).offset
elif isinstance(x, bytes):
return ffi.memory.alloc(x).offset
elif isinstance(x, dict) or isinstance(x, list):
return ffi.memory.alloc(json.dumps(x, cls=JSONEncoder).encode()).offset
elif isinstance(x, Codec):
return ffi.memory.alloc(x.encode()).offset
return ffi.memory.alloc(json.dumps(x).encode()).offset
elif isinstance(x, Enum):
return ffi.memory.alloc(str(x.value).encode()).offset
elif isinstance(x, ffi.memory.MemoryHandle):
Expand All @@ -176,9 +56,7 @@ def _load(t, x):
elif t is bytes:
return ffi.memory.bytes(mem)
elif t is dict or t is list:
return json.loads(ffi.memory.string(mem), cls=JSONDecoder)
elif issubclass(t, Codec):
return t.decode(ffi.memory.bytes(mem))
return json.loads(ffi.memory.string(mem))
elif issubclass(t, Enum):
return t(ffi.memory.string(mem))
elif t is ffi.memory.MemoryHandle:
Expand Down Expand Up @@ -235,10 +113,8 @@ def inner(*args):
def input_json(t: Optional[type] = None):
"""Get input as JSON"""
if t is int or t is float:
return t(json.loads(input_str(), cls=JSONDecoder))
if issubclass(t, Json):
return t(**json.loads(input_str(), cls=JSONDecoder))
return json.loads(input_str(), cls=JSONDecoder)
return t(json.loads(input_str()))
return json.loads(input_str())


def output_json(x):
Expand All @@ -249,7 +125,7 @@ def output_json(x):

if hasattr(x, "__dict__"):
x = x.__dict__
output_str(json.dumps(x, cls=JSONEncoder))
output_str(json.dumps(x))


def input(t: type = None):
Expand All @@ -259,10 +135,8 @@ def input(t: type = None):
return input_str()
elif t is bytes:
return input_bytes()
elif issubclass(t, Codec):
return t.decode(input_bytes())
elif t is dict or t is list:
return json.loads(input_str(), cls=JSONDecoder)
return json.loads(input_str())
elif issubclass(t, Enum):
return t(input_str())
else:
Expand All @@ -276,8 +150,6 @@ def output(x=None):
output_str(x)
elif isinstance(x, bytes):
output_bytes(x)
elif isinstance(x, Codec):
output_bytes(x.encode())
elif isinstance(x, dict) or isinstance(x, list):
output_json(x)
elif isinstance(x, Enum):
Expand Down Expand Up @@ -306,7 +178,7 @@ def get_json(key: str):
x = Var.get_str(key)
if x is None:
return x
return json.loads(x, cls=JSONDecoder)
return json.loads(x)

@staticmethod
def set(key: str, value: Union[bytes, str]):
Expand All @@ -328,7 +200,7 @@ def get_json(key: str):
x = ffi.config_get(key)
if x is None:
return None
return json.loads(x, cls=JSONDecoder)
return json.loads(x)


class HttpResponse:
Expand All @@ -352,7 +224,7 @@ def data_str(self):

def data_json(self):
"""Get response body JSON"""
return json.loads(self.data_str(), cls=JSONDecoder)
return json.loads(self.data_str())

def headers(self):
"""Get HTTP response headers"""
Expand Down
Loading