diff --git a/custom_components/pyscript/__init__.py b/custom_components/pyscript/__init__.py index ff1c5d5..b755744 100644 --- a/custom_components/pyscript/__init__.py +++ b/custom_components/pyscript/__init__.py @@ -21,10 +21,9 @@ EVENT_STATE_CHANGED, SERVICE_RELOAD, ) -from homeassistant.core import Config, HomeAssistant, ServiceCall +from homeassistant.core import Config, Event as HAEvent, HomeAssistant, ServiceCall from homeassistant.exceptions import HomeAssistantError import homeassistant.helpers.config_validation as cv -from homeassistant.core import Event as HAEvent from homeassistant.helpers.restore_state import DATA_RESTORE_STATE from homeassistant.loader import bind_hass @@ -51,6 +50,7 @@ from .requirements import install_requirements from .state import State, StateVal from .trigger import TrigTime +from .webhook import Webhook _LOGGER = logging.getLogger(LOGGER_PATH) @@ -241,6 +241,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b Mqtt.init(hass) TrigTime.init(hass) State.init(hass) + Webhook.init(hass) State.register_functions() GlobalContextMgr.init() diff --git a/custom_components/pyscript/entity.py b/custom_components/pyscript/entity.py index 7b00e06..8150392 100644 --- a/custom_components/pyscript/entity.py +++ b/custom_components/pyscript/entity.py @@ -1,19 +1,19 @@ -"""Entity Classes""" +"""Entity Classes.""" +from homeassistant.const import STATE_UNKNOWN from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import StateType -from homeassistant.const import STATE_UNKNOWN class PyscriptEntity(RestoreEntity): - """Generic Pyscript Entity""" + """Generic Pyscript Entity.""" _attr_extra_state_attributes: dict _attr_state: StateType = STATE_UNKNOWN def set_state(self, state): - """Set the state""" + """Set the state.""" self._attr_state = state def set_attributes(self, attributes): - """Set Attributes""" + """Set Attributes.""" self._attr_extra_state_attributes = attributes diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index c12f4f4..bca4ed4 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -50,6 +50,7 @@ "state_trigger", "event_trigger", "mqtt_trigger", + "webhook_trigger", "state_active", "time_active", "task_unique", @@ -74,6 +75,14 @@ "trigger_time", "var_name", "value", + "webhook_id", +} + +WEBHOOK_METHODS = { + "GET", + "HEAD", + "POST", + "PUT", } @@ -363,6 +372,7 @@ async def trigger_init(self, trig_ctx, func_name): "mqtt_trigger", "state_trigger", "time_trigger", + "webhook_trigger", } arg_check = { "event_trigger": {"arg_cnt": {1, 2, 3}, "rep_ok": True}, @@ -373,6 +383,7 @@ async def trigger_init(self, trig_ctx, func_name): "task_unique": {"arg_cnt": {1, 2}}, "time_active": {"arg_cnt": {"*"}}, "time_trigger": {"arg_cnt": {0, "*"}, "rep_ok": True}, + "webhook_trigger": {"arg_cnt": {1, 2}, "rep_ok": True}, } kwarg_check = { "event_trigger": {"kwargs": {dict}}, @@ -388,6 +399,11 @@ async def trigger_init(self, trig_ctx, func_name): "state_hold_false": {int, float}, "watch": {set, list}, }, + "webhook_trigger": { + "kwargs": {dict}, + "local_only": {bool}, + "methods": {list, set}, + }, } for dec in self.decorators: @@ -517,6 +533,10 @@ async def do_service_call(func, ast_ctx, data): self.trigger_service.add(srv_name) continue + if dec_name == "webhook_trigger" and "methods" in dec_kwargs: + if len(bad := set(dec_kwargs["methods"]).difference(WEBHOOK_METHODS)) > 0: + raise TypeError(f"{exc_mesg}: {bad} aren't valid {dec_name} methods") + if dec_name not in trig_decs: trig_decs[dec_name] = [] if len(trig_decs[dec_name]) > 0 and "rep_ok" not in arg_info: diff --git a/custom_components/pyscript/trigger.py b/custom_components/pyscript/trigger.py index d842ab7..ece082e 100644 --- a/custom_components/pyscript/trigger.py +++ b/custom_components/pyscript/trigger.py @@ -21,6 +21,7 @@ from .function import Function from .mqtt import Mqtt from .state import STATE_VIRTUAL_ATTRS, State +from .webhook import Webhook _LOGGER = logging.getLogger(LOGGER_PATH + ".trigger") @@ -222,13 +223,22 @@ async def wait_until( time_trigger=None, event_trigger=None, mqtt_trigger=None, + webhook_trigger=None, + webhook_local_only=True, + webhook_methods=None, timeout=None, state_hold=None, state_hold_false=None, __test_handshake__=None, ): """Wait for zero or more triggers, until an optional timeout.""" - if state_trigger is None and time_trigger is None and event_trigger is None and mqtt_trigger is None: + if ( + state_trigger is None + and time_trigger is None + and event_trigger is None + and mqtt_trigger is None + and webhook_trigger is None + ): if timeout is not None: await asyncio.sleep(timeout) return {"trigger_type": "timeout"} @@ -238,6 +248,7 @@ async def wait_until( state_trig_eval = None event_trig_expr = None mqtt_trig_expr = None + webhook_trig_expr = None exc = None notify_q = asyncio.Queue(0) @@ -349,6 +360,26 @@ async def wait_until( State.notify_del(state_trig_ident, notify_q) raise exc await Mqtt.notify_add(mqtt_trigger[0], notify_q) + if webhook_trigger is not None: + if isinstance(webhook_trigger, str): + webhook_trigger = [webhook_trigger] + if len(webhook_trigger) > 1: + webhook_trig_expr = AstEval( + f"{ast_ctx.name} webhook_trigger", + ast_ctx.get_global_ctx(), + logger_name=ast_ctx.get_logger_name(), + ) + Function.install_ast_funcs(webhook_trig_expr) + webhook_trig_expr.parse(webhook_trigger[1], mode="eval") + exc = webhook_trig_expr.get_exception_obj() + if exc is not None: + if len(state_trig_ident) > 0: + State.notify_del(state_trig_ident, notify_q) + raise exc + if webhook_methods is None: + webhook_methods = {"POST", "PUT"} + Webhook.notify_add(webhook_trigger[0], webhook_local_only, webhook_methods, notify_q) + time0 = time.monotonic() if __test_handshake__: @@ -394,7 +425,12 @@ async def wait_until( state_trig_timeout = True time_next = now + dt.timedelta(seconds=this_timeout) if this_timeout is None: - if state_trigger is None and event_trigger is None and mqtt_trigger is None: + if ( + state_trigger is None + and event_trigger is None + and mqtt_trigger is None + and webhook_trigger is None + ): _LOGGER.debug( "trigger %s wait_until no next time - returning with none", ast_ctx.name, @@ -527,6 +563,17 @@ async def wait_until( if mqtt_trig_ok: ret = notify_info break + elif notify_type == "webhook": + if webhook_trig_expr is None: + ret = notify_info + break + webhook_trig_ok = await webhook_trig_expr.eval(notify_info) + exc = webhook_trig_expr.get_exception_obj() + if exc is not None: + break + if webhook_trig_ok: + ret = notify_info + break else: _LOGGER.error( "trigger %s wait_until got unexpected queue message %s", @@ -540,6 +587,8 @@ async def wait_until( Event.notify_del(event_trigger[0], notify_q) if mqtt_trigger is not None: Mqtt.notify_del(mqtt_trigger[0], notify_q) + if webhook_trigger is not None: + Webhook.notify_del(webhook_trigger[0], notify_q) if exc: raise exc return ret @@ -826,6 +875,10 @@ def __init__( self.event_trigger_kwargs = trig_cfg.get("event_trigger", {}).get("kwargs", {}) self.mqtt_trigger = trig_cfg.get("mqtt_trigger", {}).get("args", None) self.mqtt_trigger_kwargs = trig_cfg.get("mqtt_trigger", {}).get("kwargs", {}) + self.webhook_trigger = trig_cfg.get("webhook_trigger", {}).get("args", None) + self.webhook_trigger_kwargs = trig_cfg.get("webhook_trigger", {}).get("kwargs", {}) + self.webhook_local_only = self.webhook_trigger_kwargs.get("local_only", True) + self.webhook_methods = self.webhook_trigger_kwargs.get("methods", {"POST", "PUT"}) self.state_active = trig_cfg.get("state_active", {}).get("args", None) self.time_active = trig_cfg.get("time_active", {}).get("args", None) self.time_active_hold_off = trig_cfg.get("time_active", {}).get("kwargs", {}).get("hold_off", None) @@ -842,6 +895,7 @@ def __init__( self.state_trig_ident_any = set() self.event_trig_expr = None self.mqtt_trig_expr = None + self.webhook_trig_expr = None self.have_trigger = False self.setup_ok = False self.run_on_startup = False @@ -933,6 +987,21 @@ def __init__( return self.have_trigger = True + if self.webhook_trigger is not None: + if len(self.webhook_trigger) == 2: + self.webhook_trig_expr = AstEval( + f"{self.name} @webhook_trigger()", + self.global_ctx, + logger_name=self.name, + ) + Function.install_ast_funcs(self.webhook_trig_expr) + self.webhook_trig_expr.parse(self.webhook_trigger[1], mode="eval") + exc = self.webhook_trig_expr.get_exception_long() + if exc is not None: + self.webhook_trig_expr.get_logger().error(exc) + return + self.have_trigger = True + self.setup_ok = True def stop(self): @@ -945,6 +1014,8 @@ def stop(self): Event.notify_del(self.event_trigger[0], self.notify_q) if self.mqtt_trigger is not None: Mqtt.notify_del(self.mqtt_trigger[0], self.notify_q) + if self.webhook_trigger is not None: + Webhook.notify_del(self.webhook_trigger[0], self.notify_q) if self.task: Function.reaper_cancel(self.task) self.task = None @@ -995,6 +1066,11 @@ async def trigger_watch(self): if self.mqtt_trigger is not None: _LOGGER.debug("trigger %s adding mqtt_trigger %s", self.name, self.mqtt_trigger[0]) await Mqtt.notify_add(self.mqtt_trigger[0], self.notify_q) + if self.webhook_trigger is not None: + _LOGGER.debug("trigger %s adding webhook_trigger %s", self.name, self.webhook_trigger[0]) + Webhook.notify_add( + self.webhook_trigger[0], self.webhook_local_only, self.webhook_methods, self.notify_q + ) last_trig_time = None last_state_trig_time = None @@ -1182,6 +1258,11 @@ async def trigger_watch(self): user_kwargs = self.mqtt_trigger_kwargs.get("kwargs", {}) if self.mqtt_trig_expr: trig_ok = await self.mqtt_trig_expr.eval(notify_info) + elif notify_type == "webhook": + func_args = notify_info + user_kwargs = self.webhook_trigger_kwargs.get("kwargs", {}) + if self.webhook_trig_expr: + trig_ok = await self.webhook_trig_expr.eval(notify_info) else: user_kwargs = self.time_trigger_kwargs.get("kwargs", {}) @@ -1237,6 +1318,8 @@ async def trigger_watch(self): Event.notify_del(self.event_trigger[0], self.notify_q) if self.mqtt_trigger is not None: Mqtt.notify_del(self.mqtt_trigger[0], self.notify_q) + if self.webhook_trigger is not None: + Webhook.notify_del(self.webhook_trigger[0], self.notify_q) return def call_action(self, notify_type, func_args, run_task=True): diff --git a/custom_components/pyscript/webhook.py b/custom_components/pyscript/webhook.py new file mode 100644 index 0000000..3c9b06a --- /dev/null +++ b/custom_components/pyscript/webhook.py @@ -0,0 +1,95 @@ +"""Handles webhooks and notification.""" + +import logging + +from aiohttp import hdrs + +from homeassistant.components import webhook + +from .const import LOGGER_PATH + +_LOGGER = logging.getLogger(LOGGER_PATH + ".webhook") + + +class Webhook: + """Define webhook functions.""" + + # + # Global hass instance + # + hass = None + + # + # notify message queues by webhook type + # + notify = {} + notify_remove = {} + + def __init__(self): + """Warn on Webhook instantiation.""" + _LOGGER.error("Webhook class is not meant to be instantiated") + + @classmethod + def init(cls, hass): + """Initialize Webhook.""" + + cls.hass = hass + + @classmethod + async def webhook_handler(cls, hass, webhook_id, request): + """Listen callback for given webhook which updates any notifications.""" + + func_args = { + "trigger_type": "webhook", + "webhook_id": webhook_id, + } + + if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""): + func_args["payload"] = await request.json() + else: + # Could potentially return multiples of a key - only take the first + payload_multidict = await request.post() + func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + + await cls.update(webhook_id, func_args) + + @classmethod + def notify_add(cls, webhook_id, local_only, methods, queue): + """Register to notify for webhooks of given type to be sent to queue.""" + if webhook_id not in cls.notify: + cls.notify[webhook_id] = set() + _LOGGER.debug("webhook.notify_add(%s) -> adding webhook listener", webhook_id) + webhook.async_register( + cls.hass, + "pyscript", # DOMAIN + "pyscript", # NAME + webhook_id, + cls.webhook_handler, + local_only=local_only, + allowed_methods=methods, + ) + cls.notify_remove[webhook_id] = lambda: webhook.async_unregister(cls.hass, webhook_id) + + cls.notify[webhook_id].add(queue) + + @classmethod + def notify_del(cls, webhook_id, queue): + """Unregister to notify for webhooks of given type for given queue.""" + + if webhook_id not in cls.notify or queue not in cls.notify[webhook_id]: + return + cls.notify[webhook_id].discard(queue) + if len(cls.notify[webhook_id]) == 0: + cls.notify_remove[webhook_id]() + _LOGGER.debug("webhook.notify_del(%s) -> removing webhook listener", webhook_id) + del cls.notify[webhook_id] + del cls.notify_remove[webhook_id] + + @classmethod + async def update(cls, webhook_id, func_args): + """Deliver all notifications for an webhook of the given type.""" + + _LOGGER.debug("webhook.update(%s, %s)", webhook_id, func_args) + if webhook_id in cls.notify: + for queue in cls.notify[webhook_id]: + await queue.put(["webhook", func_args.copy()]) diff --git a/docs/reference.rst b/docs/reference.rst index 9742682..f1f82ef 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -787,7 +787,7 @@ Wildcards in topics are supported. The ``topic`` variables will be set to the fu the message arrived on. NOTE: The `MQTT Integration in Home Assistant `__ -must be set up to use ``@mqtt_trigger``. +must be set up to use ``@mqtt_trigger``. @state_active ^^^^^^^^^^^^^ @@ -857,6 +857,39 @@ true if the current time doesn't match any of the "not" (negative) specification allows multiple arguments with and without ``not``. The condition will be met if the current time matches any of the positive arguments, and none of the negative arguments. +@webhook_trigger +^^^^^^^^^^^^^^^^ + +.. code:: python + + @webhook_trigger(webhook_id, str_expr=None, local_only=True, methods={"POST", "PUT"}, kwargs=None) + +``@webhook_trigger`` listens for calls to a `Home Assistant webhook `__ at ``your_hass_url/api/webhook/webhook_id`` and triggers whenever a request is made at that endpoint. Multiple ``@webhook_trigger`` decorators can be applied to a single function if you want to trigger off different webhook ids. + +Setting ``local_only`` option to ``False`` will allow request made from anywhere on the internet (as opposed to just on local network). +The methods option needs to be an list or set with elements ``GET``, ``HEAD``, ``POST``, or ``PUT``. + +An optional ``str_expr`` can be used to match against payload message data, and the trigger will only occur if that expression evaluates to ``True`` or non-zero. This expression has available these three +variables: + +- ``trigger_type`` is set to "webhook" +- ``webhook_id`` is set to the webhook_id that was called. +- ``payload`` is the data/json that was sent in the request returned as a dictionary. + +When the ``@webhook_trigger`` occurs, those same variables are passed as keyword arguments to the function in case it needs them. Additional keyword parameters can be specified by setting the optional ``kwargs`` argument to a ``dict`` with the keywords and values. + +An simple example looks like + +.. code:: python + + @webhook_trigger("myid", kwargs={"extra": 10}) + def webhook_test(payload, extra): + log.info(f"It ran! {payload}, {extra}") + +which if called using the curl command ``curl -X POST -d 'key1=xyz&key2=abc' hass_url/api/webhook/myid`` outputs ``It ran! {'key1': 'xyz', 'key2': 'abc'}, 10`` + +NOTE: A webhook_id can only be used by either a built-in Home Assistant automation or pyscript, but not both. Trying to use the same webhook_id in both will result in an error. + Other Function Decorators ------------------------- @@ -1313,6 +1346,11 @@ It takes the following keyword arguments (all are optional): - ``mqtt_trigger=None`` can be set to a string or list of two strings, just like ``@mqtt_trigger``. The first string is the MQTT topic, and the second string (when the setting is a two-element list) is an expression based on the message variables. +- ``webhook_trigger=None`` can be set to a string or list of two strings, just like ``@webhook_trigger``. The first string is the webhook id, and the second string (when the setting is a two-element list) is an expression based on the message variables. +- ``webhook_local_only=True`` is used with ``webhook_trigger`` to specify whether to only allow + local webhooks. +- ``webhook_methods={"POST", "PUT"}`` is used with ``webhook_trigger`` to specify allowed webhook + request methods. - ``timeout=None`` an overall timeout in seconds, which can be floating point. - ``state_check_now=True`` if set, ``task.wait_until()`` checks any ``state_trigger`` immediately to see if it is already ``True``, and will return immediately if so. If @@ -1566,7 +1604,7 @@ Pyscript supports importing two types of packages or modules: will cause all of the module files to be unloaded, and any scripts or apps that import that module will be reloaded. Imports of pyscript modules and packages are not affected by the ``allow_all_imports`` setting - if a file is in the ``/pyscript/modules`` folder then it can be imported. - + Package-style layout is also supported where a PACKAGE is defined in ``/pyscript/modules/PACKAGE/__init__.py``, and that file can, in turn, do relative imports of other files in that same directory. This form is most convenient for diff --git a/tests/test_decorator_errors.py b/tests/test_decorator_errors.py index 273a9eb..408a1f4 100644 --- a/tests/test_decorator_errors.py +++ b/tests/test_decorator_errors.py @@ -1,4 +1,5 @@ """Test pyscript decorator syntax error and eval-time exception reporting.""" + from ast import literal_eval import asyncio from datetime import datetime as dt @@ -217,7 +218,7 @@ def func4(): """, ) assert ( - "func4 defined in file.hello: needs at least one trigger decorator (ie: event_trigger, mqtt_trigger, state_trigger, time_trigger)" + "func4 defined in file.hello: needs at least one trigger decorator (ie: event_trigger, mqtt_trigger, state_trigger, time_trigger, webhook_trigger)" in caplog.text ) @@ -460,3 +461,23 @@ def func7(): "TypeError: function 'func7' defined in file.hello: decorator @state_trigger keyword 'watch' should be type list or set" in caplog.text ) + + +@pytest.mark.asyncio +async def test_webhooks_method(hass, caplog): + """Test invalid keyword arguments type generates an error.""" + + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_trigger("hook", methods=["bad"]) +def func8(): + pass +""", + ) + assert ( + "TypeError: function 'func8' defined in file.hello: {'bad'} aren't valid webhook_trigger methods" + in caplog.text + ) diff --git a/tests/test_jupyter.py b/tests/test_jupyter.py index f9fa0e7..c8fbb9a 100644 --- a/tests/test_jupyter.py +++ b/tests/test_jupyter.py @@ -487,6 +487,7 @@ async def test_jupyter_kernel_redefine_func(hass, caplog, socket_enabled): @time_trigger("once(2019/9/7 12:00)") @state_trigger("pyscript.var1 == '1'") @event_trigger("test_event") +@webhook_trigger("test_hook1") def func(): pass 123 @@ -504,6 +505,7 @@ def func(): @time_trigger("once(2019/9/7 13:00)") @state_trigger("pyscript.var1 == '1'") @event_trigger("test_event2") +@webhook_trigger("test_hook1") def func(): pass 321