diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index a8ed542..0000000 --- a/.coveragerc +++ /dev/null @@ -1,12 +0,0 @@ -[run] -# don't collect the plugins subdirectories -omit = - plugins/*/env/*.py - plugins/conftest.py - plugins/*/*/tests/* - tests/plugins/* - sideboard/tests/* - tests/* - -include = - sideboard/* diff --git a/.gitignore b/.gitignore index 13ad716..29f3e03 100644 --- a/.gitignore +++ b/.gitignore @@ -10,14 +10,18 @@ dist/ *.pyc *.pyo development.ini +test.ini sideboard/docs plugins/* !plugins/conftest.py data/sessions/* +data/profiler/* *-bootstrap.py data/*.db .tox node_modules/ +.cache distribute-*.tar.gz distribute_setup.py -.eggs/ \ No newline at end of file +.eggs/ +.coverage diff --git a/.travis.yml b/.travis.yml index 9b35c7d..040cd9c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,27 @@ language: python -python: 2.7 -env: - - TOX_ENV=py27 - - TOX_ENV=py33 - - TOX_ENV=py34 +matrix: + include: + - python: "2.7" + env: TOX_ENV=pep8 + - python: "2.7" + env: TOX_ENV=py27 + - python: "3.3" + env: TOX_ENV=py33 + - python: "3.4" + env: TOX_ENV=py34 + - python: "3.5" + env: TOX_ENV=py35 +before_install: + - sudo apt-get -qq update + - sudo apt-get install -y build-essential libcap-dev install: - pip install tox - - pip install python-coveralls + - if [[ $TOX_ENV == py27 ]] || [[ $TOX_ENV == py35 ]]; then + pip install coveralls; + fi script: - tox -e $TOX_ENV after_success: - coveralls + - if [[ $TOX_ENV == py27 ]] || [[ $TOX_ENV == py35 ]]; then + coveralls; + fi diff --git a/Dockerfile b/Dockerfile index 512f9d0..57404ff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,38 +1,59 @@ -FROM python:3.4.3 +FROM python:3.4.5 MAINTAINER RAMS Project "code@magfest.org" LABEL version.sideboard ="1.0" - WORKDIR /app -# verify gpg and sha256: http://nodejs.org/dist/v0.10.30/SHASUMS256.txt.asc -# gpg: aka "Timothy J Fontaine (Work) " -# gpg: aka "Julien Gilli " +# This is actually the least bad way to compose two Dockerfile tech stacks right now. +# The following is copied and pasted from the Node Dockerfile at +# https://github.com/nodejs/docker-node/blob/28425ed95cebaea2ff589c1516d79c60181983b2/7.4/Dockerfile +# Update this comment and change the entire copypasta section to upgrade Node version + +######################################### +# START NODEJS DOCKERFILE COPYPASTA # +# https://github.com/nodejs/docker-node # +######################################### +RUN groupadd --gid 1000 node \ + && useradd --uid 1000 --gid node --shell /bin/bash --create-home node + +# gpg keys listed at https://github.com/nodejs/node#release-team RUN set -ex \ - && for key in \ - 7937DFD2AB06298B2293C3187D33FF9D0246406D \ - 114F43EE0176B71C7BC219DD50A3051F888C628D \ - ; do \ - gpg --keyserver ha.pool.sks-keyservers.net --recv-keys "$key"; \ - done - -ENV NODE_VERSION 0.12.7 -ENV NPM_VERSION 2.14.1 -ENV BOWER_VERSION 1.5.2 -ENV GRUNT_VERSION 0.1.13 - -RUN curl -SLO "https://nodejs.org/dist/v$NODE_VERSION/node-v$NODE_VERSION-linux-x64.tar.gz" \ - && curl -SLO "https://nodejs.org/dist/v$NODE_VERSION/SHASUMS256.txt.asc" \ - && gpg --verify SHASUMS256.txt.asc \ - && grep " node-v$NODE_VERSION-linux-x64.tar.gz\$" SHASUMS256.txt.asc | sha256sum -c - \ - && tar -xzf "node-v$NODE_VERSION-linux-x64.tar.gz" -C /usr/local --strip-components=1 \ - && rm "node-v$NODE_VERSION-linux-x64.tar.gz" SHASUMS256.txt.asc \ - && npm install -g npm@"$NPM_VERSION" \ - && npm install -g bower@"$BOWER_VERSION" grunt-cli@"$GRUNT_VERSION" \ - && echo '{ "allow_root": true }' > /root/.bowerrc \ - && npm cache clear + && for key in \ + 9554F04D7259F04124DE6B476D5A82AC7E37093B \ + 94AE36675C464D64BAFA68DD7434390BDBE9B9C5 \ + FD3A5288F042B6850C66B31F09FE44734EB7990E \ + 71DCFD284A79C3B38668286BC97EC7A07EDE3FC1 \ + DD8F2338BAE7501E3DD5AC78C273792F7D83545D \ + B9AE9905FFD7803F25714661B63B535A4C206CA9 \ + C4F0DFFF4E8C1A8236409D08E73BC641CC11F4C8 \ + 56730D5401028683275BD23C23EFEFE93C4CFFFE \ + ; do \ + gpg --keyserver ha.pool.sks-keyservers.net --recv-keys "$key" || \ + gpg --keyserver pgp.mit.edu --recv-keys "$key" || \ + gpg --keyserver keyserver.pgp.com --recv-keys "$key" ; \ + done + +ENV NPM_CONFIG_LOGLEVEL info +ENV NODE_VERSION 7.10.0 + +RUN curl -SLO "https://nodejs.org/dist/v$NODE_VERSION/node-v$NODE_VERSION-linux-x64.tar.xz" \ + && curl -SLO "https://nodejs.org/dist/v$NODE_VERSION/SHASUMS256.txt.asc" \ + && gpg --batch --decrypt --output SHASUMS256.txt SHASUMS256.txt.asc \ + && grep " node-v$NODE_VERSION-linux-x64.tar.xz\$" SHASUMS256.txt | sha256sum -c - \ + && tar -xJf "node-v$NODE_VERSION-linux-x64.tar.xz" -C /usr/local --strip-components=1 \ + && rm "node-v$NODE_VERSION-linux-x64.tar.xz" SHASUMS256.txt.asc SHASUMS256.txt \ + && ln -s /usr/local/bin/node /usr/local/bin/nodejs +################################### +# END NODEJS DOCKERFILE COPYPASTA # +################################### + +# required for python-prctl +RUN apt-get update && apt-get install -y libcap-dev && rm -rf /var/lib/apt/lists/* ADD . /app/ -RUN python3 setup.py develop +RUN pip3 install virtualenv \ + && virtualenv --always-copy /app/env \ + && /app/env/bin/pip3 install paver +RUN /app/env/bin/paver install_deps -CMD python3 /app/sideboard/run_server.py +CMD /app/env/bin/python3 /app/sideboard/run_server.py EXPOSE 8282 diff --git a/README.md b/README.md index 1b71a08..bffd097 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -Sideboard [![Build Status](https://travis-ci.org/appliedsec/sideboard.svg)](https://travis-ci.org/appliedsec/sideboard)[![Coverage Status](https://coveralls.io/repos/appliedsec/sideboard/badge.png)](https://coveralls.io/r/appliedsec/sideboard) +Sideboard [![Build Status](https://travis-ci.org/magfest/sideboard.svg)](https://travis-ci.org/magfest/sideboard) [![Coverage Status](https://coveralls.io/repos/github/magfest/sideboard/badge.svg?branch=master)](https://coveralls.io/github/magfest/sideboard?branch=master) ========= Getting Started diff --git a/data/paver/skeleton/__init__.py b/data/paver/skeleton/__init__.py index 4f1a14c..a91f5a9 100644 --- a/data/paver/skeleton/__init__.py +++ b/data/paver/skeleton/__init__.py @@ -1,4 +1,5 @@ import os +import re from datetime import datetime import six @@ -9,6 +10,7 @@ __here__ = os.path.dirname(os.path.abspath(__file__)) + def render(to_render, settings): if isinstance(to_render, six.string_types): return env.from_string(to_render).render(settings) @@ -16,18 +18,18 @@ def render(to_render, settings): with open(os.path.join(__here__, *to_render)) as template_file: return env.from_string(template_file.read()).render(settings) + def create_plugin(plugins_dir, plugin, **settings): assert ' ' not in plugin, "plugins probably shouldn't have spaces; but either way we aren't specifically handling spaces" module = plugin.replace('-', '_') plugin = plugin.replace('_', '-') settings.update({'plugin': plugin, 'module': module, 'generated_date': datetime.utcnow()}) - + package_dir = os.path.join(plugins_dir, plugin) assert not os.path.exists(package_dir), '{} plugin already exists at {}'.format(plugin, package_dir) os.makedirs(os.path.join(package_dir, module, 'tests')) for fname, template in TEMPLATES.items(): fname = render(fname, settings) - if fname: fpath = os.path.join(package_dir, fname) try: @@ -36,17 +38,19 @@ def create_plugin(plugins_dir, plugin, **settings): pass with open(fpath, 'w') as f: - f.write(render(template, settings) + '\n') + # our templates often have a lot of {% if %} clauses which lead to a lot of blank lines, + # so we collapse those such that we never have more than 1 blank line in a row + f.write(re.sub(r'\n{3,}', '\n\n', render(template, settings).strip() + '\n')) - if settings.get('generate_sphinx', True): + if settings.get('sphinx', True): sphinx_settings = dict( path=os.path.join(package_dir, 'docs'), sep=False, dot='_', project=plugin, author='{} Team'.format(plugin), - release='0.1', - version='0.1', + release='0.1.0', + version='0.1.0', suffix='.rst', master='index', epub=False, @@ -62,10 +66,8 @@ def create_plugin(plugins_dir, plugin, **settings): makefile=True, batchfile=False ) - sphinx.quickstart.generate(sphinx_settings) - TEMPLATES = { '{{ module }}/_version.py': ('templates', '_version.py.template'), 'requirements.txt': ('templates', 'requirements.txt.template'), diff --git a/data/paver/skeleton/templates/__init__.py.template b/data/paver/skeleton/templates/__init__.py.template index 73728fc..ec09510 100644 --- a/data/paver/skeleton/templates/__init__.py.template +++ b/data/paver/skeleton/templates/__init__.py.template @@ -1,8 +1,16 @@ -from __future__ import unicode_literals - -{% if webapp %}import cherrypy +from __future__ import unicode_literals{% if django %} +import os +import sys +{% endif %} +{% if webapp or django %} +import cherrypy +{% if django %} +import django +from django.core.handlers.wsgi import WSGIHandler +{% endif %} {% endif %} + from sideboard.lib import log, parse_config{% if webapp %}, render_with_templates{% endif %}{% if service or sqlalchemy %}, services{% endif %} from {{ module }}._version import __version__ @@ -13,6 +21,7 @@ from {{ module }} import service services.register(service, '{{ module }}') {% endif %} + {% if sqlalchemy %} from {{ module }} import sa services.register(sa.Session.crud, '{{ module }}_crud') @@ -30,3 +39,23 @@ class Root(object): cherrypy.tree.mount(Root(), '/{{ module }}') {% endif %} + +{% if django %} +# add our Django site to our Python path so we can import it +sys.path.append(os.path.join(config['root'], '{{ django }}')) + +# since we're not using mod_wsgi we'll use the env var approach to setting up Django +os.environ['DJANGO_SETTINGS_MODULE'] = '{{ django }}.settings' +django.setup() +cherrypy.tree.graft(WSGIHandler(), '/{{ django }}') + +# expose the static files used by the Django admin interface +# NOTE: if you have Apache serving these files directly then you can remove this part +class Static(object): + admin = cherrypy.tools.staticdir.handler( + section="/admin", + dir=os.path.dirname(django.__file__) + '/contrib/admin/static/admin' + ) + +cherrypy.tree.mount(Static(), '/{{ django }}/static') +{% endif %} diff --git a/data/paver/skeleton/templates/_version.py.template b/data/paver/skeleton/templates/_version.py.template index 5222413..748c016 100644 --- a/data/paver/skeleton/templates/_version.py.template +++ b/data/paver/skeleton/templates/_version.py.template @@ -1,2 +1,3 @@ from __future__ import unicode_literals -__version__ = '0.1' + +__version__ = '0.1.0' diff --git a/data/paver/skeleton/templates/fabfile.py.template b/data/paver/skeleton/templates/fabfile.py.template index e34e3f5..aaa17f2 100644 --- a/data/paver/skeleton/templates/fabfile.py.template +++ b/data/paver/skeleton/templates/fabfile.py.template @@ -1,4 +1,5 @@ from __future__ import unicode_literals +import time from os.path import abspath, basename, dirname, join from sh import pip, fpm, chmod @@ -12,24 +13,32 @@ POSTINSTALL = '/tmp/postinstall.sh' def _make_postinstall_script(): with open(POSTINSTALL, 'w') as f: f.write('#!/bin/bash\n') - f.write('/opt/sideboard/bin/pip install --use-wheel --find-links /opt/sideboard/plugins/{package_name}/wheelhouse/ -r /opt/sideboard/plugins/{package_name}/requirements.txt' + f.write('set -e\n') + f.write('source /opt/sideboard/bin/activate && /opt/sideboard/bin/pip install --use-wheel --find-links /opt/sideboard/plugins/{package_name}/wheelhouse/ -r /opt/sideboard/plugins/{package_name}/requirements.txt\n' .format(package_name=package_name)) + f.write('chown -R sideboard.sideboard /opt/sideboard\n') chmod('755', POSTINSTALL) -def package(package_type): +def package(package_type, iteration='testing'): import sideboard plugin = __import__(plugin_name) pip('wheel', r='requirements.txt') _make_postinstall_script() + if iteration == 'testing': + iteration = '0.{}'.format(int(time.time())) fpm('-t', package_type, '-s', 'dir', '--{}-user'.format(package_type), 'sideboard', '--{}-group'.format(package_type), 'sideboard', '--name', 'sideboard-{}'.format(package_name), '--version', plugin.__version__, + '--license', 'COMPANY-PROPRIETARY', + '--iteration', iteration, '--depends', 'sideboard >= {}'.format(sideboard.__version__), '--after-install', POSTINSTALL, + '--config-files', '/etc/sideboard/plugins.d/{package_name}.cfg'.format(package_name=package_name), './package-support/{package_name}.cfg=/etc/sideboard/plugins.d/{package_name}.cfg'.format(package_name=package_name), './requirements.txt=/opt/sideboard/plugins/{}/requirements.txt'.format(package_name), - './wheelhouse=/opt/sideboard/plugins/{}'.format(package_name), + './wheelhouse=/opt/sideboard/plugins/{}'.format(package_name),{% if django %} + './{{ django }}=/opt/sideboard/plugins/{}/{{ django }}'.format(package_name),{% endif %} './{}=/opt/sideboard/plugins/{}'.format(plugin_name, package_name)) diff --git a/data/paver/skeleton/templates/index.html.template b/data/paver/skeleton/templates/index.html.template index 82555a2..0bbc0f1 100644 --- a/data/paver/skeleton/templates/index.html.template +++ b/data/paver/skeleton/templates/index.html.template @@ -1,9 +1,9 @@ - +{% raw %} - $(( plugin ))$ skeleton page + {{ plugin }} skeleton page - ((% if header %)) -

Hello $(( plugin ))$ developer!

- ((% endif %)) + {% if header %} +

Hello {{ plugin }} developer!

+ {% endif %} - +{% endraw%} diff --git a/development-defaults.ini b/development-defaults.ini index 2353ddf..6ce0355 100644 --- a/development-defaults.ini +++ b/development-defaults.ini @@ -1,22 +1,11 @@ debug = True ws.auth_required = False -# set priority plugin so the main uber plugin always loads first -priority_plugins = uber, [cherrypy] server.socket_host = "0.0.0.0" engine.autoreload.on = True -tools.sessions.timeout = 4320 # 30 days (in minutes) + [loggers] root = "DEBUG" - -[handlers] -[[stdout]] -class = "logging.StreamHandler" -stream = "ext://sys.stderr" - -#[[syslog]] -#class = "logging.handlers.SysLogHandler" -#address = "/dev/log" diff --git a/docs/source/index.rst b/docs/source/index.rst index b257063..7e49746 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -56,7 +56,7 @@ Let's start by cloning the Sideboard repo and running it without any plugins: .. code-block:: none - $ git clone https://github.com/appliedsec/sideboard.git + $ git clone https://github.com/magfest/sideboard $ cd sideboard/ $ paver make_venv $ ./env/bin/python sideboard/run_server.py @@ -305,18 +305,14 @@ So this sets us up to be able to change our index.html to be a template that use Ragnarok Aggregation -

$(( apocalypse ))$

- ((% for website, status in all_checks.items() %)) -

$(( website ))$ - $(( status.result ))$

- ((% endfor %)) +

{{ apocalypse }}

+ {% for website, status in all_checks.items() %} +

{{ website }} - {{ status.result }}

+ {% endfor %} -So now we can go back to ``_ and see a summary of our end-of-the-world checks. A few things to note about our latest code: - -* We return a dictionary from our page handler; since the page handler is called ``index``, the dictionary it returns is used to render the ``index.html`` `jinja template `_ in our configured templates directory - -* Notice that the jinja template tokens are not the default; they have been swapped out so that they do not conflict with `angular `_ which is our Javascript framework of choice (Sideboard doesn't require you to use Angular but we highly recommend it!) +So now we can go back to ``_ and see a summary of our end-of-the-world checks. One thing to note about this page handler is that it returns a dictionary. Since the page handler is called ``index``, the dictionary it returns is used to render the ``index.html`` `jinja template `_ in our configured templates directory. So let's make this extra-dynamic; we'll use websockets to subscribe to our service so that anytime our data changes, we'll automatically get an update. We're using `Angular `_ because Sideboard comes with some WebSocket helpers which are written with Angular. @@ -361,14 +357,18 @@ So let's make this extra-dynamic; we'll use websockets to subscribe to our servi Note that when you press the "Refresh" button the data gets automatically updated even though all we did was make call to the server without doing anything with the response. That happened because of the following sequence of steps: + * we subscribe to the ``ragnarok.all_checks`` method when the page loads, so our callback will be called anytime we get a message from the server with new data + * when the refresh button is pressed, it calls the ``ragnarok.check_for_apocalypse`` method which updates the database + * because of how we used the ``@subscribes`` and ``@notifies`` decorators on these methods, calling ``check_for_apocalypse`` automatically causes the latest data to be pushed to the client which is subscribed to ``all_checks`` + * our callback is fired again, which updates the data on the scope and the latest data is rendered to the page Even without pressing the refresh button, the data on this page would still update every 24 hours since we defined that ``DaemonTask`` which calls ``check_for_apocalypse`` once per day. -Since this is our only plugin, we'd probably like this webpage to be the defauly page for this Sideboard site, so let's open our plugin's ``sideboard/configspec.ini`` and add the following line: +Since this is our only plugin, we'd probably like this webpage to be the default page for this Sideboard site, so let's open our plugin's ``sideboard/configspec.ini`` and add the following line: .. code-block:: none @@ -378,6 +378,55 @@ So now if we re-start our web server by re-running ``./env/bin/python sideboard/ +Using Django With Sideboard +=========================== + +CherryPy is a WSGI container, which means that anything which runs in Apache with ``mod_wsgi`` can run in CherryPy. In this section we'll focus on creating a Django project inside a Sideboard plugin. We're specifically documenting how to use Django because it's the most popular Python web framework, but other WSGI-compatible frameworks such as Flask can be used in the same way. + +Let's create a new Sideboard plugin, this time without any of the usual pieces and tell it that we'll be including a Django site called ``mysite`` (we'll be following the Django tutorial, which uses that name). + +.. code-block:: none + + ./env/bin/paver create_plugin --no_webapp --no_sqlalchemy --no_service --no_sphinx --django=mysite --name=unchained + +After doing this, we now have the following directory structure created in the ``plugins`` directory: + +.. code-block:: none + + unchained/ + |-- development-defaults.ini + |-- fabfile.py + |-- MANIFEST.in + |-- package-support + | `-- unchained.cfg + |-- requirements.txt + |-- setup.cfg + |-- setup.py + `-- unchained + |-- configspec.ini + |-- __init__.py + |-- tests + | `-- __init__.py + `-- _version.py + +Note that this did **not** automatically create the Django project. The plugin that was created expects that Django project to exist, and it won't work until we create that project manually. First, we'll need to add Django as a dependency by opening up ``plugins/unchained/requirements.txt`` and adding something like ``Django==1.9.2`` or whatever version of Django you'd like to use. Then you can run ``python setup.py develop`` in your plugin's directory (or run ``paver install_deps`` from the main Sideboard directory). + +After that, you can follow the `Django tutorial `_ to create a site. As explained in the tutorial, in your top-level ``unchained`` directory you can run ``django-admin startproject mysite`` to creates the Django project alongside your plugin module. The one thing you'll need to do differently from what the tutorial says is that you'll need to set + +.. code-block:: python + + STATIC_URL = '/unchained/static/' + +in ``plugins/unchained/mysite/mysite/settings.py`` because we're mounting our Django app at the ``/unchained`` mount point in CherryPy. + +This approach maintains a Sideboard plugin whose module lives alongside a standalone Django project. We do this in order to more easily run ``manage.py`` commands, which shouldn't generally need to know or care about Sideboard. This also means that you can potentially write a Django app that will run in any mod_wsgi container, then have the Sideboard plugin call into it when you need to do Sideboard-specific things such as exposing ``services`` API calls. + +From here you can run through the Django tutorial. You'll be able to visit ``_ to see the Django admin interface, and once you write the "polls" app you'll be able to visit ``_ to access its views. (You don't currently get links to these in the ``/list_plugins`` page of Sideboard.) + +Your Django project will be included in your RPM as packaged by your ``fabfile``. + + + Writing Unit Tests ------------------ @@ -455,7 +504,7 @@ Services .. method:: register(module[, namespace]) - exposes everything in a module (or any object with callable functions) + Exposes all methods whose names do not begin with underscores in a module (or any object with callable functions). If the module defines ``__all__``, only methods included in ``__all__`` will be exposed. :param module: the module you are exposing; any function in this module which is not prefixed with an underscore will be callable :param namespace: the prefix which consumers calling your method over RPC will use; if omitted this defaults to your module name @@ -652,13 +701,16 @@ When a function which has been declared to update those channels is called, the return "ok" -.. function:: notify(channels) +.. function:: notify(channels, delay=False) - Explicitly cause all client listening on those files + Explicitly cause all client listening on the given channel(s) to check for new data. Unless you have some specific reason to use this function, you should probably just decorate the appropriate function with the ``@notifies`` decorator. + By default this triggers Sideboard's broadcaster thread to check for new data on the given channel(s) immediately. However, plugins might find themselves in a situation where they want to trigger a notification after the current RPC method has finished running. For example, notify() might be called while in the middle of a database transaction and we want to make sure that the transaction has been committed before the ``notify`` occurs. To support this case, we can set the ``delay`` parameter to ``True``, in which case Sideboard will wait until we finish executing the current RPC method to signal the broadcaster thread. (If called with a ``delay`` outside of an RPC request, no notify will occur.) + :param channels: a string or list of strings, which are the names of the channels to notify + :param delay: boolean indicating whether to fire immediately or at the end of the execution of the current RPC method So in the above examples listed with the ``@subscribes`` and ``@notifies`` decorators, we might see the following sequence of requests and responses: @@ -730,11 +782,13 @@ Sideboard provides several useful classes for establishing websocket connections * *callback* (required): a function taking a single argument; this is called every time we receive data back from the server - * *errback* (required): a function taking a single argument; this is called every time we recieve an error response from the server and passed the error message + * *errback* (optional): a function taking a single argument; this is called every time we recieve an error response from the server and passed the error message (if omitted we log an error message and continue) + + * *paramback* (optional): a function taking no arguments; if present, this is called to generate the params for this subscription, both initially and every time the websocket reconnects * *client* (optional): the client id to use; this will be automatically generated, and you should omit this unless you really know what you're doing - :param method: the name of the method you want to subscribe to; you may pass either positional or keyword arguments (but not both) to this method which will be sent as part of the RPC message + :param method: the name of the method you want to subscribe to; you may pass either positional or keyword arguments (but not both) to this method which will be sent as part of the RPC message. If "paramback" is passed (see above) then args/kwargs will be ignored entirely and that will be used to generate the parameters. :returns: the automatically generated client id; you can pass this to ``unsubscribe`` if you want to keep this connection open but cancel this one subscription .. method:: unsubscribe(client1[, client2[, ...]]) @@ -919,7 +973,7 @@ When we refer to "Sideboard starting" and "Sideboard stopping" we are referring .. function:: on_startup(func[, priority=50]) - Cause a function to get called on startup. You can use this as a decorator, or directly call it and pass a priority level which indicates the order in which the startup functions should be called. + Cause a function to get called on startup. You can use this as a decorator or directly call it. The priority level indicates the order in which the startup functions should be called, where low numbers come before high numbers. .. code-block:: python @@ -927,10 +981,14 @@ When we refer to "Sideboard starting" and "Sideboard stopping" we are referring def f(): print("Hello World!") + @on_startup(priority=40) def g(): print("Hello Kitty!") - on_startup(g, priority=40) + def h(): + print("Goodbye World!") + + on_startup(h, priority=60) :param func: The function to be called on startup; this function must be callable with no arguments. :param priority: Order in which startup functions will be called (lower priority items are called first); this can be any integer and the numbers are used for nothing other than ordering the handlers. @@ -976,7 +1034,7 @@ When we refer to "Sideboard starting" and "Sideboard stopping" we are referring .. class:: TimeDelayQueue() - Subclass of `Queue.Queue `_ which adds an optional ``delay`` parameter to the ``put`` method which does not add the item to the queue until after the specified amount of time. This is used internally but is included in our public API in case it's useful to anyone else. + Subclass of `Queue.Queue `_ which adds an optional ``delay`` parameter to the ``put`` method which does not add the item to the queue until after the specified amount of time. This is included in our public API in case it's useful, though Sideboard itself no longer makes use of it as part of its internal implementation. .. method:: put(item[, block=True[, timeout=None[, delay=0]]]): @@ -988,8 +1046,7 @@ When we refer to "Sideboard starting" and "Sideboard stopping" we are referring Utility class allowing code to call the provided function in a separate pool of threads. For example, if you need to call a long-running function in the handler for an HTTP request, you might want to just kick off the method in a background thread so that you can return from the page handler immediately. >>> caller = Caller(long_running_func) - >>> caller.defer('arg1', arg2=True) # called immediately (in another thread) - >>> caller.delay(5, 'argOne', arg2=False) # called after a 5 second delay (in another thread) + >>> caller.defer('arg1', arg2=True) # triggers an immediate call in another thread :param func: the function to be executed in the background; this must be callable with no arguments :param threads: the number of threads which will call this function; sometimes you may want a pool of threads all calling the same function @@ -999,20 +1056,15 @@ When we refer to "Sideboard starting" and "Sideboard stopping" we are referring .. method:: defer(*args, **kwargs) Pass a set of arguments and keyword arguments which will be used to call this instance's function in a background thread. - - .. method:: delay(seconds, *args, **kwargs) - - Call this instance's function in a background thread after the specified delay with the passed position and keyword arguments. .. class:: GenericCaller([threads=1]) Like the ``Caller`` class above, except that instead of calling the same method with provided arguments, this lets you spin up a pool of background threads which will call any methods you specify, e.g. - >>> from __future__ import print_function + >>> from __future__ import print_function # unnecessary in Python 3 >>> gc = GenericCaller() - >>> gc.defer(print, 'Hello', 'World', sep=', ', end='!') # prints "Hello, World!" - >>> gc.delay(5, print, 'Hello', 'World', sep=', ', end='!') # prints "Hello, World!" after 5 seconds + >>> gc.defer(print, 'Hello', 'World', sep=', ', end='!') # prints "Hello, World!" :param threads: the number of threads which will call this function; sometimes you may want a pool of threads all calling the same function @@ -1073,6 +1125,13 @@ Miscellaneous :param key: name (as a string) of the value to store :param value: the value to store; this can be anything and does not need to be pickleable or otherwise serializable + .. classmethod:: setdefault(key, value) + + Check whether the given key already has a value set; if not then set the provided value. Either way, return whatever the current value now is. + + :param key: name (as a string) of the value to optionally-set-and-definitely-return + :param value: default value to set if no value is already set for this key + .. classmethod:: get_client() Returns the websocket client id used to make this request, or None if not applicable. This value may be present in both websocket and jsonrpc requests; in the latter case it would be present as the ``websocket_client`` key of the request, e.g. a jsonrpc request that looks like this: @@ -1086,6 +1145,11 @@ Miscellaneous "params": ["20191"], "websocket_client": "client-623" } + + .. attribute:: client_data + + Class property which returns a dictionary. For websocket subscriptions, this dictionary is persisted and then restored before the function is re-called, so that methods can store data on a per-subscription basis. This is like a server "session" except that it's per-subscription instead of per-user. + sideboard.lib.sa diff --git a/fabfile.py b/fabfile.py index 336cc5f..04082a8 100644 --- a/fabfile.py +++ b/fabfile.py @@ -1 +1,32 @@ -from ship_it import fpm +from __future__ import unicode_literals +import time +import os.path + +import yaml +import ship_it + + +__here__ = os.path.abspath(os.path.dirname(__file__)) +MANIFEST_YAML = os.path.join(__here__, 'manifest.yaml') +MANIFEST_TEMPLATE = MANIFEST_YAML + '.template' + + +def _populate_manifest_and_invoke_fpm(iteration): + import sideboard + with open(MANIFEST_TEMPLATE) as f: + manifest = yaml.load(f) + manifest[b'version'] = sideboard.__version__ + manifest[b'iteration'] = iteration + + with open(MANIFEST_YAML, 'w') as f: + yaml.dump(manifest, f) + + ship_it.fpm(MANIFEST_YAML) + + +def fpm_stable(iteration): + _populate_manifest_and_invoke_fpm(iteration) + + +def fpm_testing(): + _populate_manifest_and_invoke_fpm(b'0.{}'.format(int(time.time()))) diff --git a/manifest.yaml b/manifest.yaml deleted file mode 100644 index 74b50fc..0000000 --- a/manifest.yaml +++ /dev/null @@ -1,13 +0,0 @@ -description: Sideboard Framework -name: sideboard -version: 0.1.0 -before_install: package-support/preinstall.sh -config_files: - /etc/sideboard/sideboard-server.cfg: package-support/sideboard-server.cfg - /etc/init.d/sideboard-server: package-support/sideboard-server -build_dependencies: - - rpm-build - - prelink - - ruby-devel -run_dependencies: - - python diff --git a/manifest.yaml.template b/manifest.yaml.template new file mode 100644 index 0000000..1a735a2 --- /dev/null +++ b/manifest.yaml.template @@ -0,0 +1,10 @@ +description: Sideboard Framework +name: sideboard +before_install: package-support/preinstall.sh +after_install: package-support/postinstall.sh +config_files: + /etc/sideboard/sideboard-server.cfg: package-support/sideboard-server.cfg + /etc/init.d/sideboard: package-support/init.d/sideboard + /etc/sysconfig/sideboard: package-support/sysconfig/sideboard +depends: + - python27 diff --git a/package-support/init.d/sideboard b/package-support/init.d/sideboard new file mode 100755 index 0000000..03383bd --- /dev/null +++ b/package-support/init.d/sideboard @@ -0,0 +1,181 @@ +#!/bin/bash + +# chkconfig: - 88 14 +# description: Sideboard exit node manager +# processname: sideboard +# +### BEGIN INIT INFO +# Provides: sideboard +# Required-Start: $local_fs $remote_fs $network $named +# Required-Stop: $local_fs $remote_fs $network +# Short-Description: start and stop sideboard +# Description: Sideboard exit node manager +### END INIT INFO + +RETVAL=0 +prog="sideboard" + +DESC=sideboard +VENV=/opt/sideboard +PYTHON=$VENV/bin/python +SEP=$VENV/bin/sep +CHERRYD=$VENV/bin/cherryd +USER=sideboard +PID_FILE=/var/run/sideboard/sideboard.pid +COUNTDOWN=10 + +SIDEBOARDMODE=server +if [ -f /etc/sysconfig/sideboard ]; then + # import $SIDEBOARDMODE from defaults file +. /etc/sysconfig/sideboard +fi + +OPTIONS="mainloop_daemon --pidfile=$PID_FILE" +if [ "$SIDEBOARDMODE" == "server" ]; then + OPTIONS="-d --pidfile=$PID_FILE --import=sideboard.server" +fi + +procgrep() { + grep "python.*${DESC}" | grep -v grep +} + +filepid() { + cat $PID_FILE 2>/dev/null +} + +procpid() { + ps aux | procgrep | awk '{print $2}' +} + +allpids () { + FP=$(filepid) + PP=$(procpid) + if [ "$PP" == "$FP" ]; then + echo $PP + else + FP_IS_CORRECT_PROC=$(cat /proc/$FP/cmdline 2>/dev/null | procgrep) + if [ -n "$FP_IS_CORRECT_PROC" ]; then + echo $PP $FP + else + echo $PP + fi + fi +} + +isrunning() { + if [ -n "$(filepid)" ]; then + RETVAL=0 + elif [ -n "$(procpid)" ]; then + RETVAL=0 + else + RETVAL=1 + fi + return $RETVAL +} + +start() { + if isrunning; then + echo "Starting $prog: $prog is already running [FAIL]" + RETVAL=1 + else + echo -n $"Starting $prog: " + if [ "$SIDEBOARDMODE" == "server" ]; then + sudo -u $USER $CHERRYD $OPTIONS + else + sudo -u $USER $SEP $OPTIONS + fi + RETVAL=$? + if [ $RETVAL -eq 0 ]; then + echo '[OK]' + else + echo '[FAIL]' + fi + fi + return $RETVAL +} + +exitcountdown() { + while [ "$COUNTDOWN" -gt 0 ]; do + if [ -z "$(procpid)" ]; then + break + fi + echo -n . + sleep 1 + COUNTDOWN=`expr $COUNTDOWN - 1` + done +} + +stop() { + if isrunning; then + echo -n "Shutting down $prog" + kill $(allpids) 2>/dev/null + exitcountdown + if [ -n "$(procpid)" ]; then + echo -n " $prog failed to exit cleanly, terminating" + kill -9 $(allpids) 2>/dev/null + COUNTDOWN=3 + exitcountdown + fi + rm -f $PID_FILE + if isrunning; then + echo ' [FAIL]' + else + echo ' [OK]' + fi + RETVAL=0 + else + echo "$prog is not running" + RETVAL=1 + fi + return $RETVAL +} + +restart() { + stop + start +} + +condrestart() { + if isrunning; then + restart + fi +} + +status() { + if isrunning; then + FP=$(filepid) + PP=$(procpid) + if [ "$FP" == "$PP" ]; then + RETVAL=0 + echo "$prog is running with pid $PP" + else + RETVAL=1 + echo "$prog has a pidfile with pid ${FP:-} but a running process with pid ${PP:-}" + fi + else + RETVAL=3 + echo "$prog is not running" + fi + return $RETVAL +} + +case "$1" in + start) + start + ;; + stop) + stop + ;; + status) + status + ;; + restart) + restart + ;; + condrestart|try-restart) + condrestart + ;; + *) + echo $"Usage: $0 {start|stop|status|restart|condrestart}" + RETVAL=1 +esac diff --git a/package-support/postinstall.sh b/package-support/postinstall.sh new file mode 100644 index 0000000..dad1b91 --- /dev/null +++ b/package-support/postinstall.sh @@ -0,0 +1,16 @@ + +for dname in /var/run/sideboard /var/tmp/sideboard /var/tmp/sideboard/sessions /opt/sideboard /opt/sideboard/db /opt/sideboard/plugins; do + mkdir -p $dname + chmod 750 $dname + chown sideboard.sideboard $dname +done + +# unlike all of the other directories in the above loop, we want this directory (and also its contents) to be sideboard.root +chown -R root.sideboard /etc/sideboard + +chown root.root /etc/init.d/sideboard +chown root.root /etc/sysconfig/sideboard + +# TODO: instead of doing this in postinstall, we should eventually do ---use-file-permissions +chmod 700 /etc/init.d/sideboard +chmod 600 /etc/sysconfig/sideboard diff --git a/package-support/preinstall.sh b/package-support/preinstall.sh index 0113693..36d0831 100644 --- a/package-support/preinstall.sh +++ b/package-support/preinstall.sh @@ -1,8 +1,5 @@ + if ! id -u sideboard &>/dev/null; then - adduser sideboard + groupadd --force -r sideboard -g 600 + useradd -r --shell /sbin/nologin -uid 600 --gid sideboard sideboard fi - -for dname in /var/run/sideboard /var/tmp/sideboard/sessions /opt/sideboard/db /opt/sideboard/plugins; do - mkdir -p $dname - chown sideboard.sideboard $dname -done diff --git a/package-support/sideboard-server b/package-support/sideboard-server deleted file mode 100755 index 4549307..0000000 --- a/package-support/sideboard-server +++ /dev/null @@ -1,110 +0,0 @@ -#!/bin/bash - -RETVAL=0 -prog="sideboard" - -DESC=sideboard -VENV=/opt/sideboard -PYTHON=$VENV/bin/python -CHERRYD=$VENV/bin/cherryd -USER=sideboard -PID_FILE=/var/run/sideboard/$DESC.pid -LOCK_FILE=/var/run/sideboard/$DESC.lock - -OPTIONS="-d --pidfile=$PID_FILE --import=sideboard.server" - -isrunning() { - PID=`cat $PID_FILE 2>/dev/null` - if [ -n "$PID" ]; then - EXISTS=`ps aux | grep $PID | grep -v grep` - if [ -n "$EXISTS" ]; then - RETVAL=0 - else - RETVAL=1 - fi - else - RETVAL=1 - fi - return $RETVAL -} - -start() { - if isrunning; then - echo "Starting $prog: $prog is already running [FAIL]" - RETVAL=1 - else - echo -n $"Starting $prog: " - sudo -u $USER $CHERRYD $OPTIONS - RETVAL=$? - if [ $RETVAL -eq 0 ]; then - touch $LOCK_FILE - echo '[OK]' - else - echo '[FAIL]' - fi - fi - return $RETVAL -} - -stop() { - PID=`cat $PID_FILE 2>/dev/null` - if [ -n "$PID" ]; then - echo -n $"Shutting down $prog: " - kill $PID - sleep 5 - if isrunning; then - echo '[FAIL]' - RETVAL=1 - else - rm -f $LOCK_FILE - echo '[OK]' - RETVAL=0 - fi - else - echo "$prog is not running" - RETVAL=1 - fi - return $RETVAL -} - -restart() { - stop - start -} - -condrestart() { - [-e $LOCK_FILE] && restart || : -} - -status() { - PID=`cat $PID_FILE 2>/dev/null` - if [ -n "$PID" ]; then - echo "$prog is running with pid $PID" - RETVAL=0 - else - echo "$prog is not running" - RETVAL=1 - fi - return $RETVAL -} - -case "$1" in - start) - start - ;; - stop) - stop - ;; - status) - status - ;; - restart) - restart - ;; - condrestart|try-restart) - condrestart - ;; - *) - echo $"Usage: $0 {start|stop|status|restart|condrestart}" - RETVAL=1 -esac diff --git a/package-support/sideboard-server.cfg b/package-support/sideboard-server.cfg index caedb66..a1cf43f 100644 --- a/package-support/sideboard-server.cfg +++ b/package-support/sideboard-server.cfg @@ -1,7 +1,7 @@ plugins_dir = "/opt/sideboard/plugins" [cherrypy] -server.socket_host = "0.0.0.0" +server.socket_host = "127.0.0.1" tools.sessions.storage_path = "/var/tmp/sideboard/sessions" [loggers] @@ -11,3 +11,12 @@ root = "INFO" [[syslog]] class = "logging.handlers.SysLogHandler" address = "/dev/log" +formatter = syslog + +[formatters] +[[syslog]] +format = "$$(levelname)-5.5s $$(threadName)s [$$(name)s] $$(message)s" + +[[default]] +format = "$$(asctime)s,$$(msecs)03d $$(levelname)-5.5s $$(threadName)s [$$(name)s] $$(message)s" +datefmt = "$$m-$$d $$H:$$M:$$S" diff --git a/package-support/sysconfig/sideboard b/package-support/sysconfig/sideboard new file mode 100644 index 0000000..6b81aa9 --- /dev/null +++ b/package-support/sysconfig/sideboard @@ -0,0 +1,6 @@ +# defaults for sideboard package +# SIDEBOARDMODE=server +SIDEBOARDMODE=daemon + +# DESC=sideboard +DESC=sideboard-daemon diff --git a/pavement.py b/pavement.py index 57bf87c..b3006b6 100644 --- a/pavement.py +++ b/pavement.py @@ -65,6 +65,7 @@ def collect_plugin_dirs(module=False): else: yield potential_folder + @task def make_venv(): """ @@ -73,23 +74,32 @@ def make_venv(): bootstrap_venv(__here__ / path('env'), 'sideboard') develop_sideboard() + def install_pip_requirements_in_dir(dir_of_requirements_txt): path_to_pip = __here__ / path('env/bin/pip') + + print("---- installing dependencies in {} ----" + .format(dir_of_requirements_txt)) + sh('{pip} install -e {dir_of_requirements_txt}' .format( pip=path_to_pip, dir_of_requirements_txt=dir_of_requirements_txt)) + def run_setup_py(path): + venv_python = str(__here__ / 'env' / 'bin' / 'python') sh('cd {path} && {python_path} {setup_path} develop' .format( path=path, - python_path=sys.executable, + python_path=venv_python if exists(venv_python) else sys.executable, setup_path=join(path, 'setup.py'))) + def develop_sideboard(): run_setup_py(__here__) + @task def pull_plugins(): """ @@ -99,6 +109,7 @@ def pull_plugins(): for plugin_dir in collect_plugin_dirs(): sh('cd "{}";git pull'.format(plugin_dir)) + @task def assert_all_files_import_unicode_literals(): """ @@ -120,6 +131,7 @@ def assert_all_files_import_unicode_literals(): raise BuildFailure("there were files that didn't include " '"from __future__ import unicode_literals"') + @task def assert_all_projects_correctly_define_a_version(): """ @@ -144,6 +156,7 @@ def assert_all_projects_correctly_define_a_version(): raise BuildFailure("there were projects that didn't include correctly specify __version__") + @task @needs(['assert_all_files_import_unicode_literals', 'assert_all_projects_correctly_define_a_version']) @@ -152,6 +165,7 @@ def run_all_assertions(): run all the assertion tasks that sideboard supports """ + @task @cmdopts([ ('name=', 'n', 'name of the plugin to create'), @@ -159,6 +173,8 @@ def run_all_assertions(): ('no_webapp', 'w', 'do not expose webpages in the plugin'), ('no_sqlalchemy', 'a', 'do not use SQLAlchemy in the plugin'), ('no_service', 'r', 'do not expose a service in the plugin'), + ('no_sphinx', 's', 'do not generate Sphinx docs'), + ('django=', 'j', 'create a Django project alongside the plugin with this name'), ('cli', 'c', 'make this a cli application; implies -w/-r') ]) def create_plugin(options): @@ -166,34 +182,37 @@ def create_plugin(options): # this is actually needed thanks to the skeleton using jinja2 (and six, although that's changeable) try: - pkg_resources.get_distribution("sideboard") + pkg_resources.get_distribution("sideboard") except pkg_resources.DistributionNotFound: - raise BuildFailure("This command must be run from within a configured virtual environment.") + raise BuildFailure("This command must be run from within a configured virtual environment.") plugin_name = options.create_plugin.name if getattr(options.create_plugin, 'drop', False) and (PLUGINS_DIR / path(plugin_name.replace('_', '-'))).exists(): # rmtree fails if the dir doesn't exist apparently (PLUGINS_DIR / path(plugin_name.replace('_', '-'))).rmtree() - + kwargs = {} - for opt in ['webapp', 'sqlalchemy', 'service']: + for opt in ['webapp', 'sqlalchemy', 'service', 'sphinx']: kwargs[opt] = not getattr(options.create_plugin, 'no_' + opt, False) kwargs['cli'] = getattr(options.create_plugin, 'cli', False) + kwargs['django'] = getattr(options.create_plugin, 'django', None) if kwargs['cli']: kwargs['webapp'] = False kwargs['service'] = False - + from data.paver import skeleton skeleton.create_plugin(PLUGINS_DIR, plugin_name, **kwargs) print('{} successfully created'.format(options.create_plugin.name)) + @task def install_deps(): install_pip_requirements_in_dir(__here__) for pdir in collect_plugin_dirs(): install_pip_requirements_in_dir(pdir) + @task def clean(): """ diff --git a/requirements.txt b/requirements.txt index ae3f855..55bc090 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,15 @@ -configobj==5.0.5 -cherrypy==3.6.0 -#python-ldap==2.3.13 -ws4py==0.3.2 -SQLAlchemy==1.0.0 -six==1.5.2 -Jinja2==2.6 -rpctools==0.2.1 -logging_unterpolation==0.2.0 -requests==2.2.1 -pytest==2.7.2 -mock==1.0.1 -Sphinx==1.2.1 -coverage==3.6 -paver==1.2.2 -wheel==0.24.0 -pip==1.5.6 -sh==1.09 +configobj>=5.0.5 +cherrypy>=3.6.0 +ws4py>=0.3.2 +SQLAlchemy>=1.1.0 +six>=1.5.2 +Jinja2>=2.7 +rpctools>=0.3.1 +logging_unterpolation>=0.2.0 +requests>=2.2.1 +paver>=1.2.2 +wheel>=0.24.0 +pip>=1.5.6 +sh>=1.09 +python-prctl>=1.6.1; 'linux' in sys_platform +psutil>=5.4.1 diff --git a/run-tests.sh b/run-tests.sh deleted file mode 100644 index 7ae625c..0000000 --- a/run-tests.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -../env/bin/py.test -m 'not nonfunctional' sideboard \ No newline at end of file diff --git a/setup.cfg b/setup.cfg old mode 100755 new mode 100644 index 199f488..335a779 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,23 @@ [easy_install] zip_ok = False -[pytest] +[coverage:run] +# don't collect the plugins subdirectories +omit = + plugins/*/env/*.py + plugins/conftest.py + plugins/*/*/tests/* + tests/plugins/* + sideboard/tests/* + tests/* + +include = + sideboard/* + +[tool:pytest] norecursedirs = tests/plugins build env dist plugins/*/build plugins/*/env plugins/*/dist .tox -python_files = test_*.py tests/__init__.py tests/*/__init__.py \ No newline at end of file +python_files = test_*.py tests/__init__.py tests/*/__init__.py + +[pep8] +max-line-length=999 +ignore=E121,E123,E126,E226,E24,E704,E221,E127,E128,W503,E731,E131,E711,E712,E402 diff --git a/setup.py b/setup.py index 0fd3cf0..2365747 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,5 @@ +import os +import platform import sys import os.path from setuptools import setup, find_packages @@ -9,9 +11,29 @@ with open(os.path.join(__here__, pkg_name, '_version.py')) as version: exec(version.read()) # __version__ is now defined -req_data = open(os.path.join(__here__, 'requirements.txt')).read() -requires = [r.strip() for r in req_data.split() if r.strip() != ''] -requires = list(reversed(requires)) +req_data = open(os.path.join(__here__, 'requirements.txt')).readlines() +raw_requires = [r.strip() for r in req_data if r.strip() != ''] + +# Ugly hack to reconcile pip requirements.txt and setup.py install_requires +os_name = os.name +sys_platform = sys.platform +platform_release = platform.release() +implementation_name = sys.implementation.name +platform_machine = platform.machine() +platform_python_implementation = platform.python_implementation() +requires = [] +for s in reversed(raw_requires): + if ';' in s: + req, env_marker = s.split(';') + if eval(env_marker): + requires.append(s) + else: + requires.append(s) + +# testing dependencies +req_data = open(os.path.join(__here__, 'test_requirements.txt')).read() +tests_require = [r.strip() for r in req_data.split() if r.strip() != ''] +tests_require = list(reversed(tests_require)) # temporary workaround for a Python 2 CherryPy bug, for which we opened a pull request: # https://bitbucket.org/cherrypy/cherrypy/pull-request/85/1285-python-2-now-accepts-both-bytestrings/ @@ -19,21 +41,26 @@ requires = ['CherryPy==3.2.2' if 'cherrypy' in r.lower() else r for r in requires] if __name__ == '__main__': + setup_requires = {'setup_requires': ['distribute']} if sys.version_info[0] == 2 else {} setup( name=pkg_name, version=__version__, description='Sideboard plugin container.', license='BSD', scripts=[], - setup_requires=['distribute'], install_requires=requires, + tests_require=tests_require, packages=find_packages(), include_package_data=True, package_data={pkg_name: []}, zip_safe=False, - entry_points = { + entry_points={ 'console_scripts': [ 'sep = sideboard.sep:run_plugin_entry_point' ] - } + }, + extras_require={ + 'perftrace': ['python-prctl>=1.6.1', 'psutil>=4.3.0'] + }, + **setup_requires ) diff --git a/sideboard/__init__.py b/sideboard/__init__.py index a4b7f60..825ec9a 100644 --- a/sideboard/__init__.py +++ b/sideboard/__init__.py @@ -1,18 +1,17 @@ from __future__ import unicode_literals +import os import importlib import six import cherrypy from sideboard._version import __version__ -try: - import sideboard.server -except: - from sideboard.lib import log - log.warning('Error importing server', exc_info=True) +import sideboard.server from sideboard.internal.imports import _discover_plugins from sideboard.internal.logging import _configure_logging +import sideboard.run_mainloop -_discover_plugins() -_configure_logging() +if 'SIDEBOARD_MODULE_TESTING' not in os.environ: + _discover_plugins() + _configure_logging() diff --git a/sideboard/_version.py b/sideboard/_version.py index edeb193..5b59135 100644 --- a/sideboard/_version.py +++ b/sideboard/_version.py @@ -1,2 +1,3 @@ from __future__ import unicode_literals -__version__ = '0.1.0' + +__version__ = '1.0.14' diff --git a/sideboard/config.py b/sideboard/config.py index 6f549b5..c78f985 100755 --- a/sideboard/config.py +++ b/sideboard/config.py @@ -1,7 +1,9 @@ from __future__ import unicode_literals import os +import re from os import unlink +from collections import Sized, Iterable, Mapping from copy import deepcopy from tempfile import NamedTemporaryFile @@ -9,153 +11,221 @@ from validate import Validator +def uniquify(xs): + """ + Returns an order-preserved copy of `xs` with duplicate items removed. + + >>> uniquify(['a', 'z', 'a', 'b', 'a', 'y', 'a', 'c', 'a', 'x']) + ['a', 'z', 'b', 'y', 'c', 'x'] + + """ + is_listy = isinstance(xs, Sized) \ + and isinstance(xs, Iterable) \ + and not isinstance(xs, (Mapping, type(b''), type(''))) + assert is_listy, 'uniquify requires a listy argument' + + seen = set() + return [x for x in xs if x not in seen and not seen.add(x)] + + class ConfigurationError(RuntimeError): pass -def os_path_split_asunder(path, debug=False): + +def get_config_overrides(): """ - http://stackoverflow.com/a/4580931/171094 + Returns a list of config file paths used to override the default config. + + The SIDEBOARD_CONFIG_OVERRIDES environment variable may be set to a + semicolon separated list of absolute and/or relative paths. If the + SIDEBOARD_CONFIG_OVERRIDES is set, this function returns a list of its + contents, split on semicolons:: + + # SIDEBOARD_CONFIG_OVERRIDES='/absolute/config.ini;relative/config.ini' + return ['/absolute/config.ini', 'relative/config.ini'] + + If any of the paths listed in SIDEBOARD_CONFIG_OVERRIDES ends with the + suffix "-defaults." then a similarly named path + "." will also be included:: + + # SIDEBOARD_CONFIG_OVERRIDES='test-defaults.ini' + return ['test-defaults.ini', 'test.ini'] + + If the SIDEBOARD_CONFIG_OVERRIDES environment variable is NOT set, this + function returns a list with two relative paths:: + + return ['development-defaults.ini', 'development.ini'] """ - parts = [] - while True: - newpath, tail = os.path.split(path) - if newpath == path: - assert not tail - if path: parts.append(path) - break - parts.append(tail) - path = newpath - parts.reverse() - return parts - - -def is_subdirectory(potential_subdirectory, expected_parent_directory): + config_overrides = os.environ.get( + 'SIDEBOARD_CONFIG_OVERRIDES', + 'development-defaults.ini') + + defaults_re = re.compile(r'(.+)-defaults(\.\w+)$') + config_paths = [] + for config_path in uniquify([s.strip() for s in config_overrides.split(';')]): + config_paths.append(config_path) + m = defaults_re.match(config_path) + if m: + config_paths.append(m.group(1) + m.group(2)) + + return config_paths + + +def get_config_root(): """ - Is the first argument a sub-directory of the second argument? - - :param potential_subdirectory: - :param expected_parent_directory: - :return: True if the potential_subdirectory is a child of the expected parent directory - - >>> is_subdirectory('/var/test2', '/var/test') - False - >>> is_subdirectory('/var/test', '/var/test2') - False - >>> is_subdirectory('var/test2', 'var/test') - False - >>> is_subdirectory('var/test', 'var/test2') - False - >>> is_subdirectory('/var/test/sub', '/var/test') - True - >>> is_subdirectory('/var/test', '/var/test/sub') - False - >>> is_subdirectory('var/test/sub', 'var/test') - True - >>> is_subdirectory('var/test', 'var/test') - True - >>> is_subdirectory('var/test', 'var/test/fake_sub/..') - True - >>> is_subdirectory('var/test/sub/sub2/sub3/../..', 'var/test') - True - >>> is_subdirectory('var/test/sub', 'var/test/fake_sub/..') - True - >>> is_subdirectory('var/test', 'var/test/sub') - False + Returns the config root for the system, defaults to '/etc/sideboard'. + + If the SIDEBOARD_CONFIG_ROOT environment variable is set, its contents + will be returned instead. """ + default_root = '/etc/sideboard' + config_root = os.environ.get('SIDEBOARD_CONFIG_ROOT', default_root) + if config_root != default_root and not os.path.isdir(config_root): + raise AssertionError('cannot find {!r} directory'.format(config_root)) + elif os.path.isdir(config_root) and not os.access(config_root, os.R_OK): + raise AssertionError('{!r} directory is not readable'.format(config_root)) + return config_root + - def _get_normalized_parts(path): - return os_path_split_asunder(os.path.realpath(os.path.abspath(os.path.normpath(path)))) +def get_module_and_root_dirs(requesting_file_path, is_plugin): + """ + Returns the "module_root" and "root" directories for the given file path. - # make absolute and handle symbolic links, split into components - sub_parts = _get_normalized_parts(potential_subdirectory) - parent_parts = _get_normalized_parts(expected_parent_directory) + Sideboard and its plugins often want to find other files. Sometimes they + need files which ship as part of the module itself, and for those they need + to know the module directory. Other times they might need files which are + bundled with their Git repo or which shipped with their RPM, and for those + they need to know their "root" directory. This "root" directory in + development is just the root of the Git repo and in production is the + package under the configured "plugins_dir" directory. - if len(parent_parts) > len(sub_parts): - # a parent directory never has more path segments than its child - return False + These two directories are also automatically inserted into plugin config + files as "root" and "module_root" and are available for interpolation. For + example, a plugin could have a line in their config file like:: - # we expect the zip to end with the short path, which we know to be the parent - return all(part1==part2 for part1, part2 in zip(sub_parts, parent_parts)) + template_dir = "%(module_root)s/templates" + and that would be interpolated to the correct absolute path. -def get_dirnames(pyname): - module_dir = os.path.dirname(os.path.abspath(pyname)) + Args: + requesting_file_path (str): The __file__ of the module requesting the + "root" and "module_root" directories. + is_plugin (bool): Indicates whether a plugin is making the request or + Sideboard itself is making the request. - # this can blow up if we decide that production plugins are somewhere different - expected_prod_plugin_dir = ('/', 'opt', 'sideboard', 'plugins') - if is_subdirectory(pyname, os.path.join(*expected_prod_plugin_dir)): - # we're in production, so the root, is really the directory in plugins that holds our - # virtualenv - root_dir = os.path.join(*os_path_split_asunder(pyname)[:len(expected_prod_plugin_dir) + 1]) + Returns: + tuple(str): The "module_root" and "root" directories for the + given module. + """ + module_dir = os.path.dirname(os.path.abspath(requesting_file_path)) + if is_plugin: + from sideboard.lib import config + plugin_name = os.path.basename(module_dir) + root_dir = os.path.join(config['plugins_dir'], plugin_name) + if '_' in plugin_name and not os.path.exists(root_dir): + root_dir = os.path.join(config['plugins_dir'], plugin_name.replace('_', '-')) else: root_dir = os.path.realpath(os.path.join(module_dir, '..')) - return module_dir, root_dir -def get_config_files(requesting_file_path, plugin): - """ - get a list of the config files that should be parsed, merged and returned by parse_config - - :param requesting_file_path: the path of the file requesting a parsed config file - :param plugin: if True (default) return the expected production-config directory. This is based - on the folder name of the requesting module, although in the future this could be the - based on the plugin name, no matter where you request a config from. - :return: list of config file paths that should be parsed, this list is ordered from lowest - to highest priority - :type: list +def get_config_files(requesting_file_path, is_plugin): """ + Returns a list of absolute paths to config files for the given file path. - module_dir, root_dir = get_dirnames(requesting_file_path) - module_name = os.path.basename(module_dir) + When the returned config files are parsed by ConfigObj each subsequent + file will override values in earlier files. - # this first two are expected to be per-plugin (or sideboard itself) - default_file_paths = ('development-defaults.ini', 'development.ini') + If `is_plugin` is `True` the first of the returned files is: - if plugin: - # TODO: this should ideally be the plugin name, even if it's overridden - plugin_config_name = '%s.cfg' % module_name.replace('_', '-') - extra_configs = [os.path.join('/etc', 'sideboard', 'plugins.d', plugin_config_name)] - else: - if module_name != 'sideboard': - raise RuntimeError('Unexpected module name {!r} requesting "non-plugin" ' - 'configuration files'.format(module_name)) + * /etc/sideboard/plugins.d/.cfg, which is the config file we + expect in production - extra_configs = [ - os.path.join('/etc', 'sideboard', 'sideboard-core.cfg'), - os.path.join('/etc', 'sideboard', 'sideboard-server.cfg'), - ] - old_production_path = os.path.join('/etc', 'sideboard', 'sideboard.cfg') - if os.path.exists(old_production_path): - raise RuntimeError("Old-style production path {}, exists. Configuration you've set " - "should be migrated to one of the following new-style " - "configuration files:\n{}".format(old_production_path, - '\n'.join(extra_configs))) + If `is_plugin` is `False` the first two returned files are: + + * /etc/sideboard/sideboard-core.cfg, which is the sideboard core config + file we expect in production + + * /etc/sideboard/sideboard-server.cfg, which is the sideboard server config + file we expect in production + + + The rest of the files returned are as follows, though we wouldn't + necessarily expect these to exist on a production install (these are + controlled by SIDEBOARD_CONFIG_OVERRIDES): + + * /development-defaults.ini, which can be checked into source + control and include whatever we want to be present in a development + environment. + + * /development.ini, which shouldn't be checked into source + control, allowing a developer to include local settings not shared with + others. + + + When developing on a machine with an installed production config file, we + might want to ignore the "real" config file and limit ourselves to only the + development files. This behavior is turned on by setting the environment + variable SIDEBOARD_MODULE_TESTING to any value. (This environment variable + is also used elsewhere to turn off automatically loading all plugins in + order to facilitate testing modules which rely on Sideboard but which are + not themselves Sideboard plugins.) + + Args: + requesting_file_path (str): The Python __file__ of the module + requesting its config files. + is_plugin (bool): Indicates whether a plugin is making the request or + Sideboard itself is making the request, since this affects which + config files we return. - return ([os.path.join(root_dir, default_path) for default_path in default_file_paths] + - extra_configs) + Returns: + list(str): List of absolute paths to config files for the given module. + """ + config_root = get_config_root() + module_dir, root_dir = get_module_and_root_dirs(requesting_file_path, is_plugin) + module_name = os.path.basename(module_dir) + + if 'SIDEBOARD_MODULE_TESTING' in os.environ: + base_configs = [] + elif is_plugin: + base_configs = [os.path.join(config_root, 'plugins.d', '{}.cfg'.format(module_name.replace('_', '-')))] + else: + assert module_name == 'sideboard', 'Unexpected module name {!r} requesting "non-plugin" configuration files'.format(module_name) + base_configs = [ + os.path.join(config_root, 'sideboard-core.cfg'), + os.path.join(config_root, 'sideboard-server.cfg') + ] + + config_overrides = [os.path.join(root_dir, config_path) for config_path in get_config_overrides()] + return base_configs + config_overrides -def parse_config(requesting_file_path, plugin=True): +def parse_config(requesting_file_path, is_plugin=True): """ - parse the configuration files for a given sideboard module (or the sideboard server itself). It's - expected that this function is called from one of the files in the top-level of your module - (typically the __init__.py file) - - :param requesting_file_path: the path of the file requesting a parsed config file. An example - value is: - ~/sideboard/plugins/plugin_nickname/plugin_module_name/__init__.py - the containing directory (here, 'plugin_module_name') is assumed to be the module name of - the plugin that is requesting a parsed config. - :type requesting_file_path: binary or unicode string - :param plugin: if True (default) add plugin-relevant information to the returning config. Also, - treat it as if it's a plugin - :type plugin: bool - :return: the resulting configuration object - :rtype: ConfigObj + Parse the config files for a given sideboard plugin, or sideboard itself. + + It's expected that this function is called from one of the files in the + top-level of your module (typically the __init__.py file) + + Args: + requesting_file_path (str): The __file__ of the module requesting the + parsed config file. An example value is:: + + /opt/sideboard/plugins/plugin-package-name/plugin_module_name/__init__.py + + the containing directory (here, `plugin_module_name`) is assumed + to be the module name of the plugin that is requesting a parsed + config. + is_plugin (bool): Indicates whether a plugin is making the request or + Sideboard itself is making the request. If True (default) add + plugin-relevant information to the returned config. Also, treat it + as if it's a plugin + + Returns: + ConfigObj: The resulting configuration object. """ - module_dir, root_dir = get_dirnames(requesting_file_path) + module_dir, root_dir = get_module_and_root_dirs(requesting_file_path, is_plugin) specfile = os.path.join(module_dir, 'configspec.ini') spec = configobj.ConfigObj(specfile, interpolation=False, list_values=False, encoding='utf-8', _inspec=True) @@ -164,7 +234,7 @@ def parse_config(requesting_file_path, plugin=True): root_conf = ['root = "{}"\n'.format(root_dir), 'module_root = "{}"\n'.format(module_dir)] temp_config = configobj.ConfigObj(root_conf, interpolation=False, encoding='utf-8') - for config_path in get_config_files(requesting_file_path, plugin): + for config_path in get_config_files(requesting_file_path, is_plugin): # this gracefully handles nonexistent files temp_config.merge(configobj.ConfigObj(config_path, encoding='utf-8', interpolation=False)) @@ -183,13 +253,13 @@ def parse_config(requesting_file_path, plugin=True): configobj.flatten_errors(config, validation)) ) - if plugin: + if is_plugin: sideboard_config = globals()['config'] config['plugins'] = deepcopy(sideboard_config['plugins']) if 'rpc_services' in config: - from sideboard.lib import register_rpc_services - register_rpc_services(config['rpc_services']) - + from sideboard.lib._services import _register_rpc_services + _register_rpc_services(config['rpc_services']) + if 'default_url' in config: priority = config.get('default_url_priority', 0) if priority >= sideboard_config['default_url_priority']: @@ -197,4 +267,4 @@ def parse_config(requesting_file_path, plugin=True): return config -config = parse_config(__file__, plugin=False) +config = parse_config(__file__, is_plugin=False) diff --git a/sideboard/configspec.ini b/sideboard/configspec.ini index 2254a2b..326c0e1 100644 --- a/sideboard/configspec.ini +++ b/sideboard/configspec.ini @@ -1,30 +1,82 @@ debug = boolean(default=False) +# Directory where core Sideboard pages look for plugins. template_dir = string(default="%(module_root)s/templates") + +# Directory in which Sideboard looks for plugins. plugins_dir = string(default="%(root)s/plugins") + +# Sometimes plugins want to import from other plugins. This requires that some +# must be loaded before others. Plugins are loaded first in the order they are +# found here and then in an arbitrary order after that. priority_plugins = string_list(default=list()) +# When someone visits an unknown URL then we redirect them to this configured +# URL, which by default is the page which shows a list of installed plugins. +# +# Each plugin may define its own default URL. However, since there can be only +# one default, there's a second "priority" option. If more than one plugin +# defines a default URL, the defined URL with the highest priority is used. default_url = string(default="/list_plugins") default_url_priority = integer(default=0) +# Default client cert information. Any rpc_services entry will default to +# these values if not overridden in their own section. ca = string(default="") client_key = string(default="") client_cert = string(default="") +ssl_version = string(default="PROTOCOL_TLSv1") ws.thread_pool = integer(default=25) ws.call_timeout = integer(default=10) # seconds ws.poll_interval = integer(default=300) # seconds ws.reconnect_interval = integer(default=60) # seconds + +# Sideboard exposes two websocket endpoints. The first is at /wsrpc and doesn't +# require authentication, with the expectation being that the frontend webserver +# which reverse proxies to Sideboard will either block or require a client cert +# for this endpoint. The second is at /ws and by default requires a logged-in +# user to work. This setting can turn off that authentication check, which is +# useful for development or for applications which require no authentication. ws.auth_required = boolean(default=True) -ldap.url = string(default="") -ldap.basedn = force_list(default="") -ldap.userattr = string(default="uid") -ldap.start_tls = boolean(default=True) -ldap.cacert = string(default="") -ldap.cert = string(default="") -ldap.key = string(default="") +# When performing authentication for the /ws websocket endpoint, this setting +# determines which session field must be set for the request to be considered +# "logged in". If your application sets a session field other than "username" +# when a user logs in, you should change this setting to the name of that field. +ws.auth_field = string(default="username") + +# When an authenticated websocket is established on the /ws endpoint, we copy +# this configurable list of session fields into the websocket and make them +# available as threadlocal fields on every websocket RPC requests. By default +# we only do this with the username of the logged-in user, but applications +# which store other data for logged in users can add those fields to this list. +ws.session_fields = string_list(default=list("username")) + +# When a frontend server permforms authentication before proxying a request, +# the username is often placed in an HTTP header. We copy this configureable +# list of HTTP headers and make them available on every websocket RPC request +# as fields inside threadlocal['headers'] +ws.header_fields = string_list(default=list()) + +# If the "debug" option is set, the default login form will allow people to log +# in with any username using this password. +debug_password = string(default="testpassword") + +# Sideboard has numerous background threads which wait on sideboard.lib.stopped +# to either sleep or bail immediately on shutdown. Since these threads wait in +# a loop, we don't want to set an interval too small or we'll eat a lot of CPU +# while doing absolutely nothing. A hard-coded value of 1 second would probably +# be fine for all workloads, but we've made it configurable just in case. +thread_wait_interval = float(default=1) + +# Plugins can register different authenticators, since different applications may +# have different ideas about what it means to be "logged in". The default +# authenticator is mainly used for the /ws and /json RPC endpoints, so this +# option should be changed on systems where a plugin provides a new method we +# want to use as our default. +default_authenticator = string(default="default") [plugins] @@ -32,8 +84,41 @@ sqlite_dir = string(default="%(root)s/db") [cherrypy] +checker.check_skipped_app_config = boolean(default=False) + engine.autoreload.on = boolean(default=False) +# True to enable (or False to disable) the Sideboard profiler. +# If the profiler is enabled, profiler data can be collected using the +# @sideboard.lib.profile decorator. A web based interface will be exposed on +# http://servername/profiler/ to view profiler results. +# If the profiler is disabled, then the @sideboard.lib.profile decorator +# becomes a no-op, thus no performance penalty is incurred. The web interface +# will not be exposed and visits to http://servername/profiler/ will 404. +profiling.on = boolean(default=False) + +# The directory where the Sideboard profiler will store its data files. If the +# directory does not exist, it will be created. Data files are generated using +# pstats.dump_stats() from the standard library. +# See https://docs.python.org/3/library/profile.html#pstats.Stats.dump_stats +profiling.path = string(default="%(root)s/data/profiler") + +# If True, the profile data for each instance of the @sideboard.lib.profile +# decorator will be aggregated over time. Individual profiler files will be +# created for each call, but the stats reported in each file will be the +# aggregate of all previous runs, plus the current run. This will smooth out +# how the profiler data changes over time, but it will be harder to gauge the +# results of any individual profiler run. +# If False, each profiler data file will represent a single run. +profiling.aggregate = boolean(default=False) + +# True to include the full path to each python file reported in the profiler +# stats. False to strip the full path and only include the filename. NOTE: +# this does not change how the data is COLLECTED, only how it is DISPLAYED. +# You can switch this on and off between runs without losing data. +# See https://docs.python.org/3/library/profile.html#pstats.Stats.strip_dirs +profiling.strip_dirs = boolean(default=False) + server.socket_host = string(default="127.0.0.1") server.socket_port = integer(default=8282) @@ -46,24 +131,29 @@ tools.sessions.storage_type = string(default="file") tools.sessions.storage_path = string(default="%(root)s/data/sessions") tools.sessions.secure = boolean(default=False) -checker.check_skipped_app_config = boolean(default=False) - [rpc_services] ___many___ = string -#[[___many___]] -#jsonrpc_only = boolean(default=False) +[[__many__]] +jsonrpc_only = boolean(default=False) [loggers] -root = option("DEBUG", "INFO", "WARN", "WARNING", "ERROR", "CRITICAL", default="INFO") -cherrypy.error = option("DEBUG", "INFO", "WARNING", "WARN", "ERROR", "CRITICAL", default="DEBUG") -cherrypy.access = option("DEBUG", "INFO", "WARNING", "WARN", "ERROR", "CRITICAL", default="CRITICAL") -__many__ = option("DEBUG", "INFO", "WARN", "WARNING", "ERROR", "CRITICAL", default="INFO") +root = option("TRACE", "DEBUG", "INFO", "WARN", "WARNING", "ERROR", "CRITICAL", default="DEBUG") +cherrypy.error = option("TRACE", "DEBUG", "INFO", "WARNING", "WARN", "ERROR", "CRITICAL", default="DEBUG") +cherrypy.access = option("TRACE", "DEBUG", "INFO", "WARNING", "WARN", "ERROR", "CRITICAL", default="CRITICAL") +__many__ = option("TRACE", "DEBUG", "INFO", "WARN", "WARNING", "ERROR", "CRITICAL", default="INFO") [handlers] [[__many__]] formatter = string(default="default") ___many___ = string() +[formatters] +[[default]] +format = string(default="$$(asctime)s [$$(levelname)s] $$(name)s: $$(message)s") +datefmt = string(default="") +[[__many__]] +format = string +datefmt = string(default="") diff --git a/sideboard/debugger.py b/sideboard/debugger.py deleted file mode 100644 index 64dbebd..0000000 --- a/sideboard/debugger.py +++ /dev/null @@ -1,14 +0,0 @@ -import os - - -# when debugging, if you kill the server, occasionally there will be lockfiles leftover. -# we'll kill them here. DO NOT CALL THIS IN PRODUCTION -def debugger_helper_remove_any_lockfiles(): - path_of_this_python_script = os.path.dirname(os.path.realpath(__file__)) - session_path = path_of_this_python_script + "/../data/sessions/" - for lockfile in os.listdir(session_path): - if lockfile.endswith(".lock"): - os.remove(session_path + lockfile) - -def debugger_helpers_all_init(): - debugger_helper_remove_any_lockfiles() \ No newline at end of file diff --git a/sideboard/debugging.py b/sideboard/debugging.py new file mode 100644 index 0000000..08943fa --- /dev/null +++ b/sideboard/debugging.py @@ -0,0 +1,49 @@ +from __future__ import unicode_literals +import os + +# create a list of status functions which can inspect information of the running process +status_functions = [] + + +def gather_diagnostics_status_information(): + """ + Return textual information about current system state / diagnostics + Useful for debugging threading / db / cpu load / etc + """ + out = '' + for func in status_functions: + out += '--------- {} ---------\n{}\n\n\n'.format(func.__name__.replace('_', ' ').upper(), func()) + return out + + +def register_diagnostics_status_function(func): + status_functions.append(func) + return func + + +def _get_all_session_lock_filenames(): + path_of_this_python_script = os.path.dirname(os.path.realpath(__file__)) + session_path = path_of_this_python_script + "/../data/sessions/" + return [session_path + lockfile for lockfile in os.listdir(session_path) if lockfile.endswith(".lock")] + + +def _debugger_helper_remove_any_session_lockfiles(): + """ + When debugging, if you force kill the server, occasionally + there will be cherrypy session lockfiles leftover. + Calling this function will remove any stray lockfiles. + + DO NOT CALL THIS IN PRODUCTION. + """ + for lockfile in _get_all_session_lock_filenames(): + os.remove(lockfile) + + +def debugger_helpers_all_init(): + """ + Prepare sideboard to launch from a debugger. + This will do a few extra steps to make sure the environment is friendly. + + DO NOT CALL THIS IN PRODUCTION. + """ + _debugger_helper_remove_any_session_lockfiles() diff --git a/sideboard/internal/autolog.py b/sideboard/internal/autolog.py index a2da2b9..7508f14 100755 --- a/sideboard/internal/autolog.py +++ b/sideboard/internal/autolog.py @@ -187,8 +187,9 @@ def __init__(self, adapter_class=None, adapter_args=None, self.adapter_kwargs = adapter_kwargs def __getattr__(self, name): - if 'self' in inspect.currentframe().f_back.f_locals: - other = inspect.currentframe().f_back.f_locals['self'] + f_locals = inspect.currentframe().f_back.f_locals + if 'self' in f_locals and f_locals['self'] is not None: + other = f_locals['self'] caller_name = '%s.%s' % (other.__class__.__module__, other.__class__.__name__) else: caller_name = inspect.currentframe().f_back.f_globals['__name__'] @@ -234,7 +235,9 @@ def wrapper(*args, **kwargs): TRACE_LEVEL = 5 logging.addLevelName(TRACE_LEVEL, "TRACE") + + def trace(self, message, *args, **kws): # Yes, logger takes its '*args' as 'args'. - self._log(TRACE_LEVEL, message, args, **kws) + self._log(TRACE_LEVEL, message, args, **kws) logging.Logger.trace = trace diff --git a/sideboard/internal/connection_checker.py b/sideboard/internal/connection_checker.py index 99eef50..eeb7673 100644 --- a/sideboard/internal/connection_checker.py +++ b/sideboard/internal/connection_checker.py @@ -1,7 +1,9 @@ from __future__ import unicode_literals, print_function import ssl import socket +from contextlib import closing +import six from six.moves.urllib_parse import urlparse from sideboard.lib import services, entry_point @@ -9,7 +11,7 @@ def _check(url, **ssl_params): status = ['checking {}'.format(url)] - + try: parsed = urlparse(url) except Exception as e: @@ -18,14 +20,14 @@ def _check(url, **ssl_params): host = parsed.hostname port = parsed.port or (443 if parsed.scheme in ['https', 'wss'] else 80) status.append('using hostname {} and port {}'.format(host, port)) - + try: ip = socket.gethostbyname(host) except Exception as e: return status + ['failed to resolve host with DNS: {!s}'.format(e)] else: status.append('successfully resolved host {} to {}'.format(host, ip)) - + sock = None try: sock = socket.create_connection((host, port)) @@ -33,29 +35,36 @@ def _check(url, **ssl_params): return status + ['failed to establish a socket connection to {} on port {}: {!s}'.format(host, port, e)] else: status.append('successfully opened socket connection to {}:{}'.format(host, port)) - - if any(ssl_params.values()): + + # check if any of the non-version SSL options have been set + if any(val for val in ssl_params.values() if not isinstance(val, int)): try: wrapped = ssl.wrap_socket(sock, **ssl_params) except Exception as e: return status + ['failed to complete SSL handshake ({}): {!s}'.format(ssl_params, e)] else: - status.append('succeeded at SSL handshake') + status.append('succeeded at SSL handshake (without validating server cert)') finally: if sock: sock.close() - + + try: + with closing(socket.create_connection((host, port))) as sock: + wrapped = ssl.wrap_socket(sock, **dict(ssl_params, cert_reqs=ssl.CERT_REQUIRED)) + status.append('succeeded at validating server cert') + except Exception as e: + return status + ['failed to validate server cert ({}): {!s}'.format(ssl_params, e)] + status.append('everything seems to work') - return status def check_all(): checks = {} for name, jservice in services._jsonrpc.items(): - jproxy = jservice._send.im_self # ugly kludge to get the ServerProxy object + jproxy = jservice._send.im_self if six.PY2 else jservice._send.__self__ # ugly kludge to get the ServerProxy object url = '{}://{}/'.format(jproxy.type, jproxy.host) - checks[name] = _check(url, ca_certs=jproxy.ca_certs, keyfile=jproxy.key_file, certfile=jproxy.cert_file) + checks[name] = _check(url, **jproxy.ssl_opts) return checks diff --git a/sideboard/internal/imports.py b/sideboard/internal/imports.py index fa0139f..7103497 100755 --- a/sideboard/internal/imports.py +++ b/sideboard/internal/imports.py @@ -1,17 +1,34 @@ from __future__ import unicode_literals import sys import importlib +from collections import OrderedDict from glob import glob +from itertools import chain from os.path import join, isdir, basename from sideboard.config import config -plugins = {} + +plugins = OrderedDict() +plugin_dirs = OrderedDict() + + +def _discover_plugin_dirs(): + unsorted_dirs = {basename(d).replace('-', '_'): d + for d in glob(join(config['plugins_dir'], '*')) + if isdir(d) and not basename(d).startswith('_')} + + priority_plugins = config['priority_plugins'] + nonpriority_plugins = sorted(set(unsorted_dirs.keys()).difference(priority_plugins)) + sorted_plugins = chain(priority_plugins, nonpriority_plugins) + + return [(name, unsorted_dirs[name]) for name in sorted_plugins] + def _discover_plugins(): - ordered = list(reversed(config['priority_plugins'])) - plugin_dirs = [d for d in glob(join(config['plugins_dir'], '*')) if isdir(d) and not basename(d).startswith('_')] - for plugin_path in sorted(plugin_dirs, reverse=True, key=lambda d: (ordered.index(d) if d in ordered else 0)): - sys.path.append(plugin_path) - plugin_name = basename(plugin_path).replace('-', '_') - plugins[plugin_name] = importlib.import_module(plugin_name) + for name, path in _discover_plugin_dirs(): + sys.path.append(path) + plugin_dirs[name] = path + + for name, path in plugin_dirs.items(): + plugins[name] = importlib.import_module(name) diff --git a/sideboard/internal/logging.py b/sideboard/internal/logging.py index ab77417..92a24cc 100644 --- a/sideboard/internal/logging.py +++ b/sideboard/internal/logging.py @@ -4,29 +4,46 @@ import logging_unterpolation -from sideboard.config import config +from sideboard.config import config, get_config_root + + +class IndentMultilinesLogFormatter(logging.Formatter): + """ + Provide a formatter (unused by default) which adds indentation to messages + which are split across multiple lines. + """ + def format(self, record): + s = super(IndentMultilinesLogFormatter, self).format(record) + # indent all lines that start with a newline so they are easier for external log programs to parse + s = s.rstrip('\n').replace('\n', '\n ') + return s def _configure_logging(): logging_unterpolation.patch_logging() - fname='/etc/sideboard/logging.cfg' + fname = os.path.join(get_config_root(), 'logging.cfg') if os.path.exists(fname): logging.config.fileConfig(fname, disable_existing_loggers=True) else: - logging.config.dictConfig({ - 'version': 1, - 'root': { - 'level': config['loggers']['root'], - 'handlers': config['handlers'].dict().keys() - }, - 'loggers': { - name: {'level': level} - for name, level in config['loggers'].items() if name != 'root' - }, - 'handlers': config['handlers'].dict(), - 'formatters': { - 'default': { - 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s' - } + # ConfigObj doesn't support interpolation escaping, so we manually work around it here + formatters = config['formatters'].dict() + for formatter in formatters.values(): + formatter['format'] = formatter['format'].replace('$$', '%') + formatter['datefmt'] = formatter['datefmt'].replace('$$', '%') or None + formatters['indent_multiline'] = { + '()': IndentMultilinesLogFormatter, + 'format': formatters['default']['format'] } - }) + logging.config.dictConfig({ + 'version': 1, + 'root': { + 'level': config['loggers']['root'], + 'handlers': config['handlers'].dict().keys() + }, + 'loggers': { + name: {'level': level} + for name, level in config['loggers'].items() if name != 'root' + }, + 'handlers': config['handlers'].dict(), + 'formatters': formatters + }) diff --git a/sideboard/jsonrpc.py b/sideboard/jsonrpc.py index 14ec89d..895d88c 100755 --- a/sideboard/jsonrpc.py +++ b/sideboard/jsonrpc.py @@ -6,6 +6,7 @@ from cherrypy.lib.jsontools import json_decode from sideboard.lib import log, config, serializer +from sideboard.websockets import trigger_delayed_notifications ERR_INVALID_RPC = -32600 @@ -34,16 +35,16 @@ def force_json_in(): except ValueError: raise cherrypy.HTTPError(400, 'Invalid JSON document') - cherrypy.tools.force_json_in = cherrypy.Tool('before_request_body', force_json_in, priority=30) + def _make_jsonrpc_handler(services, debug=config['debug'], precall=lambda body: None, errback=lambda err, message: log.error(message, exc_info=True)): @cherrypy.expose @cherrypy.tools.force_json_in() @cherrypy.tools.json_out(handler=json_handler) - def jsonrpc_handler(self): + def jsonrpc_handler(self=None): id = None def error(code, message): @@ -90,5 +91,7 @@ def error(code, message): if debug: message += ': ' + traceback.format_exc() return error(ERR_FUNC_EXCEPTION, message) + finally: + trigger_delayed_notifications() return jsonrpc_handler diff --git a/sideboard/lib/__init__.py b/sideboard/lib/__init__.py index ef8997e..2ad5752 100644 --- a/sideboard/lib/__init__.py +++ b/sideboard/lib/__init__.py @@ -1,132 +1,27 @@ from __future__ import unicode_literals -import os -import json -from functools import wraps -from datetime import datetime, date -from collections import Sized, Iterable, Mapping import six from sideboard.internal.autolog import log from sideboard.config import config, ConfigurationError, parse_config -from sideboard.lib._cp import stopped, on_startup, on_shutdown, mainloop, ajax, renders_template, render_with_templates +from sideboard.lib._utils import is_listy, listify, serializer, cached_property, request_cached_property, class_property, entry_point, RWGuard +from sideboard.lib._cp import stopped, on_startup, on_shutdown, mainloop, ajax, renders_template, render_with_templates, restricted, all_restricted, register_authenticator +from sideboard.lib._profiler import cleanup_profiler, profile, Profiler, ProfileAggregator from sideboard.lib._threads import DaemonTask, Caller, GenericCaller, TimeDelayQueue -from sideboard.lib._websockets import WebSocket, Model, Subscription -from sideboard.websockets import subscribes, notifies, notify, threadlocal +from sideboard.lib._websockets import WebSocket, Model, Subscription, MultiSubscription +from sideboard.websockets import subscribes, locally_subscribes, notifies, notify, threadlocal from sideboard.lib._services import services __all__ = ['log', 'services', 'ConfigurationError', 'parse_config', + 'is_listy', 'listify', 'serializer', 'cached_property', 'class_property', 'entry_point', 'stopped', 'on_startup', 'on_shutdown', 'mainloop', 'ajax', 'renders_template', 'render_with_templates', + 'restricted', 'all_restricted', 'register_authenticator', + 'cleanup_profiler', 'profile', 'Profiler', 'ProfileAggregator', 'DaemonTask', 'Caller', 'GenericCaller', 'TimeDelayQueue', - 'WebSocket', 'Model', 'Subscription', - 'listify', 'serializer', 'cached_property', 'is_listy', 'entry_point', - 'threadlocal', 'subscribes', 'notifies', 'notify'] + 'WebSocket', 'Model', 'Subscription', 'MultiSubscription', + 'listify', 'serializer', 'cached_property', 'request_cached_property', 'is_listy', 'entry_point', 'RWGuard', + 'threadlocal', 'subscribes', 'locally_subscribes', 'notifies', 'notify'] if six.PY2: __all__ = [s.encode('ascii') for s in __all__] - - -def listify(x): - """ - returns a list version of x if x is a non-string iterable, otherwise - returns a list with x as its only element - """ - return list(x) if is_listy(x) else [x] - - -class serializer(json.JSONEncoder): - """ - JSONEncoder subclass for plugins to register serializers for types. - Plugins should not need to instantiate this class directly, but - they are expected to call serializer.register() for new data types. - """ - - _registry = {} - _datetime_format = '%Y-%m-%d %H:%M:%S.%f' - - def default(self, o): - if type(o) in self._registry: - preprocessor = self._registry[type(o)] - else: - for klass, preprocessor in self._registry.items(): - if isinstance(o, klass): - break - else: - raise json.JSONEncoder.default(self, o) - - return preprocessor(o) - - @classmethod - def register(cls, type, preprocessor): - """ - Associates a type with a preprocessor so that RPC handlers may - pass non-builtin JSON types. For example, Sideboard already - does the equivalent of - - >>> serializer.register(datetime, lambda dt: dt.strftime('%Y-%m-%d %H:%M:%S.%f')) - - This method raises an exception if you try to register a - preprocessor for a type which already has one. - - :param type: the type you are registering - :param preprocessor: function which takes one argument which is - the value to serialize and returns a json- - serializable value - """ - assert type not in cls._registry, '{} already has a preprocessor defined'.format(type) - cls._registry[type] = preprocessor - -serializer.register(date, lambda d: d.strftime('%Y-%m-%d')) -serializer.register(datetime, lambda dt: dt.strftime(serializer._datetime_format)) - - -def cached_property(func): - """decorator for making readonly, memoized properties""" - pname = "_" + func.__name__ - @property - @wraps(func) - def caching(self, *args, **kwargs): - if not hasattr(self, pname): - setattr(self, pname, func(self, *args, **kwargs)) - return getattr(self, pname) - return caching - - -def is_listy(x): - """ - returns a boolean indicating whether the passed object is "listy", - which we define as a sized iterable which is not a map or string - """ - return isinstance(x, Sized) and isinstance(x, Iterable) and not isinstance(x, (Mapping, type(b''), type(''))) - - -def entry_point(func): - """ - Decorator used to define entry points for command-line scripts. Sideboard - ships with a "sep" (Sideboard Entry Point) command line script which can be - used to call into any plugin-defined entry point after deleting sys.argv[0] - so that the entry point name will be the first argument. For example, if a - plugin had this entry point: - - @entry_point - def some_action(): - print(sys.argv) - - Then someone in a shell ran the command: - - sep some_action foo bar - - It would print: - - ['some_action', 'foo', 'bar'] - - :param func: a function which takes no arguments; its name will be the name - of the command, and an exception is raised if a function with - the same name has already been registered as an entry point - """ - assert func.__name__ not in _entry_points, 'An entry point named {} has already been implemented'.format(func.__name__) - _entry_points[func.__name__] = func - return func - -_entry_points = {} diff --git a/sideboard/lib/_cp.py b/sideboard/lib/_cp.py index c5dcbee..84fac17 100644 --- a/sideboard/lib/_cp.py +++ b/sideboard/lib/_cp.py @@ -10,22 +10,62 @@ import cherrypy import sideboard.lib -from sideboard.lib import log +from sideboard.lib import log, config, serializer +auth_registry = {} _startup_registry = defaultdict(list) _shutdown_registry = defaultdict(list) -def on_startup(func, priority=50): +def _on_startup(func, priority): _startup_registry[priority].append(func) return func -def on_shutdown(func, priority=50): +def _on_shutdown(func, priority): _shutdown_registry[priority].append(func) return func +def on_startup(func=None, priority=50): + """ + Register a function to be called when Sideboard starts. Startup functions + have a priority, and the functions are invoked in priority order, where + low-priority-numbered functions are invoked before higher numbers. + + Startup functions may be registered in one of three ways: + + 1) A function can be passed directly, e.g. + on_startup(callback_function) + on_startup(callback_function, priority=25) + + 2) This function can be used as a decorator, e.g. + @on_startup + def callback_function(): + ... + + 3) This function can be used as a decorator with a priority value, e.g. + @on_startup(priority=25) + def callback_function(): + ... + """ + if func: + return _on_startup(func, priority) + else: + return lambda func: _on_startup(func, priority) + + +def on_shutdown(func=None, priority=50): + """ + Register a function to be called when Sideboard exits. See the on_startup + function above for how this is used. + """ + if func: + return _on_shutdown(func, priority) + else: + return lambda func: _on_shutdown(func, priority) + + def _run_startup(): for priority, functions in sorted(_startup_registry.items()): for func in functions: @@ -40,7 +80,6 @@ def _run_shutdown(): except Exception: log.warn('Ignored exception during shutdown', exc_info=True) - stopped = Event() on_startup(stopped.clear, priority=0) on_shutdown(stopped.set, priority=0) @@ -60,7 +99,7 @@ def mainloop(): try: while not stopped.is_set(): try: - stopped.wait(0.1) + stopped.wait(config['thread_wait_interval']) except KeyboardInterrupt: break finally: @@ -79,11 +118,40 @@ def to_json(self, *args, **kwargs): return to_json -def renders_template(method, restricted=False): +def restricted(x): + """ + Decorator for CherryPy page handler methods. This can either be called + to provide an authenticator ident or called directly as a decorator, e.g. + + @restricted + def some_page(self): ... + + is equivalent to + + @restricted(sideboard.lib.config['default_authenticator']) + def some_page(self): ... + """ + def make_decorator(ident): + def decorator(func): + @cherrypy.expose + @wraps(func) + def with_checking(*args, **kwargs): + if not auth_registry[ident]['check'](): + raise cherrypy.HTTPRedirect(auth_registry[ident]['login_path']) + else: + return func(*args, **kwargs) + return with_checking + return decorator + + if hasattr(x, '__call__'): + return make_decorator(config['default_authenticator'])(x) + else: + return make_decorator(x) + + +def renders_template(method): """ Decorator for CherryPy page handler methods implementing default behaviors: - - if your @render_with_templates class decorator used the "restricted" - argument, this redirects to /login if the user has not authenticated - if your page handler returns a string, return that un-modified - if your page handler returns a non-jsonrpc dictionary, render a template with that dictionary; the function my_page will render my_page.html @@ -91,9 +159,6 @@ def renders_template(method, restricted=False): @cherrypy.expose @wraps(method) def renderer(self, *args, **kwargs): - if restricted and 'username' not in cherrypy.session: - raise cherrypy.HTTPRedirect('/login?return_to=' + quote(cherrypy.request.app.script_name)) - output = method(self, *args, **kwargs) if isinstance(output, dict) and output.get('jsonrpc') != '2.0': return self.env.get_template(method.__name__ + '.html').render(**output) @@ -112,31 +177,58 @@ def _guess_autoescape(template_name): class render_with_templates(object): """ - Class decorator for CherryPy application objects with two optional arguments: - - template_dir: if present, this will cause all of your page handler methods - which return dictionaries to render Jinja templates found in this - directory using those dictionaries. So if you have a page handler called - my_page which returns a dictionary, the template my_page.html in the - template_dir directory will be rendered with that dictionary. - - restricted: boolean which if True (this is False by default) will cause all - page handlers in this class to redirect to /login if the client has not - logged in already - """ - def __init__(self, template_dir=None, restricted=False): + Class decorator for CherryPy application objects which causes all of your page + handler methods which return dictionaries to render Jinja templates found in this + directory using those dictionaries. So if you have a page handler called my_page + which returns a dictionary, the template my_page.html in the template_dir + directory will be rendered with that dictionary. An "env" attribute gets added + to the class which is a Jinja environment. + + For convenience, if the optional "restricted" parameter is passed, this class is + also passed through the @all_restricted class decorator. + """ + def __init__(self, template_dir, restricted=False): self.template_dir, self.restricted = template_dir, restricted def __call__(self, klass): - if self.template_dir: - klass.env = jinja2.Environment( - autoescape=_guess_autoescape, - loader=jinja2.FileSystemLoader(self.template_dir), - block_start_string='((%', - block_end_string='%))', - variable_start_string='$((', - variable_end_string='))$', - ) - klass.env.filters['jsonify'] = lambda x: klass.env.filters['safe'](json.dumps(x)) + klass.env = jinja2.Environment(autoescape=_guess_autoescape, loader=jinja2.FileSystemLoader(self.template_dir)) + klass.env.filters['jsonify'] = lambda x: klass.env.filters['safe'](json.dumps(x, cls=serializer)) + + if self.restricted: + all_restricted(self.restricted)(klass) + for name, func in list(klass.__dict__.items()): if hasattr(func, '__call__'): - setattr(klass, name, renders_template(func, self.restricted)) + setattr(klass, name, renders_template(func)) + return klass + + +class all_restricted(object): + """Invokes the @restricted decorator on all methods of a class.""" + def __init__(self, ident): + self.ident = ident + assert ident in auth_registry, '{!r} is not a recognized authenticator'.format(ident) + + def __call__(self, klass): + for name, func in list(klass.__dict__.items()): + if hasattr(func, '__call__'): + setattr(klass, name, restricted(self.ident)(func)) + return klass + + +def register_authenticator(ident, login_path, checker): + """ + Register a new authenticator, which consists of three things: + - A string ident, used to identify the authenticator in @restricted calls. + - The path to the login page we should redirect to when not authenticated. + - A function callable with no parameters which returns a truthy value if the + user is logged in and a falsey value if they are not. + """ + assert ident not in auth_registry, '{} is already a registered authenticator'.format(ident) + auth_registry[ident] = { + 'check': checker, + 'login_path': login_path + } + +register_authenticator('default', '/login', lambda: 'username' in cherrypy.session) diff --git a/sideboard/lib/_profiler.py b/sideboard/lib/_profiler.py new file mode 100644 index 0000000..f189a92 --- /dev/null +++ b/sideboard/lib/_profiler.py @@ -0,0 +1,268 @@ +""" +Adds profiling tools and a web interface for viewing profiling results. + +The Sideboard profiler borrows heavily from the `CherryPy profiler +`_, +but with a few added features and nicer formatting. + + * Adds the ability to sort results by different columns. + * Adds the ability to cleanup profile data files. + * Uses a better naming scheme for profile data files. + * Uses `cProfile` instead of `profile` for better performance. + +Profiling data can be collected using the @profile decorator on functions and +methods. The profiling results can be viewed at http://servername/profile/. + +Good candidates for profiling are the outermost functions that generate your +web pages, usually exposed as cherrypy endpoints via @cherrypy.expose:: + + import cherrypy + from sideboard.lib import profile + + class Root(object): + @cherrypy.expose + @profile + def index(self): + # Create and return the index page + return '' + + +But any regular function can be profiled using the @profile decorator:: + + from sideboard.lib import profile + + @profile + def some_interesting_function(): + # Do some stuff + + +The following config options control how the profiler operates, see +configspec.ini for more details:: + + [cherrypy] + profiling.on = True + profiling.path = "%(root)s/data/profiler" + profiling.aggregate = False + profiling.strip_dirs = False + +""" +from __future__ import unicode_literals +import io +import os +import os.path +import cProfile +import pstats +from datetime import datetime +from functools import wraps +from glob import glob + +import cherrypy +from sideboard.lib import config, entry_point, listify + + +def _new_func_strip_path(func_name): + """ + Adds the parent module to profiler output for `__init__.py` files. + + Copied verbatim from cherrypy/lib/profiler.py. + """ + filename, line, name = func_name + if filename.endswith('__init__.py'): + return os.path.basename(filename[:-12]) + filename[-12:], line, name + return os.path.basename(filename), line, name + +pstats.func_strip_path = _new_func_strip_path + + +@entry_point +def cleanup_profiler(): + """ + Deletes all `*.prof` files in the profiler's data directory. + + This is useful when you've created tons of profile files that you're no + longer interested in. Exposed as a `sep` command:: + + $ sep cleanup_profiler + + The profiler directory is specified in the config by:: + + [cherrypy] + profiling.path = 'path/to/profile/data' + + """ + profiling_path = config['cherrypy']['profiling.path'] + for f in glob(os.path.join(profiling_path, '*.prof')): + os.remove(f) + + +def profile(func): + """ + Decorator to capture profile data from a method or function. + + If profiling is disabled then this decorator is a no-op, and the original + function is returned unmodified. Since the original function is returned, + this decorator does not incur any performance penalty if profiling is + disabled. To enable or disable profiling use the following setting in your + config:: + + [cherrypy] + profiling.on = True # Or False to disable + + Args: + func (function): The function to profile. + + Returns: + function: Either a wrapped version of `func` with profiling enabled, + or `func` itself if profiling is disabled. + + See Also: + configspec.ini + """ + if config['cherrypy']['profiling.on']: + profiling_path = config['cherrypy']['profiling.path'] + if config['cherrypy']['profiling.aggregate']: + p = ProfileAggregator(profiling_path) + else: + p = Profiler(profiling_path) + + @wraps(func) + def wrapper(*args, **kwargs): + return p.run(func, *args, **kwargs) + return wrapper + else: + return func + + +class Profiler(object): + """ + Mostly copied from cherrypy/lib/profiler.py. + + * Adds the ability to sort results by different columns. + * Adds the ability to cleanup profile data files. + * Uses a better naming scheme for profile data files. + """ + + # https://docs.python.org/3/library/profile.html#pstats.Stats.sort_stats + sort_fields = [ + ('cumulative', 'Cumulative Time'), + ('filename', 'File Name'), + ('ncalls', 'Call Count'), + ('pcalls', 'Primitive Call Count'), + ('line', 'Line Number'), + ('name', 'Function Name'), + ('nfl', 'Function/File/Line'), + ('stdname', 'Standard Name'), + ('tottime', 'Total Time')] + + def __init__(self, path=config['cherrypy']['profiling.path']): + self.path = path + if not os.path.exists(path): + os.makedirs(path) + + def new_filename(self, func): + date = datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f") + name = func.__name__ if func.__name__ else 'unknown' + return '{}_{}.prof'.format(date, name) + + def run(self, func, *args, **params): + """Dump profile data into self.path.""" + path = os.path.join(self.path, self.new_filename(func)) + prof = cProfile.Profile() + result = prof.runcall(func, *args, **params) + prof.dump_stats(path) + return result + + def statfiles(self): + """:rtype: list of available profiles. + """ + return [f for f in os.listdir(self.path) if f.endswith('.prof')] + + def stats(self, filename, sortby='cumulative'): + """:rtype stats(index): output of print_stats() for the given profile. + """ + sio = io.StringIO() + s = pstats.Stats(os.path.join(self.path, filename), stream=sio) + if config['cherrypy']['profiling.strip_dirs']: + s.strip_dirs() + s.sort_stats(sortby) + s.print_stats() + response = sio.getvalue() + sio.close() + return response + + @cherrypy.expose + def index(self): + return ''' + Sideboard Profiler + + + + + + ''' + + @cherrypy.expose + def menu(self): + yield '

Profiling Runs

' + runs = self.statfiles() + if not runs: + yield 'No profiling runs' + else: + yield '
' + runs.sort() + for run in runs: + yield '{0}' \ + '
'.format(run) + yield '



' + yield '' \ + 'Delete all profiling runs' + + @cherrypy.expose + def report(self, filename, sortby='cumulative'): + yield 'Sort by: ' + for (field, label) in Profiler.sort_fields: + if field == sortby: + yield '{} '.format(label) + else: + yield '{}' \ + ' '.format(filename, field, label) + yield '
'
+        yield self.stats(filename, sortby)
+        yield '
' + + @cherrypy.expose + def cleanup(self): + """ + Deletes all `*.prof` files in the profiler's data directory. + + To delete all profile data files hit + http://servername/profile/cleanup/. + + The profiler directory is specified by:: + + [cherrypy] + profiling.path = 'path/to/profile/data' + + See Also: + `cleanup_profiler` + """ + cleanup_profiler() + raise cherrypy.HTTPRedirect('.') + + +class ProfileAggregator(Profiler): + """ + Mostly copied from cherrypy/lib/profiler.py. + + * Uses a better naming scheme for profile data files. + """ + + def __init__(self, path=None): + super(ProfileAggregator, self).__init__(path) + self.profiler = cProfile.Profile() + + def run(self, func, *args, **params): + path = os.path.join(self.path, self.new_filename(func)) + result = self.profiler.runcall(func, *args, **params) + self.profiler.dump_stats(path) + return result diff --git a/sideboard/lib/_services.py b/sideboard/lib/_services.py index d75b877..5bfe9c4 100644 --- a/sideboard/lib/_services.py +++ b/sideboard/lib/_services.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals import os +import ssl from rpctools.jsonrpc import ServerProxy @@ -11,8 +12,10 @@ def __init__(self, services, name): self.services, self.name = services, name def __getattr__(self, method): + from sideboard.lib import is_listy assert self.name in self.services, '{} is not registered as a service'.format(self.name) service = self.services[self.name] + assert not is_listy(getattr(service, '__all__', None)) or method in service.__all__, 'unable to call non-whitelisted method {}.{}'.format(self.name, method) func = service.make_caller('{}.{}'.format(self.name, method)) if isinstance(service, WebSocket) else getattr(service, method, None) assert func and hasattr(func, '__call__') and not method.startswith('_'), 'no such method {}.{}'.format(self.name, method) return func @@ -30,26 +33,26 @@ class _Services(object): """ This class is used by plugins to register services, and to call services registered by other plugins. You call services by attribute lookup, e.g. - + >>> from sideboard.lib import services >>> services.foo.bar() 'Hello World!' - + You may get a service which has not yet been registered; you'll only get an exception when calling a method on the service if it doesn't exist yet; this is to facilitate getting a namespace before the relevant plugin has been imported by Sideboard: - + >>> foo, baz = services.foo, services.baz >>> foo.bar() 'Hello World!' >>> baz.baf() AssertionError: baz is not registered as a service - + Services may be local or websocket, but they're called in the same way. If you know that service is remote, and you want to use Jsonrpc, you can use the .jsonrpc attribute of this class, e.g. - + >>> services.jsonrpc.foo.bar() 'Hello World!' >>> foo = services.jsonrpc.foo @@ -91,9 +94,9 @@ def get_services(self): """ return self._services - def _register_websocket(self, url=None, **ws_kwargs): + def _register_websocket(self, url=None, connect_immediately=True, **ws_kwargs): if url not in self._websockets: - self._websockets[url] = WebSocket(url, **ws_kwargs) + self._websockets[url] = WebSocket(url, connect_immediately=connect_immediately, **ws_kwargs) return self._websockets[url] def get_websocket(self, service_name=None): @@ -114,30 +117,94 @@ def __getattr__(self, name): services = _Services() +def _rpc_opts(host, service_config=None): + """ + Sideboard uses client certs for backend service authentication. There's a + global set of config options which determine the SSL settings we pass to our + RPC libraries, but sometimes different services require client certs issued + by different CAs. In those cases, we define a config subsection of the main + [rpc_services] section to override those settings. + + This function takes a hostname and for each config option, it returns either + the hostname-specific config option if it exists, or the global config option + if it doesn't. Specifically, this returns a dict of option names/values. + + If the service_config parameter is passed, it uses that as the config section + from which to draw the hostname-specific options. Otherwise it searches + the [rpc_services] config section for Sideboard and for all Sideboard plugins + which have a "config" object defined in order to find options for that host. + """ + from sideboard.internal.imports import plugins + section = service_config + if service_config is not None: # check explicitly for None because service_config might be {} + section = service_config + else: + rpc_sections = {host: section for host, section in config['rpc_services'].items() if isinstance(section, dict)} + for plugin in plugins.values(): + plugin_config = getattr(plugin, 'config', None) + if isinstance(plugin_config, dict): + rpc_sections.update({host: section for host, section in plugin_config.get('rpc_services', {}).items() if isinstance(section, dict)}) + section = rpc_sections.get(host, {}) + + opts = {} + for setting in ['client_key', 'client_cert', 'ca', 'ssl_version']: + path = section.get(setting, config[setting]) + if path and setting != 'ssl_version': + assert os.path.exists(path), '{} config option set to path not found on the filesystem: {}'.format(setting, path) + + opts[setting] = path + return opts + + +def _ssl_opts(rpc_opts): + """ + Given a dict of config options returned by _rpc_opts, return a dict of + options which can be passed to the ssl module. + """ + ssl_opts = { + 'ca_certs': rpc_opts['ca'], + 'keyfile': rpc_opts['client_key'], + 'certfile': rpc_opts['client_cert'], + 'cert_reqs': ssl.CERT_REQUIRED if rpc_opts['ca'] else None, + 'ssl_version': getattr(ssl, rpc_opts['ssl_version']) + } + return {k: v for k, v in ssl_opts.items() if v} + + +def _ws_url(host, rpc_opts): + """ + Given a hostname and set of config options returned by _rpc_opts, return the + standard URL websocket endpoint for a Sideboard remote service. + """ + return '{protocol}://{host}/wsrpc'.format(host=host, protocol='wss' if rpc_opts['ca'] else 'ws') + + def _register_rpc_services(rpc_services): + """ + Sideboard has a config file, and it provides a parse_config method for its + plugins to parse their own config files. In both cases, we check for the + presence of an [rpc_services] config section, which we use to register any + services defined there with our sideboard.lib.services API. Note that this + means a server can provide information about a remote service in either the + main Sideboard config file OR the config file of any plugin. + + This function takes the [rpc_services] config section from either Sideboard + itself or one of its plugins and registers all remote services found there. + """ for service_name, host in rpc_services.items(): if not isinstance(host, dict): - opts = {} - for setting in ['client_key', 'client_cert', 'ca']: - path = rpc_services.get(host, {}).get(setting, config[setting]) - if path: - assert os.path.exists(path), '{} config option set to path not found on the filesystem: {}'.format(setting, path) - - _check(setting, path) - opts[setting] = path - - jsonrpc_url = '{protocol}://{host}/jsonrpc'.format(host=host, protocol='https' if opts['ca'] else 'http') - jproxy = ServerProxy(url, key_file=opts['client_key'], cert_file=opts['client_cert'], - ca_certs=opts['ca'], validate_cert_hostname=bool(opts['ca'])) - jservice = getattr(jproxy, service_name) - if rpc_services.get(host, {}).get('jsonrpc_only'): - service = jservice - else: - ws_url = '{protocol}://{host}/wsprc'.format(host=host, protocol='wss' if opts['ca'] else 'ws') - ssl_opts = {'key_file': opts['client_key'], 'cert_file': opts['client_cert'], 'ca_certs': opts['ca']} - service = services._register_websocket(ws_url, ssl_opts={k: v for k, v in ssl_opts if v}) - - services.register(service, name, _jsonrpc=jservice, _override=True) + rpc_opts = _rpc_opts(host, rpc_services.get(host, {})) + ssl_opts = _ssl_opts(rpc_opts) + + jsonrpc_url = '{protocol}://{host}/jsonrpc'.format(host=host, protocol='https' if rpc_opts['ca'] else 'http') + jproxy = ServerProxy(jsonrpc_url, ssl_opts=ssl_opts, validate_cert_hostname=bool(rpc_opts['ca'])) + jservice = getattr(jproxy, service_name) + if rpc_services.get(host, {}).get('jsonrpc_only'): + service = jservice + else: + service = services._register_websocket(_ws_url(host, rpc_opts), ssl_opts=ssl_opts, connect_immediately=False) + + services.register(service, service_name, _jsonrpc=jservice, _override=True) _register_rpc_services(config['rpc_services']) diff --git a/sideboard/lib/_threads.py b/sideboard/lib/_threads.py index c8b8d37..559df7b 100644 --- a/sideboard/lib/_threads.py +++ b/sideboard/lib/_threads.py @@ -1,20 +1,85 @@ from __future__ import unicode_literals +import sys import time import heapq +import ctypes +import platform +import traceback +import threading from warnings import warn from threading import Thread, Timer, Event, Lock +import six from six.moves.queue import Queue, Empty -from sideboard.lib import log, on_startup, on_shutdown +from sideboard.lib import log, config, on_startup, on_shutdown +from sideboard.debugging import register_diagnostics_status_function + +try: + import prctl + import psutil +except ImportError: + prctl = psutil = None # For platforms without this support. + + +def _get_linux_thread_tid(): + """ + Get the current linux thread ID as it appears in /proc/[pid]/task/[tid] + :return: Linux thread ID if available, or -1 if any errors / not on linux + """ + try: + if not platform.system().startswith('Linux'): + raise ValueError('Can only get thread id on Linux systems') + syscalls = { + 'i386': 224, # unistd_32.h: #define __NR_gettid 224 + 'x86_64': 186, # unistd_64.h: #define __NR_gettid 186 + } + syscall_num = syscalls[platform.machine()] + tid = ctypes.CDLL('libc.so.6').syscall(syscall_num) + except: + tid = -1 + return tid + + +def _set_current_thread_ids_from(thread): + # thread ID part 1: set externally visible thread name in /proc/[pid]/tasks/[tid]/comm to our internal name + if prctl and thread.name: + # linux doesn't allow thread names > 15 chars, and we ideally want to see the end of the name. + # attempt to shorten the name if we need to. + shorter_name = thread.name if len(thread.name) < 15 else thread.name.replace('CP Server Thread', 'CPServ') + prctl.set_name(shorter_name) + + # thread ID part 2: capture linux-specific thread ID (TID) and store it with this thread object + # if TID can't be obtained or system call fails, tid will be -1 + thread.linux_tid = _get_linux_thread_tid() + + +# inject our own code at the start of every thread's start() method which sets the thread name via prctl(). +# Python thread names will now be shown in external system tools like 'top', '/proc', etc. +def _thread_name_insert(self): + _set_current_thread_ids_from(self) + threading.Thread._bootstrap_inner_original(self) + +if six.PY3: + threading.Thread._bootstrap_inner_original = threading.Thread._bootstrap_inner + threading.Thread._bootstrap_inner = _thread_name_insert +else: + threading.Thread._bootstrap_inner_original = threading.Thread._Thread__bootstrap + threading.Thread._Thread__bootstrap = _thread_name_insert + +# set the ID's of the main thread +threading.current_thread().name = 'sideboard_main' +_set_current_thread_ids_from(threading.current_thread()) class DaemonTask(object): - def __init__(self, func, interval=0.1, threads=1): + def __init__(self, func, interval=None, threads=1, name=None): self.lock = Lock() self.threads = [] self.stopped = Event() self.func, self.interval, self.thread_count = func, interval, threads + self.name = name or self.func.__name__ + on_startup(self.start) on_shutdown(self.stop) @@ -29,8 +94,9 @@ def run(self): except: log.error('unexpected error', exc_info=True) - if self.interval: - self.stopped.wait(self.interval) + interval = config['thread_wait_interval'] if self.interval is None else self.interval + if interval: + self.stopped.wait(interval) def start(self): with self.lock: @@ -38,8 +104,8 @@ def start(self): self.stopped.clear() del self.threads[:] for i in range(self.thread_count): - t = Thread(target = self.run) - t.name = '{}-{}'.format(self.func.__name__, i + 1) + t = Thread(target=self.run) + t.name = '{}-{}'.format(self.name, i + 1) t.daemon = True t.start() self.threads.append(t) @@ -93,55 +159,81 @@ def _put_and_notify(self): class Caller(DaemonTask): - def __init__(self, func, interval=0, threads=1): - self.q = TimeDelayQueue() - DaemonTask.__init__(self, self.call, interval=interval, threads=threads) + def __init__(self, func, interval=0, threads=1, name=None): + self.q = Queue() + DaemonTask.__init__(self, self.call, interval=interval, threads=threads, name=name or func.__name__) self.callee = func def call(self): try: - args, kwargs = self.q.get(timeout = 0.1) + args, kwargs = self.q.get(timeout=config['thread_wait_interval']) self.callee(*args, **kwargs) except Empty: pass - def start(self): - self.q.task.start() - DaemonTask.start(self) - - def stop(self): - self.q.task.stop() - DaemonTask.stop(self) - def defer(self, *args, **kwargs): self.q.put([args, kwargs]) - def delayed(self, delay, *args, **kwargs): - self.q.put([args, kwargs], delay=delay) - class GenericCaller(DaemonTask): - def __init__(self, interval=0, threads=1): - DaemonTask.__init__(self, self.call, interval=interval, threads=threads) - self.q = TimeDelayQueue() + def __init__(self, interval=0, threads=1, name=None): + DaemonTask.__init__(self, self.call, interval=interval, threads=threads, name=name) + self.q = Queue() def call(self): try: - func, args, kwargs = self.q.get(timeout = 0.1) + func, args, kwargs = self.q.get(timeout=config['thread_wait_interval']) func(*args, **kwargs) except Empty: pass - def start(self): - self.q.task.start() - DaemonTask.start(self) - - def stop(self): - self.q.task.stop() - DaemonTask.stop(self) - def defer(self, func, *args, **kwargs): self.q.put([func, args, kwargs]) - def delayed(self, delay, func, *args, **kwargs): - self.q.put([func, args, kwargs], delay=delay) + +def _get_thread_current_stacktrace(thread_stack, thread): + out = [] + linux_tid = getattr(thread, 'linux_tid', -1) + status = '[unknown]' + if psutil and linux_tid != -1: + status = psutil.Process(linux_tid).status() + out.append('\n--------------------------------------------------------------------------') + out.append('# Thread name: "%s"\n# Python thread.ident: %d\n# Linux Thread PID (TID): %d\n# Run Status: %s' + % (thread.name, thread.ident, linux_tid, status)) + for filename, lineno, name, line in traceback.extract_stack(thread_stack): + out.append('File: "%s", line %d, in %s' % (filename, lineno, name)) + if line: + out.append(' %s' % (line.strip())) + return out + + +@register_diagnostics_status_function +def threading_information(): + out = [] + threads_by_id = dict([(thread.ident, thread) for thread in threading.enumerate()]) + for thread_id, thread_stack in sys._current_frames().items(): + thread = threads_by_id.get(thread_id, '') + out += _get_thread_current_stacktrace(thread_stack, thread) + return '\n'.join(out) + + +def _to_megabytes(bytes): + return str(int(bytes / 0x100000)) + 'MB' + + +@register_diagnostics_status_function +def general_system_info(): + """ + Print general system info + TODO: + - print memory nicer, convert mem to megabytes + - disk partitions usage, + - # of open file handles + - # free inode count + - # of cherrypy session files + - # of cherrypy session locks (should be low) + """ + out = [] + out += ['Mem: ' + repr(psutil.virtual_memory()) if psutil else ''] + out += ['Swap: ' + repr(psutil.swap_memory()) if psutil else ''] + return '\n'.join(out) diff --git a/sideboard/lib/_utils.py b/sideboard/lib/_utils.py new file mode 100644 index 0000000..93b3706 --- /dev/null +++ b/sideboard/lib/_utils.py @@ -0,0 +1,262 @@ +from __future__ import unicode_literals +import os +import json +from functools import wraps +from datetime import datetime, date +from contextlib import contextmanager +from threading import RLock, Condition, current_thread +from collections import Sized, Iterable, Mapping, defaultdict + + +def is_listy(x): + """ + returns a boolean indicating whether the passed object is "listy", + which we define as a sized iterable which is not a map or string + """ + return isinstance(x, Sized) and isinstance(x, Iterable) and not isinstance(x, (Mapping, type(b''), type(''))) + + +def listify(x): + """ + returns a list version of x if x is a non-string iterable, otherwise + returns a list with x as its only element + """ + return list(x) if is_listy(x) else [x] + + +class serializer(json.JSONEncoder): + """ + JSONEncoder subclass for plugins to register serializers for types. + Plugins should not need to instantiate this class directly, but + they are expected to call serializer.register() for new data types. + """ + + _registry = {} + _datetime_format = '%Y-%m-%d %H:%M:%S.%f' + + def default(self, o): + if type(o) in self._registry: + preprocessor = self._registry[type(o)] + else: + for klass, preprocessor in self._registry.items(): + if isinstance(o, klass): + break + else: + raise json.JSONEncoder.default(self, o) + + return preprocessor(o) + + @classmethod + def register(cls, type, preprocessor): + """ + Associates a type with a preprocessor so that RPC handlers may + pass non-builtin JSON types. For example, Sideboard already + does the equivalent of + + >>> serializer.register(datetime, lambda dt: dt.strftime('%Y-%m-%d %H:%M:%S.%f')) + + This method raises an exception if you try to register a + preprocessor for a type which already has one. + + :param type: the type you are registering + :param preprocessor: function which takes one argument which is + the value to serialize and returns a json- + serializable value + """ + assert type not in cls._registry, '{} already has a preprocessor defined'.format(type) + cls._registry[type] = preprocessor + +serializer.register(date, lambda d: d.strftime('%Y-%m-%d')) +serializer.register(datetime, lambda dt: dt.strftime(serializer._datetime_format)) +serializer.register(set, lambda s: sorted(list(s))) + + +def cached_property(func): + """decorator for making readonly, memoized properties""" + pname = '_cached_{}'.format(func.__name__) + + @property + @wraps(func) + def caching(self, *args, **kwargs): + if not hasattr(self, pname): + setattr(self, pname, func(self, *args, **kwargs)) + return getattr(self, pname) + return caching + + +def request_cached_property(func): + """ + Sometimes we want a property to be cached for the duration of a request, + with concurrent requests each having their own cached version. This does + that via the threadlocal class, such that each HTTP request CherryPy serves + and each RPC request served via websocket or JSON-RPC will have its own + cached value, which is cleared and then re-generated on later requests. + """ + from sideboard.lib import threadlocal + name = func.__module__ + '.' + func.__name__ + + @property + @wraps(func) + def with_caching(self): + val = threadlocal.get(name) + if val is None: + val = func(self) + threadlocal.set(name, val) + return val + return with_caching + + +class _class_property(property): + def __get__(self, cls, owner): + return self.fget.__get__(None, owner)() + + +def class_property(cls): + """ + For whatever reason, the @property decorator isn't smart enough to recognize + classmethods and behave differently on them than on instance methods. This + property may be used to create a class-level property, useful for singletons + and other one-per-class properties. Class properties are read-only. + """ + return _class_property(classmethod(cls)) + + +def entry_point(func): + """ + Decorator used to define entry points for command-line scripts. Sideboard + ships with a "sep" (Sideboard Entry Point) command line script which can be + used to call into any plugin-defined entry point after deleting sys.argv[0] + so that the entry point name will be the first argument. For example, if a + plugin had this entry point: + + @entry_point + def some_action(): + print(sys.argv) + + Then someone in a shell ran the command: + + sep some_action foo bar + + It would print: + + ['some_action', 'foo', 'bar'] + + :param func: a function which takes no arguments; its name will be the name + of the command, and an exception is raised if a function with + the same name has already been registered as an entry point + """ + assert func.__name__ not in _entry_points, 'An entry point named {} has already been implemented'.format(func.__name__) + _entry_points[func.__name__] = func + return func + +_entry_points = {} + + +class RWGuard(object): + """ + This utility class provides the ability to perform read/write locking, such + that we can have any number of readers OR a single writer. We give priority + to writers, who will get the lock before any readers. + + These locks are reentrant, meaning that the same thread can acquire a read + or write lock multiple times, and will then need to release the lock the + same number of times it was acquired. A thread with an acquired read lock + cannot acquire a write lock, or vice versa. Locks can only be released by + the threads which acquired them. + + This class is named RWGuard rather than RWLock because it is not itself a + lock, e.g. it doesn't have an acquire method, it cannot be directly used as + a context manager, etc. + """ + def __init__(self): + self.lock = RLock() + self.waiting_writer_count = 0 + self.acquired_writer = defaultdict(int) + self.acquired_readers = defaultdict(int) + self.ready_for_reads = Condition(self.lock) + self.ready_for_writes = Condition(self.lock) + + @property + @contextmanager + def read_locked(self): + """ + Context manager which acquires a read lock on entrance and releases it + on exit. Any number of threads may acquire a read lock. + """ + self.acquire_for_read() + try: + yield + finally: + self.release() + + @property + @contextmanager + def write_locked(self): + """ + Context manager which acquires a write lock on entrance and releases it + on exit. Only one thread may acquire a write lock at a time. + """ + self.acquire_for_write() + try: + yield + finally: + self.release() + + def acquire_for_read(self): + """ + NOTE: consumers are encouraged to use the "read_locked" context manager + instead of this method where possible. + + This method acquires the read lock for the current thread, blocking if + necessary until there are no other threads with the write lock acquired + or waiting for the write lock to be available. + """ + tid = current_thread().ident + assert tid not in self.acquired_writer, 'Threads which have already acquired a write lock may not lock for reading' + with self.lock: + while self.acquired_writer or (self.waiting_writer_count and tid not in self.acquired_readers): + self.ready_for_reads.wait() + self.acquired_readers[tid] += 1 + + def acquire_for_write(self): + """ + NOTE: consumers are encouraged to use the "write_locked" context manager + instead of this method where possible. + + This method acquires the write lock for the current thread, blocking if + necessary until no other threads have the write lock acquired and no + thread has the read lock acquired. + """ + tid = current_thread().ident + assert tid not in self.acquired_readers, 'Threads which have already acquired a read lock may not lock for writing' + with self.lock: + while self.acquired_readers or (self.acquired_writer and tid not in self.acquired_writer): + self.waiting_writer_count += 1 + self.ready_for_writes.wait() + self.waiting_writer_count -= 1 + self.acquired_writer[tid] += 1 + + def release(self): + """ + Release the read or write lock held by the current thread. Since these + locks are reentrant, this method must be called once for each time the + lock was acquired. This method raises an exception if called by a + thread with no read or write lock acquired. + """ + tid = current_thread().ident + assert tid in self.acquired_readers or tid in self.acquired_writer, 'this thread does not hold a read or write lock' + with self.lock: + for counts in [self.acquired_readers, self.acquired_writer]: + counts[tid] -= 1 + if counts[tid] <= 0: + del counts[tid] + + wake_readers = not self.waiting_writer_count + wake_writers = self.waiting_writer_count and not self.acquired_readers + + if wake_writers: + with self.ready_for_writes: + self.ready_for_writes.notify() + elif wake_readers: + with self.ready_for_reads: + self.ready_for_reads.notify_all() diff --git a/sideboard/lib/_websockets.py b/sideboard/lib/_websockets.py index 8fa215d..e47b10e 100644 --- a/sideboard/lib/_websockets.py +++ b/sideboard/lib/_websockets.py @@ -4,7 +4,7 @@ import json from copy import deepcopy from itertools import count -from threading import RLock +from threading import RLock, Event from datetime import datetime, timedelta from collections import Mapping, MutableMapping @@ -59,26 +59,29 @@ def received_message(self, message): class _Subscriber(object): - def __init__(self, method, client, src_ws, dest_ws): - self.src_ws, self.dest_ws, self.method, self.client = src_ws, dest_ws, method, client + def __init__(self, method, src_client, dst_client, src_ws, dest_ws): + self.method, self.src_ws, self.dest_ws, self.src_client, self.dst_client = method, src_ws, dest_ws, src_client, dst_client def unsubscribe(self): - self.dest_ws.unsubscribe(self.client) + self.dest_ws.unsubscribe(self.dst_client) def callback(self, data): - self.src_ws.send(data=data, client=self.client) + self.src_ws.send(data=data, client=self.src_client) def errback(self, error): - self.src_ws.send(error=error, client=self.client) + self.src_ws.send(error=error, client=self.src_client) def __call__(self, *args, **kwargs): self.dest_ws.subscribe({ - 'client': self.client, + 'client': self.dst_client, 'callback': self.callback, 'errback': self.errback }, self.method, *args, **kwargs) return self.src_ws.NO_RESPONSE + def __del__(self): + self.unsubscribe() + class WebSocket(object): """ @@ -112,6 +115,17 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() + def preprocess(self, method, params): + """ + Each message we send has its parameters passed to this function and + the actual parameters sent are whatever this function returns. By + default this just returns the message unmodified, but plugins can + override this to add whatever logic is needed. We pass the method + name in its full "service.method" form in case the logic depends on + the service being invoked. + """ + return params + @property def _should_reconnect(self): interval = min(config['ws.reconnect_interval'], 2 ** self._reconnect_attempts) @@ -143,7 +157,8 @@ def _refire_subscriptions(self): try: for cb in self._callbacks.values(): if 'client' in cb: - self._send(method=cb['method'], params=cb['params'], client=cb['client']) + params = cb['paramback']() if 'paramback' in cb else cb['params'] + self._send(method=cb['method'], params=params, client=cb['client']) except: pass # self._send() already closes and logs on error @@ -196,7 +211,7 @@ def fallback(self, message): aren't valid responses to an outstanding call or subscription. By default this just logs an error message. You can override this by subclassing this class, or just by assigning a hander method, e.g. - + >>> ws = WebSocket() >>> ws.fallback = some_handler_function >>> ws.connect() @@ -245,46 +260,64 @@ def subscribe(self, callback, method, *args, **kwargs): Send a websocket request which you expect to subscribe you to a channel with a callback which will be called every time there is new data, and return the client id which uniquely identifies this subscription. - + Callback may be either a function or a dictionary in the form { 'callback': , - 'errback': + 'errback': , # optional + 'paramback: , # optional + 'client': # optional } Both callback and errback take a single argument; for callback, this is the return value of the method, for errback it is the error message returning. If no errback is specified, we will log errors at the ERROR level and do nothing further. - + + The paramback function exists for subscriptions where we might want to + pass different parameters every time we reconnect. This might be used + for e.g. time-based parameters. This function takes no arguments and + returns the parameters which should be passed every time we connect + and fire (or re-fire) all of our subscriptions. + + The client id is automatically generated if omitted, and you should not + set this yourself unless you really know what you're doing. + The positional and keyword arguments passed to this function will be - used as the arguments to the remote method. + used as the arguments to the remote method, unless paramback is passed, + in which case that will be used to generate the params, and args/kwargs + will be ignored. """ client = self._next_id('client') if isinstance(callback, Mapping): - assert 'callback' in callback and 'errback' in callback, 'callback and errback are required' + assert 'callback' in callback, 'callback is required' client = callback.setdefault('client', client) self._callbacks[client] = callback else: self._callbacks[client] = { 'client': client, - 'callback': callback, - 'errback': lambda result: log.error('{}(*{}, **{}) returned an error: {!r}', method, args, kwargs, result) + 'callback': callback } + + paramback = self._callbacks[client].get('paramback') + params = self.preprocess(method, paramback() if paramback else (args or kwargs)) + self._callbacks[client].setdefault('errback', lambda result: log.error('{}(*{}, **{}) returned an error: {!r}', method, args, kwargs, result)) self._callbacks[client].update({ 'method': method, - 'params': args or kwargs + 'params': params }) + try: - self._send(method=method, params=args or kwargs, client=client) + self._send(method=method, params=params, client=client) except: log.warn('initial subscription to {} at {!r} failed, will retry on reconnect', method, self.url) + return client def unsubscribe(self, client): """ Cancel the websocket subscription identified by the specified client id. This id is returned from the subscribe() method, e.g. - + >>> client = ws.subscribe(some_callback, 'foo.some_function') >>> ws.unsubscribe(client) """ @@ -302,20 +335,23 @@ def call(self, method, *args, **kwargs): kind was received. The positional and keyword arguments to this method are used as the arguments to the rpc function call. """ + finished = Event() result, error = [], [] callback = self._next_id('callback') self._callbacks[callback] = { - 'callback': result.append, - 'errback': error.append + 'callback': lambda response: (result.append(response), finished.set()), + 'errback': lambda response: (error.append(response), finished.set()) } + params = self.preprocess(method, args or kwargs) try: - self._send(method=method, params=args or kwargs, callback=callback) + self._send(method=method, params=params, callback=callback) except: self._callbacks.pop(callback, None) raise - for i in range(10 * config['ws.call_timeout']): - stopped.wait(0.1) + wait_until = datetime.now() + timedelta(seconds=config['ws.call_timeout']) + while datetime.now() < wait_until: + finished.wait(0.1) if stopped.is_set() or result or error: break self._callbacks.pop(callback, None) @@ -327,15 +363,33 @@ def make_caller(self, method): """ Returns a function which calls the specified method; useful for creating callbacks, e.g. - + >>> authenticate = ws.make_caller('auth.authenticate') >>> authenticate('username', 'password') True + + Sideboard supports "passthrough subscriptions", e.g. + -> a browser makes a subscription for the "foo.bar" method + -> the server has "foo" registered as a remote service + -> the server creates its own subscription to "foo.bar" on the remote + service and passes all results back to the client as they arrive + + This method implements that by checking whether it was called from a + thread with an active websocket as part of a subscription request. If + so then in addition to returning a callable, it also registers the + new subscription with the client websocket so it can be cleaned up when + the client websocket closes and/or when its subscription is canceled. """ + client = sideboard.lib.threadlocal.get_client() originating_ws = sideboard.lib.threadlocal.get('websocket') - client = sideboard.lib.threadlocal.get('message', {}).get('client') if client and originating_ws: - return _Subscriber(client=client, src_ws=originating_ws, dest_ws=self, method=method) + sub = originating_ws.passthru_subscriptions.get(client) + if sub: + sub.method = method + else: + sub = _Subscriber(method=method, src_client=client, dst_client=self._next_id('client'), src_ws=originating_ws, dest_ws=self) + originating_ws.passthru_subscriptions[client] = sub + return sub else: return lambda *args, **kwargs: self.call(method, *args, **kwargs) @@ -367,8 +421,8 @@ def query(self): @property def dirty(self): - return {k:v for k,v in self._data.items() if v != self._orig_data.get(k)} - + return {k: v for k, v in self._data.items() if v != self._orig_data.get(k)} + def to_dict(self): data = deepcopy(self._data) serialized = {k: v for k, v in data.pop(self._project_key, {}).items()} @@ -379,7 +433,7 @@ def to_dict(self): serialized[k] = data['extra_data'].pop(k) serialized.update(data) return serialized - + @property def _extra_data(self): return self._data.setdefault('extra_data', {}) @@ -418,7 +472,7 @@ def __delitem__(self, key): def __iter__(self): return iter(k for k in self.to_dict() if k != 'extra_data') - + def __repr__(self): return repr(dict(self.items())) @@ -446,16 +500,16 @@ class Subscription(object): ... def __init__(self): ... self.usernames = [] ... Subscription.__init__(self, 'admin.get_logged_in_users') - ... + ... ... def callback(self, users): ... self.usernames = [user['username'] for user in users] - ... + ... >>> users = UserList() The above code gives you a "users" object with a "usernames" attribute; when Sideboard starts, it opens a websocket connection to whichever remote server defines the "admin" - service (as defined in the rpc_services config section), then subscribes the the - "admin.get_logged_in_users" method and calls the "callback" methon on every response. + service (as defined in the rpc_services config section), then subscribes to the + "admin.get_logged_in_users" method and calls the "callback" method on every response. """ def __init__(self, rpc_method, *args, **kwargs): @@ -490,4 +544,95 @@ def _callback(self, response_data): def callback(self, response_data): """override this to define what to do with your rpc method return values""" - pass + + +class MultiSubscription(object): + """ + A version of the Subscription utility class which subscribes to an arbitrary + number of remote servers and aggregates the results from each. You invoke + this similarly to Subscription class, with two main differences: + + 1) The first parameter is a list of hostnames to which we should connect. + Each hostname will have a websocket registered for it if one does not + already exist, using the standard config options under [rpc_services]. + + 2) Unlike the Subscription class, we do not support the connect_immediately + parameter. Because this class looks in the [rpc_services] config section + of every plugin to find the client cert settings, we need to wait for all + plugins to be loaded before trying to connect. + + Like the Subscription class, you can instantiate this class directly, e.g. + + >>> logged_in_users = MultiSubscription(['host1', 'host2'], 'admin.get_logged_in_users') + >>> logged_in_users.results # this will always be the latest return values of your rpc method + + The "results" attribute is a dictionary whose keys are the websocket objects + used to connect to each host, and whose values are the latest return values + from each of those websockets. Hosts for which we have not yet received a + response will have no key/value pair in the "results" dictionary. + + If you want to do postprocessing on the results, you can subclass this and + override the "callback" method, e.g. + + >>> class UserList(MultiSubscription): + ... def __init__(self): + ... self.usernames = set() + ... MultiSubscription.__init__(self, ['host1', 'host2'], 'admin.get_logged_in_users') + ... + ... def callback(self, users, ws): + ... self.usernames.update(user['username'] for user in users) + ... + >>> users = UserList() + + The above code gives you a "users" object with a "usernames" attribute; when Sideboard + starts, it opens websocket connections to 'host1' and 'host2', then subscribes to the + "admin.get_logged_in_users" method and calls the "callback" method on every response. + """ + def __init__(self, hostnames, rpc_method, *args, **kwargs): + from sideboard.lib import listify + self.hostnames, self.method, self.args, self.kwargs = listify(hostnames), rpc_method, args, kwargs + self.results, self.websockets, self._client_ids = {}, {}, {} + on_startup(self._subscribe) + on_shutdown(self._unsubscribe) + + def _websocket(self, url, ssl_opts): + from sideboard.lib import services + return services._register_websocket(url, ssl_opts=ssl_opts) + + def _subscribe(self): + from sideboard.lib._services import _ws_url, _rpc_opts, _ssl_opts + for hostname in self.hostnames: + rpc_opts = _rpc_opts(hostname) + self.websockets[hostname] = self._websocket(_ws_url(hostname, rpc_opts), _ssl_opts(rpc_opts)) + + for ws in self.websockets.values(): + self._client_ids[ws] = ws.subscribe(self._make_callback(ws), self.method, *self.args, **self.kwargs) + + def _unsubscribe(self): + for ws in self.websockets.values(): + ws.unsubscribe(self._client_ids.get(ws)) + + def _make_callback(self, ws): + return lambda result_data: self._callback(result_data, ws) + + def _callback(self, response_data, ws): + self.results[ws] = response_data + self.callback(response_data, ws) + + def callback(self, result_data, ws): + """override this to define what to do with your rpc method return values""" + + def refresh(self): + """ + Sometimes we want to manually re-fire all of our subscription methods to + get the latest data. This is useful in cases where the remote server + isn't necessarily programmed to always push the latest data as soon as + it's available, usually for performance reasons. This method allows the + client to get the latest data more often than the server is programmed + to provide it. + """ + for ws in self.websockets.values(): + try: + self._callback(self.ws.call(self.method, *self.args, **self.kwargs), ws) + except: + log.warn('failed to fetch latest data from {} on {}', self.method, ws.url) diff --git a/sideboard/lib/sa/__init__.py b/sideboard/lib/sa/__init__.py index 187c0dc..2e23908 100644 --- a/sideboard/lib/sa/__init__.py +++ b/sideboard/lib/sa/__init__.py @@ -80,9 +80,9 @@ def process_bind_param(self, value, dialect): return str(value) else: if not isinstance(value, uuid.UUID): - return '%.32x' % uuid.UUID(value) + return uuid.UUID(value).hex else: - return '%.32x' % value + return value.hex def process_result_value(self, value, dialect): if value is None: @@ -143,15 +143,100 @@ def process_result_value(self, value, engine): __all__.append('UTCDateTime') -def declarative_base(klass): - class Mixed(klass, CrudMixin): - pass +def check_constraint_naming_convention(constraint, table): + """Creates a unique name for an unnamed CheckConstraint. + + The generated name is the SQL text of the CheckConstraint with + non-alphanumeric, non-underscore operators converted to text, and all + other non-alphanumeric, non-underscore substrings replaced by underscores. + + If the generated name is longer than 32 characters, a uuid5 based on the + generated name will be returned instead. + + >>> check_constraint_naming_convention(CheckConstraint('failed_logins > 3'), Table('account', MetaData())) + 'failed_logins_gt_3' + + See: http://docs.sqlalchemy.org/en/latest/core/constraints.html#configuring-constraint-naming-conventions + """ + # The text of the replacements doesn't matter, so long as it's unique + replacements = [ + ('||/', 'cr'), ('<=', 'le'), ('>=', 'ge'), ('<>', 'nq'), ('!=', 'ne'), + ('||', 'ct'), ('<<', 'ls'), ('>>', 'rs'), ('!!', 'fa'), ('|/', 'sr'), + ('@>', 'cn'), ('<@', 'cb'), ('&&', 'an'), ('<', 'lt'), ('=', 'eq'), + ('>', 'gt'), ('!', 'ex'), ('"', 'qt'), ('#', 'hs'), ('$', 'dl'), + ('%', 'pc'), ('&', 'am'), ('\'', 'ap'), ('(', 'lpr'), (')', 'rpr'), + ('*', 'as'), ('+', 'pl'), (',', 'cm'), ('-', 'da'), ('.', 'pd'), + ('/', 'sl'), (':', 'co'), (';', 'sc'), ('?', 'qn'), ('@', 'at'), + ('[', 'lbk'), ('\\', 'bs'), (']', 'rbk'), ('^', 'ca'), ('`', 'tk'), + ('{', 'lbc'), ('|', 'pi'), ('}', 'rbc'), ('~', 'td')] + + constraint_name = str(constraint.sqltext).strip() + for operator, text in replacements: + constraint_name = constraint_name.replace(operator, text) + + constraint_name = re.sub('[\W\s]+', '_', constraint_name) + if len(constraint_name) > 32: + constraint_name = uuid.uuid5(uuid.NAMESPACE_OID, str(constraint_name)).hex + return constraint_name + + +# SQLAlchemy doesn't expose its default constructor as a nicely importable +# function, so we grab it from the function defaults. +if six.PY2: + _spec_args, _spec_varargs, _spec_kwargs, _spec_defaults = inspect.getargspec(declarative.declarative_base) +else: + _declarative_spec = inspect.getfullargspec(declarative.declarative_base) + _spec_args, _spec_defaults = _declarative_spec.args, _declarative_spec.defaults +declarative_base_constructor = dict(zip(reversed(_spec_args), reversed(_spec_defaults)))['constructor'] + - constructor = {'constructor': klass.__init__} if '__init__' in klass.__dict__ else {} - Mixed = declarative.declarative_base(cls=Mixed, **constructor) - Mixed.BaseClass = _SessionInitializer._base_classes[klass.__module__] = Mixed - Mixed.__tablename__ = declarative.declared_attr(lambda cls: _camelcase_to_underscore(cls.__name__)) - return Mixed +def declarative_base(*orig_args, **orig_kwargs): + """ + Replacement for SQLAlchemy's declarative_base, which adds these features: + 1) This is a decorator. + 2) This allows your base class to set a constructor. + 3) This provides a default constructor which automatically sets defaults + instead of waiting to do that until the object is committed. + 4) Automatically setting __tablename__ to snake-case. + 5) Automatic integration with the SessionManager class. + """ + orig_args = list(orig_args) + + def _decorate_base_class(klass): + + class Mixed(klass, CrudMixin): + def __init__(self, *args, **kwargs): + """ + Variant on SQLAlchemy model __init__ which sets default values on + initialization instead of immediately before the model is saved. + """ + if '_model' in kwargs: + assert kwargs.pop('_model') == self.__class__.__name__ + declarative_base_constructor(self, *args, **kwargs) + for attr, col in self.__table__.columns.items(): + if kwargs.get(attr) is None and col.default: + self.__dict__.setdefault(attr, col.default.execute()) + + orig_kwargs['cls'] = Mixed + if 'name' not in orig_kwargs: + orig_kwargs['name'] = klass.__name__ + if 'constructor' not in orig_kwargs: + orig_kwargs['constructor'] = klass.__init__ if '__init__' in klass.__dict__ else Mixed.__init__ + + Mixed = declarative.declarative_base(*orig_args, **orig_kwargs) + Mixed.BaseClass = _SessionInitializer._base_classes[klass.__module__] = Mixed + Mixed.__tablename__ = declarative.declared_attr(lambda cls: _camelcase_to_underscore(cls.__name__)) + return Mixed + + is_class_decorator = not orig_kwargs and \ + len(orig_args) == 1 and \ + inspect.isclass(orig_args[0]) and \ + not isinstance(orig_args[0], sqlalchemy.engine.Connectable) + + if is_class_decorator: + return _decorate_base_class(orig_args.pop()) + else: + return _decorate_base_class class _SessionInitializer(type): @@ -205,20 +290,14 @@ def __del__(self): log.error('SessionManager went out of scope without underlying connection being closed; did you forget to use it as a context manager?') self.session.close() - """ - Initializes the database connection for use, and attempt to create any - tables registered in our metadata which do not actually exist yet in the - database. - - drop: USE WITH CAUTION: If True, then we will drop any tables in the database. - """ @classmethod - def initialize_db(cls, drop=False): + def initialize_db(cls, drop=False, create=True): configure_mappers() cls.BaseClass.metadata.bind = cls.engine if drop: cls.BaseClass.metadata.drop_all(cls.engine, checkfirst=True) - cls.BaseClass.metadata.create_all(cls.engine, checkfirst=True) + if create: + cls.BaseClass.metadata.create_all(cls.engine, checkfirst=True) @classmethod def all_models(cls): @@ -241,7 +320,7 @@ def resolve_model(cls, name): return subclasses[singular] if name.lower().endswith('ies'): - singular = name[:-3] + 'sy' # TODO: sy looks like a typo, and we need to either make this better or get rid of it + singular = name[:-3] + 'y' if singular in subclasses: return subclasses[singular] diff --git a/sideboard/lib/sa/_crud.py b/sideboard/lib/sa/_crud.py index d225ce1..246f1c7 100644 --- a/sideboard/lib/sa/_crud.py +++ b/sideboard/lib/sa/_crud.py @@ -18,7 +18,7 @@ 'comparison': 'field': , 'value': -}]+ +}] meaning an array of one or more dictionaries (a dictionary is equivalent to an array of length 1) of queries, one for each type of SQLAlchemy model object expected to be returned @@ -163,30 +163,25 @@ from functools import wraps import six -from sqlalchemy import orm +from sqlalchemy import orm, union, select, func +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound from sqlalchemy.orm.mapper import Mapper -from sqlalchemy import union, select, func +from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty from sqlalchemy.orm.util import class_mapper from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import text, ClauseElement -from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound -from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty -from sqlalchemy.types import Boolean, Text, Integer, String, UnicodeText, DateTime from sqlalchemy.sql.expression import alias, cast, label, bindparam, and_, or_, asc, desc, literal, text, union, join +from sqlalchemy.types import Boolean, Text, Integer, String, UnicodeText, DateTime -from sideboard.lib import log, notify, listify, threadlocal, serializer, is_listy +from sideboard.lib import log, notify, listify, threadlocal, serializer, is_listy, class_property class CrudException(Exception): pass -class ClassProperty(property): - def __get__(self, cls, owner): - return self.fget.__get__(None, owner)() - - def listify_with_count(x, count=None): x = listify(x) if count and len(x) < count: @@ -207,10 +202,10 @@ def mappify(value): def generate_date_series(startDate=None, endDate=None, interval='1 month', granularity='day'): if granularity: - granularity = '1 %s'%granularity + granularity = '1 %s' % granularity else: granularity = '1 day' - + generate_series = None if startDate: if endDate: @@ -218,11 +213,11 @@ def generate_date_series(startDate=None, endDate=None, interval='1 month', granu generate_series = func.generate_series(startDate, endDate, granularity) elif interval: # If the startDate and the interval are defined then we use those - generate_series = func.generate_series(startDate, + generate_series = func.generate_series(startDate, text("DATE :start_date_param_1 + INTERVAL :interval_param_1", bindparams=[ - bindparam("start_date_param_1",startDate), - bindparam("interval_param_1",interval)]), + bindparam("start_date_param_1", startDate), + bindparam("interval_param_1", interval)]), granularity) else: # If ONLY the startDate is defined then we just use that @@ -233,16 +228,16 @@ def generate_date_series(startDate=None, endDate=None, interval='1 month', granu generate_series = func.generate_series( text("DATE :current_date_param_1 - INTERVAL :interval_param_1", bindparams=[ - bindparam("current_date_param_1",endDate), - bindparam("interval_param_1",interval)]), + bindparam("current_date_param_1", endDate), + bindparam("interval_param_1", interval)]), endDate, granularity) else: # If ONLY the endDate is defined then we just use that generate_series = func.generate_series( text("DATE :current_date_param_1 - INTERVAL :interval_param_1", bindparams=[ - bindparam("current_date_param_1",endDate), - bindparam("interval_param_1","1 month")]), + bindparam("current_date_param_1", endDate), + bindparam("interval_param_1", "1 month")]), endDate, granularity) elif interval: # If ONLY the interval is defined then we default to the current date @@ -250,18 +245,18 @@ def generate_date_series(startDate=None, endDate=None, interval='1 month', granu generate_series = func.generate_series( text("DATE :current_date_param_1 - INTERVAL :interval_param_1", bindparams=[ - bindparam("current_date_param_1",datetime.utcnow()), - bindparam("interval_param_1",interval)]), + bindparam("current_date_param_1", datetime.utcnow()), + bindparam("interval_param_1", interval)]), datetime.utcnow(), granularity) else: # If NOTHING is defined then we return the query unmodified generate_series = func.generate_series( text("DATE :current_date_param_1 - INTERVAL :interval_param_1", bindparams=[ - bindparam("current_date_param_1",datetime.utcnow()), - bindparam("interval_param_1","1 month")]), + bindparam("current_date_param_1", datetime.utcnow()), + bindparam("interval_param_1", "1 month")]), datetime.utcnow(), granularity) - + return generate_series @@ -271,45 +266,45 @@ def normalize_date_query(query, dateLabel, reportLabel, startDate=None, endDate= series.label(dateLabel), literal(0).label(reportLabel) ]) - + query = union(query, seriesQuery).alias() query = select([ - text(dateLabel), + text(dateLabel), func.max(text(reportLabel)).label(reportLabel) ], from_obj=query).group_by(dateLabel).order_by(dateLabel) - + return query def normalize_object_graph(graph): """ Returns a normalized object graph given a variety of different inputs. - + If graph is a string, we assume it is a single property of an object, and return a dict with just that property set to True. - + If graph is a dict, we assume it is already a normalized graph. - + If graph is iterable (and not a string), we assume that it's simple a - list of properties, and we return a dict with those properties set to + list of properties, and we return a dict with those properties set to True. - + NOTE: This function is NOT recursive. It is intended to be repeatedly called from an external library as it traverses the object graph. We do - this for performance reasons in case the caller decides not to traverse + this for performance reasons in case the caller decides not to traverse the entire graph. - + >>> normalize_object_graph('prop') {u'prop': True} - + >>> normalize_object_graph(['prop_one', 'prop_two']) {'prop_two': True, 'prop_one': True} - + >>> normalize_object_graph({'prop_one':'test_one', 'prop_two':'test_two'}) {u'prop_two': u'test_two', u'prop_one': u'test_one'} """ if isinstance(graph, six.string_types): - return {graph:True} + return {graph: True} elif isinstance(graph, dict): return graph elif isinstance(graph, collections.Iterable): @@ -320,21 +315,21 @@ def normalize_object_graph(graph): def collect_ancestor_classes(cls, terminal_cls=None, module=None): """ - Collects all the classes in the inheritance hierarchy of the given class, + Collects all the classes in the inheritance hierarchy of the given class, including the class itself. - - If module is an object or list, we only return classes that are in one - of the given module/modules.This will exclude base classes that come + + If module is an object or list, we only return classes that are in one + of the given module/modules.This will exclude base classes that come from external libraries. - - If terminal_cls is encountered in the hierarchy, we stop ascending + + If terminal_cls is encountered in the hierarchy, we stop ascending the tree. """ if terminal_cls is None: terminal_cls = [] elif not isinstance(terminal_cls, (list, set, tuple)): terminal_cls = [terminal_cls] - + if module is not None: if not isinstance(module, (list, set, tuple)): module = [module] @@ -345,13 +340,13 @@ def collect_ancestor_classes(cls, terminal_cls=None, module=None): else: module_strings.append(m.__name__) module = module_strings - + ancestors = [] if (module is None or cls.__module__ in module) and cls not in terminal_cls: ancestors.append(cls) for base in cls.__bases__: ancestors.extend(collect_ancestor_classes(base, terminal_cls, module)) - + return ancestors @@ -377,11 +372,11 @@ def constrain_date_query(query, column, startDate=None, endDate=None, interval=' elif interval: # If the startDate and the interval are defined then we use those query = query.where(and_( - column >= startDate, + column >= startDate, column <= text("DATE :start_date_param_1 + INTERVAL :interval_param_1", bindparams=[ - bindparam("start_date_param_1",startDate), - bindparam("interval_param_1",interval)]))) + bindparam("start_date_param_1", startDate), + bindparam("interval_param_1", interval)]))) return query else: # If ONLY the startDate is defined then we just use that @@ -391,11 +386,11 @@ def constrain_date_query(query, column, startDate=None, endDate=None, interval=' if interval: # If the endDate and the interval are defined then we use those query = query.where(and_( - column <= endDate, + column <= endDate, column >= text("DATE :end_date_param_1 - INTERVAL :interval_param_1", bindparams=[ - bindparam("end_date_param_1",endDate), - bindparam("interval_param_1",interval)]))) + bindparam("end_date_param_1", endDate), + bindparam("interval_param_1", interval)]))) return query else: # If ONLY the endDate is defined then we just use that @@ -407,13 +402,13 @@ def constrain_date_query(query, column, startDate=None, endDate=None, interval=' query = query.where(and_( column >= text("DATE :current_date_param_1 - INTERVAL :interval_param_1", bindparams=[ - bindparam("current_date_param_1",datetime.utcnow()), - bindparam("interval_param_1",interval)]))) + bindparam("current_date_param_1", datetime.utcnow()), + bindparam("interval_param_1", interval)]))) return query else: # If NOTHING is defined then we return the query unmodified return query - + def extract_sort_field(model, value): field = None @@ -428,11 +423,11 @@ def extract_sort_field(model, value): field = parts[1] else: field = f - + if field and isinstance(field, six.string_types) and model: attr = getattr(model, field) - if (not (isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, ColumnProperty)) and - not isinstance(attr, ClauseElement)): + if (not (isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, ColumnProperty)) and + not isinstance(attr, ClauseElement)): raise ValueError('SQLAlchemy model classes may only be sorted ' 'by columns that exist in the database. ' 'Provided: {}.{}'.format(model.__name__, field)) @@ -442,23 +437,23 @@ def extract_sort_field(model, value): def normalize_sort(model, sort): if sort and isinstance(sort, six.string_types) and (sort.lstrip()[0] == '[' or sort.lstrip()[0] == '{'): sort = json.loads(sort) - + if isinstance(sort, six.string_types): - return [{'field':extract_sort_field(model, sort), 'dir':'asc'}] + return [{'field': extract_sort_field(model, sort), 'dir': 'asc'}] elif is_listy(sort): sorters = [] for s in sort: - sorters.extend(normalize_sort(model, s)) + sorters.extend(normalize_sort(model, s)) return sorters elif isinstance(sort, dict): field = sort.get('property', sort.get('fields', sort.get('field', []))) direction = sort.get('direction', sort.get('dir', 'asc')).lower() return [{ - 'field':extract_sort_field(model, field), - 'dir':direction + 'field': extract_sort_field(model, field), + 'dir': direction }] else: - return [{'field':'id', 'dir':'asc'}] + return [{'field': 'id', 'dir': 'asc'}] def normalize_data(data, count=1): @@ -467,15 +462,15 @@ def normalize_data(data, count=1): 'attr' ['attr1', 'attr2'] {'attr1':True, 'attr2':True} - + A plural data must be specified as a list of lists or a list of dicts: [['attr1', 'attr2'], ['attr1', 'attr2']] [{'attr1':True, 'attr2':True}, {'attr1':True, 'attr2':True}] - - Note that if data is specified as a list of strings, it is - considered to be singular. Only a list of lists or a list of + + Note that if data is specified as a list of strings, it is + considered to be singular. Only a list of lists or a list of dicts is considered plural. - + Returns the plural form of data as the comprehensive form of a list of dictionaries mapping to True, extended to count length. If a singular data is given, the result will be padded by repeating @@ -508,7 +503,7 @@ def normalize_data(data, count=1): return listify_with_count(None, count) else: if isinstance(data, six.string_types): - data = [{data:True}] + data = [{data: True}] elif isinstance(data, collections.Mapping): data = [data] elif isinstance(data, collections.Iterable): @@ -516,11 +511,11 @@ def normalize_data(data, count=1): # this is the singular list of strings case, so wrap it and # go from there data = [data] - #is this a list of strings? + # is this a list of strings? data = [mappify(v) for v in data] else: raise TypeError('unknown datatype: {}', data) - + if len(data) < count: if len(data) == 1: data.extend([deepcopy(data[0]) for i in range(count - len(data))]) @@ -532,7 +527,7 @@ def normalize_data(data, count=1): def normalize_query(query, top_level=True, supermodel=None): """ Normalizes a variety of query formats to a known standard query format. - + The comprehensive form of the query parameter is as follows: {code:python} query = [{ @@ -549,13 +544,13 @@ def normalize_query(query, top_level=True, supermodel=None): """ if query is None: raise ValueError('None passed for query parameter') - + query = listify(deepcopy(query)) - + queries = [] for q in query: if isinstance(q, six.string_types): - queries.append({'_model':q, '_label':q}) + queries.append({'_model': q, '_label': q}) elif isinstance(q, dict): if 'distinct' in q: if isinstance(q['distinct'], six.string_types): @@ -572,9 +567,9 @@ def normalize_query(query, top_level=True, supermodel=None): q[op] = normalize_query(q[op], False, q.get('_model', supermodel)) if len(q[op]) == 1: q = q[op][0] - elif not '_model' in q: + elif '_model' not in q: # Pull the _model up from the sub clauses. Technically the - # query format requires the _model be declared in the + # query format requires the _model be declared in the # clause, but we are going to be liberal in what we accept. model = supermodel for clause in q[op]: @@ -584,7 +579,7 @@ def normalize_query(query, top_level=True, supermodel=None): if model is None: raise ValueError('Clause objects must have a "_model" attribute') q['_model'] = model - + if '_model' in q: queries.append(q) elif supermodel is not None: @@ -632,7 +627,7 @@ def wrapped(*args, **kwargs): return fn(*args, **kwargs) except: a = [x for x in (args or [])] - kw = {k : v for k, v in (kwargs or {}).items()} + kw = {k: v for k, v in (kwargs or {}).items()} log.error('Error calling {}.{} {!r} {!r}'.format(fn.__module__, fn.__name__, a, kw), exc_info=True) exc_class, exc, tb = sys.exc_info() raise six.reraise(CrudException, CrudException(str(exc)), tb) @@ -645,6 +640,7 @@ class Crud(object): @staticmethod def crud_subscribes(func): func = crud_exceptions(func) + class subscriber(object): @property def subscribes(self): @@ -659,7 +655,6 @@ def __call__(self, *args, **kwargs): @staticmethod def crud_notifies(func, **settings): func = crud_exceptions(func) - delay = settings.pop('delay', 0) class notifier(object): def __call__(self, *args, **kwargs): @@ -667,7 +662,7 @@ def __call__(self, *args, **kwargs): return func(*args, **kwargs) finally: models = Crud._get_models(args, kwargs) - notify(models, trigger=func.__name__, delay=delay) + notify(models, trigger=func.__name__) return wraps(func)(notifier()) @@ -710,7 +705,7 @@ def _get_models(cls, *args, **kwargs): def _sort_query(cls, query, model, sort): sort = normalize_sort(model, sort) for sorter in sort: - dir = {'asc':asc, 'desc':desc}[sorter['dir']] + dir = {'asc': asc, 'desc': desc}[sorter['dir']] field = sorter['field'] if model: field = getattr(model, field) @@ -770,24 +765,24 @@ def _resolve_comparison(cls, comparison, column, value): value = select([getattr(model_class, field)], cls._resolve_filters(value)) return { - 'eq': lambda field, val : field == val, - 'ne': lambda field, val : field != val, - 'lt': lambda field, val : field < val, - 'le': lambda field, val : field <= val, - 'gt': lambda field, val : field > val, - 'ge': lambda field, val : field >= val, - 'in': lambda field, val : field.in_(val), - 'notin':lambda field, val : ~field.in_(val), - 'isnull' : lambda field, val : field == None, - 'isnotnull' : lambda field, val : field != None, - 'contains': lambda field, val : field.like('%'+val+'%'), - 'icontains': lambda field, val : field.ilike('%'+val+'%'), - 'like': lambda field, val : field.like('%'+val+'%'), - 'ilike': lambda field, val : field.ilike('%'+val+'%'), - 'startswith': lambda field, val : field.startswith(val), - 'endswith': lambda field, val : field.endswith(val), - 'istartswith': lambda field, val : field.ilike(val+'%'), - 'iendswith': lambda field, val : field.ilike('%'+val) + 'eq': lambda field, val: field == val, + 'ne': lambda field, val: field != val, + 'lt': lambda field, val: field < val, + 'le': lambda field, val: field <= val, + 'gt': lambda field, val: field > val, + 'ge': lambda field, val: field >= val, + 'in': lambda field, val: field.in_(val), + 'notin': lambda field, val: ~field.in_(val), + 'isnull': lambda field, val: field == None, + 'isnotnull': lambda field, val: field != None, + 'contains': lambda field, val: field.like('%'+val+'%'), + 'icontains': lambda field, val: field.ilike('%'+val+'%'), + 'like': lambda field, val: field.like('%'+val+'%'), + 'ilike': lambda field, val: field.ilike('%'+val+'%'), + 'startswith': lambda field, val: field.startswith(val), + 'endswith': lambda field, val: field.endswith(val), + 'istartswith': lambda field, val: field.ilike(val+'%'), + 'iendswith': lambda field, val: field.ilike('%'+val) }[comparison](column, value) @classmethod @@ -887,7 +882,7 @@ def count(query): @param query: Specifies the model types to count. May be a string, a list of strings, or a list of dicts with a "_model" key specified. - @return: The count of each of the supplied model types, in a list of + @return: The count of each of the supplied model types, in a list of dicts, like so: [{ '_model' : 'Player', @@ -901,8 +896,8 @@ def count(query): with Session() as session: for filter in filters: model = Session.resolve_model(filter['_model']) - result = {'_model' : filter['_model'], - '_label' : filter.get('_label', filter['_model'])} + result = {'_model': filter['_model'], + '_label': filter.get('_label', filter['_model'])} if getattr(model, '_crud_perms', {}).get('read', True): if filter.get('groupby', False): columns = [] @@ -912,7 +907,7 @@ def count(query): rows = Crud._filter_query(session.query(func.count(columns[0]), *columns), model, filter).all() result['count'] = [] for row in rows: - count = {'count' : row[0]} + count = {'count': row[0]} index = 1 for attr in filter['groupby']: count[attr] = row[index] @@ -968,7 +963,7 @@ def read(query, data=None, order=None, limit=None, offset=0): total = Crud._filter_query(session.query(model), model, filter).count() results = Crud._filter_query(session.query(model), model, filter, limit, offset, order).all() - return {'total':total, 'results':[r.crud_read(data[0]) for r in results]} + return {'total': total, 'results': [r.crud_read(data[0]) for r in results]} elif len(filters) > 1: queries = [] @@ -991,7 +986,7 @@ def read(query, data=None, order=None, limit=None, offset=0): query = queries[0].union(*(queries[1:])) normalized_sort_fields = normalize_sort(None, order) for sort_index, sort in enumerate(normalized_sort_fields): - dir = {'asc':asc, 'desc':desc}[sort['dir']] + dir = {'asc': asc, 'desc': desc}[sort['dir']] sort_field = 'anon_sort_{}'.format(sort_index) if issubclass(sort_field_types[sort_index], String): sort_field = 'lower({})'.format(sort_field) @@ -1020,9 +1015,9 @@ def read(query, data=None, order=None, limit=None, offset=0): ordered_results[result_order[instance.id]] = instance results = [r for r in ordered_results if r is not None] - return {'total':total, 'results':[r.crud_read(data[query_index_table[r.id]]) for r in results]} + return {'total': total, 'results': [r.crud_read(data[query_index_table[r.id]]) for r in results]} else: - return {'total':0, 'results':[]} + return {'total': 0, 'results': []} @crud_notifies.__func__ def create(data): @@ -1110,7 +1105,7 @@ def delete(query): class memoized(object): """ Decorator. Caches a function's return value each time it is called. - If called later with the same arguments, the cached value is returned + If called later with the same arguments, the cached value is returned (not reevaluated). from http://wiki.python.org/moin/PythonDecoratorLibrary#Memoize @@ -1118,7 +1113,7 @@ class memoized(object): def __init__(self, func): self.func = func self.cache = {} - + def __call__(self, *args): try: return self.cache[args] @@ -1130,9 +1125,11 @@ def __call__(self, *args): # uncachable -- for instance, passing a list as an argument. # Better to not cache than to blow up entirely. return self.func(*args) + def __repr__(self): """Return the function's docstring.""" return self.func.__doc__ + def __get__(self, obj, objtype): """Support instance methods.""" return functools.partial(self.__call__, obj) @@ -1222,7 +1219,7 @@ def _create_or_fetch(cls, session, value, **backref_mapping): instance = None if id is not None: try: - instance = session.query(cls).filter(cls.id==id).first() + instance = session.query(cls).filter(cls.id == id).first() except: log.error('Unable to fetch instance based on id value {!r}', value, exc_info=True) raise TypeError('Invalid instance ID type for relation: {0.__name__} (value: {1})'.format(cls, value)) @@ -1269,15 +1266,19 @@ def _type_casts_for_to_dict(self): self._to_dict_type_cast_mapping = defaultdict(lambda: lambda x: x, type_casts) return self._to_dict_type_cast_mapping - @ClassProperty - @classmethod + @class_property def to_dict_default_attrs(cls): attr_names = [] for name in collect_ancestor_attributes(cls, terminal_cls=cls.BaseClass): if not name.startswith('_') or name in cls.extra_defaults: attr = getattr(cls, name) - if isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, ColumnProperty) \ - or not isinstance(attr, (property, InstrumentedAttribute, ClauseElement)) and not callable(attr): + + is_column_property = isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, ColumnProperty) + is_hybrid_property = isinstance(getattr(attr, 'descriptor', None), hybrid_property) + is_property = isinstance(attr, (property, InstrumentedAttribute, ClauseElement)) + is_callable = callable(attr) + + if is_column_property or not (is_hybrid_property or is_property or is_callable): attr_names.append(name) return attr_names @@ -1324,7 +1325,7 @@ def cast_type(value): obj[name] = cast_type(attr) return obj - + def from_dict(self, attrs, validator=lambda self, name, val: True): relations = [] # merge_relations modifies the dictionaries that are passed to it in @@ -1339,13 +1340,13 @@ def from_dict(self, attrs, validator=lambda self, name, val: True): relations.append((name, value)) else: setattr(self, name, value) - + def required(kv): cols = list(getattr(self.__class__, kv[0]).property.local_columns) return len(cols) != 1 or cols[0].primary_key or cols[0].nullable - relations.sort(key = required) + relations.sort(key=required) - for name,value in relations: + for name, value in relations: self._merge_relations(name, value, validator) return self @@ -1382,8 +1383,8 @@ def _get_one_to_many_foreign_key_attr_name_if_applicable(cls, name): def _merge_relations(self, name, value, validator=lambda self, name, val: True): attr = getattr(self.__class__, name) - if (not isinstance(attr, InstrumentedAttribute) or - not isinstance(attr.property, RelationshipProperty)): + if (not isinstance(attr, InstrumentedAttribute) or + not isinstance(attr.property, RelationshipProperty)): return session = orm.Session.object_session(self) @@ -1408,7 +1409,7 @@ def _merge_relations(self, name, value, validator=lambda self, name, val: True): for i in value: if backref_id_name is not None and isinstance(i, dict) and not i.get(backref_id_name): i[backref_id_name] = self.id - relation_inst = relation_cls._create_or_fetch(session, i, **{backref_id_name:self.id} if backref_id_name else {}) + relation_inst = relation_cls._create_or_fetch(session, i, **{backref_id_name: self.id} if backref_id_name else {}) if isinstance(i, dict): relation_inst.from_dict(i, _crud_write_validator if relation_inst._sa_instance_state.identity else _crud_create_validator) new_insts.append(relation_inst) @@ -1615,10 +1616,10 @@ class MyModelObject(Base): type: "" } } - - @cvar never_read: a tuple of attribute names that default to being + + @cvar never_read: a tuple of attribute names that default to being not readable - @cvar never_update: a tuple of attribute names that default to being + @cvar never_update: a tuple of attribute names that default to being not updatable @cvar always_create: a tuple of attribute names that default to being always creatable @@ -1626,12 +1627,12 @@ class MyModelObject(Base): to simplify setting the same label for each and every instance of an attribute name """ - + never_read = ('metadata',) never_update = ('id',) always_create = ('id',) default_labels = {'addr': 'Address'} # TODO: allow plugins to define this; Sideboard core is not the place to encode addr/Address - + def __init__(self, can_create=True, create=None, no_create=None, read=None, no_read=None, @@ -1704,7 +1705,7 @@ def __init__(self, can_create=True, the value to be used in validation (e.g. 1000, for a max value of 1000). This is intended to support client side validation """ - + self.can_create = can_create self.can_delete = can_delete if no_update is not None and create is None: @@ -1715,37 +1716,37 @@ def __init__(self, can_create=True, self.no_update = no_update or [x for x in self.no_read if x not in self.update] self.create = create or [] self.no_create = no_create or [x for x in self.no_update if x not in self.create] - + self.no_read.extend(self.never_read) self.no_update.extend(self.never_update) - + self.data_spec = data_spec or {} - + def __call__(self, cls): def _get_crud_perms(cls): if getattr(cls, '_cached_crud_perms', False): return cls._cached_crud_perms - + crud_perms = { - 'can_create' : self.can_create, - 'can_delete' : self.can_delete, - 'read' : [], - 'update' : [], - 'create' : [] + 'can_create': self.can_create, + 'can_delete': self.can_delete, + 'read': [], + 'update': [], + 'create': [] } - + read = self.read for name in collect_ancestor_attributes(cls): if not name.startswith('_'): attr = getattr(cls, name) if (isinstance(attr, (InstrumentedAttribute, property, ClauseElement)) or - isinstance(attr, (int, float, bool, datetime, date, time, six.binary_type, six.text_type, uuid.UUID))): + isinstance(attr, (int, float, bool, datetime, date, time, six.binary_type, six.text_type, uuid.UUID))): read.append(name) read = list(set(read)) for name in read: if not self.no_read or name not in self.no_read: crud_perms['read'].append(name) - + update = self.update + deepcopy(crud_perms['read']) update = list(set(update)) for name in update: @@ -1756,11 +1757,11 @@ def _get_crud_perms(cls): attr = getattr(cls, name) if isinstance(attr, property) and getattr(attr, 'fset', False): crud_perms['update'].append(name) - elif (isinstance(attr, InstrumentedAttribute) and + elif (isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, RelationshipProperty) and attr.property.viewonly != True): crud_perms['update'].append(name) - + create = self.create + deepcopy(crud_perms['update']) for name in self.always_create: create.append(name) @@ -1770,17 +1771,17 @@ def _get_crud_perms(cls): for name in create: if not self.no_create or name not in self.no_create: crud_perms['create'].append(name) - + cls._cached_crud_perms = crud_perms return cls._cached_crud_perms - + def _get_crud_spec(cls): if getattr(cls, '_cached_crud_spec', False): return cls._cached_crud_spec - + crud_perms = cls._crud_perms - - field_names = list(set(crud_perms['read']) | set(crud_perms['update']) | + + field_names = list(set(crud_perms['read']) | set(crud_perms['update']) | set(crud_perms['create']) | set(self.data_spec.keys())) fields = {} for name in field_names: @@ -1788,9 +1789,9 @@ def _get_crud_spec(cls): # be serialized as json, it's convenient to have it in that # form early - # if using different validation decorators or in the data spec - # causes multiple spec - # kwargs to be specified, we're going to error here for + # if using different validation decorators or in the data spec + # causes multiple spec + # kwargs to be specified, we're going to error here for # duplicate keys in dictionaries. Since we don't want to allow # two different expected values for maxLength being sent in a # crud spec for example @@ -1800,10 +1801,10 @@ def _get_crud_spec(cls): for crud_validator_dict in getattr(cls, '_validators', {}).get(name, []) for spec_key_name, spec_value in crud_validator_dict.get('spec_kwargs', {}).items() } - + if field_validator_kwargs: self.data_spec.setdefault(name, {}) - # manually specified crud validator keyword arguments + # manually specified crud validator keyword arguments # overwrite the decorator-supplied keyword arguments field_validator_kwargs.update(self.data_spec[name].get('validators', {})) self.data_spec[name]['validators'] = field_validator_kwargs @@ -1819,20 +1820,20 @@ def _get_crud_spec(cls): # data_spec argument fields[name] = field continue - + field['read'] = name in crud_perms['read'] field['update'] = name in crud_perms['update'] field['create'] = name in crud_perms['create'] - + if field['read'] or field['update'] or field['create']: fields[name] = field elif name in fields: del fields[name] continue - + if 'desc' not in field and not _isdata(attr): # no des specified, and there's a relevant docstring, so use it - + # if there's 2 consecutive newlines, assume that there's a # separator in the docstring and that the top part only # is the description, if there's not, use the whole thing. @@ -1842,7 +1843,7 @@ def _get_crud_spec(cls): if doc: doc = doc.partition('\n\n')[0].replace('\n', ' ').strip() field['desc'] = doc - + if 'type' not in field: if isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, ColumnProperty): field['type'] = cls._type_map.get(type(attr.property.columns[0].type), 'auto') @@ -1859,7 +1860,7 @@ def _get_crud_spec(cls): field['defaultValue'] = attr if isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, RelationshipProperty): field['_model'] = attr.property.mapper.class_.__name__ - + crud_spec = {'fields': fields} cls._cached_crud_spec = crud_spec return cls._cached_crud_spec @@ -1867,9 +1868,9 @@ def _get_crud_spec(cls): def _type_map(cls): return dict(cls.type_map_defaults, **cls.type_map) - cls._type_map = ClassProperty(classmethod(_type_map)) - cls._crud_spec = ClassProperty(classmethod(_get_crud_spec)) - cls._crud_perms = ClassProperty(classmethod(_get_crud_perms)) + cls._type_map = class_property(_type_map) + cls._crud_spec = class_property(_get_crud_spec) + cls._crud_perms = class_property(_get_crud_perms) return cls @@ -1902,7 +1903,7 @@ def __call__(self, cls): else: # in case we subclass something with a _validators attribute cls._validators = deepcopy(cls._validators) - + cls._validators.setdefault(self.attribute_name, []).append({ 'model_validator': self.model_validator, 'validator_message': self.validator_message, @@ -1925,13 +1926,13 @@ def model_validator(instance, text): max_length is None or text_length <= max_length]) kwargs = {} - if not min_length is None: + if min_length is not None: kwargs['minLength'] = min_length - if not max_text is None: + if max_text is not None: kwargs['minLengthText'] = min_text - if not max_length is None: + if max_length is not None: kwargs['maxLength'] = max_length - if not max_text is None: + if max_text is not None: kwargs['maxLengthText'] = max_text message = 'Length of value should be between {} and {} (inclusive; None means no min/max).'.format(min_length, max_length) diff --git a/sideboard/run_debug_server.py b/sideboard/run_debug_server.py index 9b912e0..9f80019 100644 --- a/sideboard/run_debug_server.py +++ b/sideboard/run_debug_server.py @@ -1,15 +1,9 @@ from __future__ import unicode_literals -from debugger import debugger_helpers_all_init +from sideboard.debugging import debugger_helpers_all_init import cherrypy -import sideboard.server - if __name__ == '__main__': - # import pydevd - # print("running debug server2...", flush=True) - # pydevd.settrace('10.0.0.29', port=5000, stdoutToServer=True, stderrToServer=True) - debugger_helpers_all_init() cherrypy.engine.start() diff --git a/sideboard/run_mainloop.py b/sideboard/run_mainloop.py index a82bf66..664c859 100644 --- a/sideboard/run_mainloop.py +++ b/sideboard/run_mainloop.py @@ -1,6 +1,30 @@ from __future__ import unicode_literals +import os +import argparse -from sideboard.lib import mainloop +from sideboard.lib import mainloop, entry_point, log -if __name__ == '__main__': +parser = argparse.ArgumentParser(description='Run Sideboard as a daemon without starting CherryPy') +parser.add_argument('--pidfile', required=True, help='absolute path of file where process pid will be stored') + + +@entry_point +def mainloop_daemon(): + log.info('starting Sideboard daemon process') + args = parser.parse_args() + if os.fork() == 0: + pid = os.fork() + if pid == 0: + mainloop() + else: + log.debug('writing pid ({}) to pidfile ({})', pid, args.pidfile) + try: + with open(args.pidfile, 'w') as f: + f.write('{}'.format(pid)) + except: + log.error('unexpected error writing pid ({}) to pidfile ({})', pid, args.pidfile, exc_info=True) + + +@entry_point +def mainloop_foreground(): mainloop() diff --git a/sideboard/sep.py b/sideboard/sep.py index 66987d2..066c382 100644 --- a/sideboard/sep.py +++ b/sideboard/sep.py @@ -1,24 +1,29 @@ from __future__ import unicode_literals from sys import argv -from sideboard.lib import _entry_points +from sideboard.lib._utils import _entry_points + + +def print_usage(): + print('usage: {} ENTRY_POINT_NAME ...'.format(argv[0])) + print('known entry points:') + print('\n'.join([' {}'.format(ep) for ep in sorted(_entry_points)])) def run_plugin_entry_point(): if len(argv) < 2: - print('usage: {} ENTRY_POINT_NAME ...'.format(argv[0])) + print_usage() exit(1) if len(argv) == 2 and argv[1] in ['-h', '--help']: - print('known entry points:') - print('\n'.join(sorted(_entry_points))) + print_usage() exit(0) del argv[:1] # we want the entry point name to be the first argument ep_name = argv[0] if ep_name not in _entry_points: - print('no entry point exists with name {!r}'.format(ep_name)) + print('no entry point with name {!r}'.format(ep_name)) exit(2) _entry_points[ep_name]() diff --git a/sideboard/server.py b/sideboard/server.py index cdd4b98..31f400e 100755 --- a/sideboard/server.py +++ b/sideboard/server.py @@ -3,61 +3,33 @@ import sys import six -#import ldap import cherrypy import sideboard from sideboard.internal import connection_checker from sideboard.jsonrpc import _make_jsonrpc_handler -from sideboard.websockets import WebSocketDispatcher, WebSocketRoot +from sideboard.websockets import WebSocketDispatcher, WebSocketRoot, WebSocketAuthError from sideboard.lib import log, listify, config, render_with_templates, services, threadlocal +from sideboard.lib._cp import auth_registry +default_auth_checker = auth_registry[config['default_authenticator']]['check'] -def jsonrpc_auth(body): - if 'username' not in cherrypy.session: - raise cherrypy.HTTPError(401, 'not logged in') +def reset_threadlocal(): + threadlocal.reset(**{field: cherrypy.session.get(field) for field in config['ws.session_fields']}) + +cherrypy.tools.reset_threadlocal = cherrypy.Tool('before_handler', reset_threadlocal, priority=51) -def ldap_auth(username, password): - if not username or not password: - return False - - try: - conn = ldap.initialize(config['ldap.url']) - - force_start_tls = False - if config['ldap.cacert']: - ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, config['ldap.cacert']) - force_start_tls = True - - if config['ldap.cert']: - ldap.set_option(ldap.OPT_X_TLS_CERTFILE, config['ldap.cert']) - force_start_tls = True - - if config['ldap.key']: - ldap.set_option(ldap.OPT_X_TLS_KEYFILE, config['ldap.key']) - force_start_tls = True - - if force_start_tls: - conn.start_tls_s() - else: - conn.set_option(ldap.OPT_X_TLS_DEMAND, config['ldap.start_tls']) - except: - log.error('Error initializing LDAP connection', exc_info=True) - raise - - for basedn in listify(config['ldap.basedn']): - dn = '{}={},{}'.format(config['ldap.userattr'], username, basedn) - log.debug('attempting to bind with dn {}', dn) - try: - conn.simple_bind_s(dn, password) - except ldap.INVALID_CREDENTIALS as x: - continue - except: - log.warning("Error binding to LDAP server with dn", exc_info=True) - raise - else: - return True + +def jsonrpc_reset(body): + reset_threadlocal() + threadlocal.set('client', body.get('websocket_client')) + + +def jsonrpc_auth(body): + jsonrpc_reset(body) + if not default_auth_checker(): + raise cherrypy.HTTPError(401, 'not logged in') @render_with_templates(config['template_dir']) @@ -70,13 +42,16 @@ def logout(self, return_to='/'): raise cherrypy.HTTPRedirect('login?return_to=%s' % return_to) def login(self, username='', password='', message='', return_to=''): + if not config['debug']: + return 'Login page only available in debug mode.' + if username: - if ldap_auth(username, password): + if config['debug'] and password == config['debug_password']: cherrypy.session['username'] = username - raise cherrypy.HTTPRedirect(return_to) + raise cherrypy.HTTPRedirect(return_to or config['default_url']) else: message = 'Invalid credentials' - + return { 'message': message, 'username': username, @@ -93,7 +68,8 @@ def list_plugins(self): 'paths': [] } for path, app in cherrypy.tree.apps.items(): - if path: # exclude what Sideboard itself mounts + # exclude what Sideboard itself mounts and grafted mount points + if path and hasattr(app, 'root'): plugin = app.root.__module__.split('.')[0] plugin_info[plugin]['paths'].append(path) return { @@ -108,10 +84,7 @@ def connections(self): wsrpc = WebSocketRoot() json = _make_jsonrpc_handler(services.get_services(), precall=jsonrpc_auth) - jsonrpc = _make_jsonrpc_handler(services.get_services(), - precall=lambda body: threadlocal.reset( - username=cherrypy.session.get('username'), - client=body.get('websocket_client'))) + jsonrpc = _make_jsonrpc_handler(services.get_services(), precall=jsonrpc_reset) class SideboardWebSocket(WebSocketDispatcher): @@ -125,15 +98,15 @@ class SideboardWebSocket(WebSocketDispatcher): @classmethod def check_authentication(cls): host, origin = cherrypy.request.headers['host'], cherrypy.request.headers['origin'] - if ('//' + host) not in origin: + if ('//' + host.split(':')[0]) not in origin: log.error('Javascript websocket connections must follow same-origin policy; origin {!r} does not match host {!r}', origin, host) - raise ValueError('Origin and Host headers do not match') + raise WebSocketAuthError('Origin and Host headers do not match') - if config['ws.auth_required'] and 'username' not in cherrypy.session: + if config['ws.auth_required'] and not cherrypy.session.get(config['ws.auth_field']): log.warning('websocket connections to this address must have a valid session') - raise ValueError('you are not logged in') + raise WebSocketAuthError('You are not logged in') - return cherrypy.session.get('username', '') + return WebSocketDispatcher.check_authentication() class SideboardRpcWebSocket(SideboardWebSocket): @@ -146,13 +119,7 @@ class SideboardRpcWebSocket(SideboardWebSocket): @classmethod def check_authentication(cls): - return 'rpc' - - -def reset_threadlocal(): - threadlocal.reset(username=cherrypy.session.get('username')) - -cherrypy.tools.reset_threadlocal = cherrypy.Tool('before_handler', reset_threadlocal, priority=51) + return {'username': 'rpc'} app_config = { @@ -206,4 +173,9 @@ def mount(root, script_name='', config=None): cherrypy.tree.mount = mount cherrypy.tree.mount(Root(), '', app_config) -del sys.modules['six.moves.winreg'] # kludgy workaround for CherryPy's autoreloader erroring on winreg +if config['cherrypy']['profiling.on']: + # If profiling is turned on then expose the web UI, otherwise ignore it. + from sideboard.lib import Profiler + cherrypy.tree.mount(Profiler(config['cherrypy']['profiling.path']), '/profiler') + +sys.modules.pop('six.moves.winreg', None) # kludgy workaround for CherryPy's autoreloader erroring on winreg for versions which have this diff --git a/sideboard/templates/connections.html b/sideboard/templates/connections.html index 8c8b7fa..f62ac69 100644 --- a/sideboard/templates/connections.html +++ b/sideboard/templates/connections.html @@ -5,13 +5,13 @@

Sideboard Connection Tests

- ((% for service, results in connections.items()|sort %)) -

$(( service ))$

+ {% for service, results in connections.items()|sort %} +

{{ service }}

    - ((% for line in results %)) -
  • $(( line ))$
  • - ((% endfor %)) + {% for line in results %} +
  • {{ line }}
  • + {% endfor %}
- ((% endfor %)) + {% endfor %} diff --git a/sideboard/templates/list_plugins.html b/sideboard/templates/list_plugins.html index ddbdddd..efc552b 100644 --- a/sideboard/templates/list_plugins.html +++ b/sideboard/templates/list_plugins.html @@ -4,22 +4,22 @@ Sideboard Plugins -

Welcome to Sideboard (version $(( version|default('not specified', true) ))$)

+

Welcome to Sideboard (version {{ version|default('not specified', true) }})

Sideboard documentation

-

There are $(( plugins|length ))$ plugins installed

+

There are {{ plugins|length }} plugins installed

- ((% for path, plugin in plugins.items()|sort %)) + {% for path, plugin in plugins.items()|sort %} - - + + - ((% endfor %)) + {% endfor %}
$(( plugin.name ))$(version $(( plugin.version|default('not specified', true) ))$){{ plugin.name }}(version {{ plugin.version|default('not specified', true) }}) - ((% for path in plugin.paths %)) - $(( path ))$ - ((% endfor %)) + {% for path in plugin.paths %} + {{ path }} + {% endfor %}
diff --git a/sideboard/templates/login.html b/sideboard/templates/login.html index c08e4ca..dcad459 100644 --- a/sideboard/templates/login.html +++ b/sideboard/templates/login.html @@ -9,14 +9,14 @@ -
$(( message ))$
+
{{ message }}
- + - + diff --git a/sideboard/tests/__init__.py b/sideboard/tests/__init__.py index 1f8c357..9b48022 100644 --- a/sideboard/tests/__init__.py +++ b/sideboard/tests/__init__.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals import os +import socket +from contextlib import closing import pytest import sqlalchemy @@ -9,10 +11,34 @@ from sideboard.lib import config, services +def get_available_port(): + """ + Returns an unused port in the ephemeral port range. + + Binding to port 0 with socket.SO_REUSEADDR will give us an unused port + that we can immediately reuse. This is mostly safe, but on heavily used + systems there is a potential race condition if another process uses + the same port in the time between requesting an available port and + actually using it. This is unlikely, and the worst that will happen is + the tests fail on that particular test run. + + See https://eklitzke.org/binding-on-port-zero + + Ideally we could tell cherrypy to listen on port 0, and then inspect + it to determine what port it's using, but cherrypy doesn't support that + yet (other parts of cherrypy will try to use the port defined in + 'cherrypy.server.socket_port' and end up failing on startup). + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + @pytest.fixture def service_patcher(request): class TestService(object): - def __init__(self, methods): + def __init__(self, methods): self.__dict__.update(methods) def patch(name, service): @@ -25,16 +51,19 @@ def patch(name, service): request.addfinalizer(lambda: services.get_services().update({name: orig_service})) return patch + @pytest.fixture def config_patcher(request): def patch_config(value, *path, **kwargs): conf = kwargs.pop('config', config) for section in path[:-1]: conf = conf[section] - request.addfinalizer(lambda: conf.__setitem__(path[-1], value)) + orig_val = conf[path[-1]] + request.addfinalizer(lambda: conf.__setitem__(path[-1], orig_val)) conf[path[-1]] = value return patch_config + def patch_session(Session, request): orig_engine, orig_factory = Session.engine, Session.session_factory request.addfinalizer(lambda: setattr(Session, 'engine', orig_engine)) diff --git a/sideboard/tests/test_configuration.py b/sideboard/tests/test_configuration.py index 18888bb..e175a42 100644 --- a/sideboard/tests/test_configuration.py +++ b/sideboard/tests/test_configuration.py @@ -1,17 +1,49 @@ from __future__ import unicode_literals import os -from sideboard.config import parse_config +import pytest +from mock import Mock + +from sideboard.lib import config +from sideboard.config import get_config_files, get_config_overrides, get_config_root, get_module_and_root_dirs, parse_config, uniquify + + +def test_uniquify(): + pytest.raises(AssertionError, uniquify, None) + assert [] == uniquify([]) + assert ['a', 'b', 'c'] == uniquify(['a', 'b', 'c']) + assert ['a', 'b', 'c', 'd', 'e'] == uniquify(['a', 'b', 'a', 'c', 'a', 'd', 'a', 'e']) + assert ['a'] == uniquify(['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a']) + + +@pytest.mark.skipif( + 'SIDEBOARD_CONFIG_OVERRIDES' not in os.environ, + reason='SIDEBOARD_CONFIG_OVERRIDES not set') +def test_test_defaults_ini(): + """ + Verify that the tests were launched using `test-defaults.ini`. + + All of the other sideboard tests will succeed whether `test-defaults.ini` + or `development-defaults.ini` is used. This test is actually a functional + test of sorts; it verifies the test suite itself was launched with + `SIDEBOARD_CONFIG_OVERRIDES="test-defaults.ini"`. + + The test is skipped rather than failing if the tests were launched without + setting `SIDEBOARD_CONFIG_OVERRIDES`. + """ + from sideboard.lib import config + # is_test_running is ONLY set in test-defaults.ini + assert config.get('is_test_running') class SideboardConfigurationTest(object): sideboard_root = os.path.abspath(os.path.join(__file__, '..', '..', '..')) dev_test_plugin_path = os.path.join(sideboard_root, 'plugins', 'configuration-test', 'configuration_test', '__init__.py') - production_test_plugin_path = os.path.join('/', 'opt','sideboard','plugins','configuration-test', - 'env','lib','python{py_version}','site-packages', + production_test_plugin_path = os.path.join('/', 'opt', 'sideboard', 'plugins', 'configuration-test', + 'env', 'lib', 'python{py_version}', 'site-packages', 'configuration_test-{plugin_ver}-{py_version}.egg', - 'configuration_test','__init__.py') + 'configuration_test', '__init__.py') def test_parse_config_adding_root_and_module_root_for_dev(self): test_config = parse_config(self.dev_test_plugin_path) @@ -28,3 +60,142 @@ def test_parse_config_adding_root_and_module_root_for_production(self): expected_root = os.path.join('/', 'opt', 'sideboard', 'plugins', 'configuration-test') assert test_config.get('module_root') == expected_module_root assert test_config.get('root') == expected_root + + +class TestSideboardGetConfigFiles(object): + @pytest.fixture + def config_overrides_unset(self, monkeypatch): + monkeypatch.delenv('SIDEBOARD_CONFIG_OVERRIDES', raising=False) + + @pytest.fixture + def config_overrides_set(self, monkeypatch): + monkeypatch.setenv('SIDEBOARD_CONFIG_OVERRIDES', 'test-defaults.ini') + + @pytest.fixture + def plugin_dirs(self): + module_path = '/fake/sideboard/plugins/test-plugin/test_plugin' + root_path = os.path.join(config['plugins_dir'], 'test-plugin') + return (module_path, root_path) + + @pytest.fixture + def sideboard_dirs(self): + module_path = '/fake/sideboard/sideboard' + root_path = '/fake/sideboard' + return (module_path, root_path) + + def test_get_module_and_root_dirs_plugin(self, plugin_dirs): + assert plugin_dirs == get_module_and_root_dirs( + os.path.join(plugin_dirs[0], 'config.py'), is_plugin=True) + + def test_get_module_and_root_dirs_sideboard(self, sideboard_dirs): + assert sideboard_dirs == get_module_and_root_dirs( + os.path.join(sideboard_dirs[0], 'config.py'), is_plugin=False) + + def test_get_config_files_plugin(self, plugin_dirs, config_overrides_unset): + expected = [ + '/etc/sideboard/plugins.d/test-plugin.cfg', + os.path.join(plugin_dirs[1], 'development-defaults.ini'), + os.path.join(plugin_dirs[1], 'development.ini')] + assert expected == get_config_files( + os.path.join(plugin_dirs[0], 'config.py'), is_plugin=True) + + def test_get_config_files_sideboard(self, sideboard_dirs, config_overrides_unset): + expected = [ + '/etc/sideboard/sideboard-core.cfg', + '/etc/sideboard/sideboard-server.cfg', + os.path.join(sideboard_dirs[1], 'development-defaults.ini'), + os.path.join(sideboard_dirs[1], 'development.ini')] + assert expected == get_config_files( + os.path.join(sideboard_dirs[0], 'config.py'), is_plugin=False) + + def test_get_config_files_plugin_with_overrides(self, plugin_dirs, config_overrides_set): + expected = [ + '/etc/sideboard/plugins.d/test-plugin.cfg', + os.path.join(plugin_dirs[1], 'test-defaults.ini'), + os.path.join(plugin_dirs[1], 'test.ini')] + assert expected == get_config_files( + os.path.join(plugin_dirs[0], 'config.py'), is_plugin=True) + + def test_get_config_files_sideboard_with_overrides(self, sideboard_dirs, config_overrides_set): + expected = [ + '/etc/sideboard/sideboard-core.cfg', + '/etc/sideboard/sideboard-server.cfg', + os.path.join(sideboard_dirs[1], 'test-defaults.ini'), + os.path.join(sideboard_dirs[1], 'test.ini')] + assert expected == get_config_files( + os.path.join(sideboard_dirs[0], 'config.py'), is_plugin=False) + + +class TestSideboardGetConfigOverrides(object): + @pytest.fixture(params=[ + (None, ['development-defaults.ini', 'development.ini']), + ('test-defaults.ini', ['test-defaults.ini', 'test.ini']), + ('test.ini;development.ini;test.ini', ['test.ini', 'development.ini']), + ('test-defaults.ini;test-defaults.ini', ['test-defaults.ini', 'test.ini']), + (' /absolute/path.ini ', ['/absolute/path.ini']), + ('/absolute/path.cfg', ['/absolute/path.cfg']), + (' relative/path.ini ', ['relative/path.ini']), + ('relative/path.cfg', ['relative/path.cfg']), + ('/absolute/path.cfg;relative/path.ini', ['/absolute/path.cfg', 'relative/path.ini']), + ('relative/path.cfg;/absolute/path.ini', ['relative/path.cfg', '/absolute/path.ini']), + (' /absolute/path.cfg ; relative/path.ini ', ['/absolute/path.cfg', 'relative/path.ini']), + (' /absolute/path-defaults.ini ', ['/absolute/path-defaults.ini', '/absolute/path.ini']), + ('/absolute/path-defaults.cfg', ['/absolute/path-defaults.cfg', '/absolute/path.cfg']), + (' relative/path-defaults.ini ', ['relative/path-defaults.ini', 'relative/path.ini']), + ('relative/path-defaults.cfg', ['relative/path-defaults.cfg', 'relative/path.cfg']), + ('/absolute/path-defaults.cfg;relative/path-defaults.ini', [ + '/absolute/path-defaults.cfg', + '/absolute/path.cfg', + 'relative/path-defaults.ini', + 'relative/path.ini' + ]) + ]) + def config_overrides(self, request, monkeypatch): + if request.param[0] is None: + monkeypatch.delenv('SIDEBOARD_CONFIG_OVERRIDES', raising=False) + else: + monkeypatch.setenv('SIDEBOARD_CONFIG_OVERRIDES', request.param[0]) + return request.param[1] + + def test_get_config_overrides(self, config_overrides): + assert get_config_overrides() == config_overrides + + +class TestSideboardGetConfigRoot(object): + @pytest.fixture + def dir_missing(self, monkeypatch): + monkeypatch.setattr(os.path, 'isdir', Mock(return_value=False)) + + @pytest.fixture + def dir_exists(self, monkeypatch): + monkeypatch.setattr(os.path, 'isdir', Mock(return_value=True)) + + @pytest.fixture + def dir_readable(self, monkeypatch): + monkeypatch.setattr(os, 'access', Mock(return_value=True)) + + @pytest.fixture + def dir_unreadable(self, monkeypatch): + monkeypatch.setattr(os, 'access', Mock(return_value=False)) + + @pytest.fixture + def custom_root(self, monkeypatch): + monkeypatch.setitem(os.environ, 'SIDEBOARD_CONFIG_ROOT', '/custom/location') + + def test_valid_etc_sideboard(self, dir_exists, dir_readable): + assert get_config_root() == '/etc/sideboard' + + def test_no_etc_sideboard(self, dir_missing): + assert get_config_root() == '/etc/sideboard' + + def test_etc_sideboard_unreadable(self, dir_exists, dir_unreadable): + pytest.raises(AssertionError, get_config_root) + + def test_overridden_missing(self, custom_root, dir_missing): + pytest.raises(AssertionError, get_config_root) + + def test_overridden_unreadable(self, custom_root, dir_exists, dir_unreadable): + pytest.raises(AssertionError, get_config_root) + + def test_overridden_valid(self, custom_root, dir_exists, dir_readable): + assert get_config_root() == '/custom/location' diff --git a/sideboard/tests/test_imports.py b/sideboard/tests/test_imports.py new file mode 100644 index 0000000..1cb965e --- /dev/null +++ b/sideboard/tests/test_imports.py @@ -0,0 +1,27 @@ +from __future__ import unicode_literals +import os +import shutil +import tempfile +from os.path import join + +import pytest + +from sideboard.config import config +from sideboard.internal.imports import _discover_plugin_dirs + + +class TestDiscoverPluginDirs: + + def test_discover_plugin_dirs(self, monkeypatch): + plugins_dir = tempfile.mkdtemp() + try: + monkeypatch.setitem(config, 'plugins_dir', plugins_dir) + plugin_names = ['_u', 'a-1', 'z-2', 'b_3', 'y_4', 'c-6', 'x_7'] + plugin_dirs = {name: join(plugins_dir, name) for name in plugin_names} + for plugin_name, plugin_dir in plugin_dirs.items(): + os.makedirs(plugin_dir) + actual = _discover_plugin_dirs() + expected = [(name.replace('-', '_'), plugin_dirs[name]) for name in sorted(plugin_names) if name != '_u'] + assert actual == expected + finally: + shutil.rmtree(plugins_dir, ignore_errors=True) diff --git a/sideboard/tests/test_jsonrpc.py b/sideboard/tests/test_jsonrpc.py index 3cc3c7f..3848baa 100644 --- a/sideboard/tests/test_jsonrpc.py +++ b/sideboard/tests/test_jsonrpc.py @@ -14,15 +14,18 @@ def precall(): return Mock() + @pytest.fixture def raw_jsonrpc(service_patcher, precall, monkeypatch): service_patcher('test', {'get_message': lambda name: 'Hello {}!'.format(name)}) + def caller(parsed): cherrypy.request.json = parsed result = _make_jsonrpc_handler(services.get_services(), precall=precall)(self=None) return result return caller + @pytest.fixture def jsonrpc(raw_jsonrpc): def caller(method, *args, **kwargs): @@ -32,39 +35,49 @@ def caller(method, *args, **kwargs): }) return caller + def test_precall(jsonrpc, precall): jsonrpc('test.get_message', 'World') assert precall.called + def test_valid_args(jsonrpc): assert jsonrpc('test.get_message', 'World')['result'] == 'Hello World!' + def test_valid_kwargs(jsonrpc): assert jsonrpc('test.get_message', name='World')['result'] == 'Hello World!' + def test_non_object(raw_jsonrpc): response = raw_jsonrpc('not actually json') assert 'invalid json input' in response['error']['message'] + def test_no_method(raw_jsonrpc): assert '"method" field required' in raw_jsonrpc({})['error']['message'] + def test_invalid_method(jsonrpc): assert 'invalid method' in jsonrpc('')['error']['message'] assert 'invalid method' in jsonrpc('no_module')['error']['message'] assert 'invalid method' in jsonrpc('too.many.modules')['error']['message'] + def test_missing_module(jsonrpc): assert 'no module' in jsonrpc('invalid.module')['error']['message'] + def test_missing_function(jsonrpc): assert 'no function' in jsonrpc('test.does_not_exist')['error']['message'] + def test_invalid_params(raw_jsonrpc): assert 'invalid parameter list' in raw_jsonrpc({ 'method': 'test.get_message', 'params': 'not a list or dict' })['error']['message'] + def test_exception(jsonrpc): assert 'unexpected error' in jsonrpc('test.get_message')['error']['message'] diff --git a/sideboard/tests/test_lib.py b/sideboard/tests/test_lib.py index e5f0cab..a1b5a21 100644 --- a/sideboard/tests/test_lib.py +++ b/sideboard/tests/test_lib.py @@ -1,15 +1,20 @@ from __future__ import unicode_literals import json +from time import sleep +from itertools import count from unittest import TestCase from datetime import datetime, date from collections import Sequence, Set +from threading import current_thread, Thread import six import pytest import cherrypy +from mock import Mock from sideboard.lib._services import _Services -from sideboard.lib import Model, serializer, ajax, is_listy, log +from sideboard.websockets import local_broadcast, local_subscriptions, local_broadcaster +from sideboard.lib import Model, serializer, ajax, is_listy, log, notify, locally_subscribes, cached_property, request_cached_property, threadlocal, register_authenticator, restricted, all_restricted, RWGuard class TestServices(TestCase): @@ -30,6 +35,18 @@ def test_service_preregistration_getattr(self): self.services.register(self, 'foo') foo.assertTrue(True) + def test_method_whitelisting(self): + """ + When __all__ is defined for a service, we should raise an exception if + a client calls a method whose name is not inclueded in __all__. + """ + self.__all__ = ['bar'] + self.bar = self.baz = lambda: 'Hello World' + self.services.register(self, 'foo') + assert 'Hello World' == self.services.foo.bar() + with pytest.raises(AssertionError): + self.services.foo.baz() + class TestModel(TestCase): def assert_model(self, data, unpromoted=None): @@ -40,11 +57,11 @@ def assert_model(self, data, unpromoted=None): self.assertEqual(5, model['foo']) self.assertEqual({'baz': 'baf'}, model.bar) self.assertEqual({'baz': 'baf'}, model['bar']) - + def test_missing_key(self): model = Model({}, 'test') self.assertIs(None, model.does_not_exist) - + def test_id_unsettable(self): model = Model({'id': 'some_uuid'}, 'test') model.id = 'some_uuid' @@ -54,7 +71,7 @@ def test_id_unsettable(self): model.id = 'another_uuid' with self.assertRaises(Exception): model['id'] = 'another_uuid' - + def test_extra_data_only(self): d = { 'id': 'some_uuid', @@ -167,7 +184,7 @@ class TestModel(Model): self.assertEqual('bar', model.foo) self.assertNotIn('foo', model._data) self.assertEqual('bar', model._data['extra_data']['foo']) - + def test_defaults(self): data = { 'extra_data': { @@ -179,7 +196,7 @@ def test_defaults(self): }, 'baf': -4 } - model = Model(data, 'test', {'bar','baf','fizz'}, { + model = Model(data, 'test', {'bar', 'baf', 'fizz'}, { 'foo': 1, 'bar': 2, 'baz': 3, @@ -200,7 +217,7 @@ def test_defaults(self): self.assertEqual(model.baf, 14) self.assertEqual(model.fizz, 5) self.assertEqual(model.buzz, 6) - + def test_to_dict(self): data = { 'id': 'some_uuid', @@ -222,7 +239,7 @@ def test_to_dict(self): self.assertEqual(model.to_dict(), serialized) serialized.pop('extra_data') self.assertEqual(dict(model), serialized) - + def test_query(self): model = Model({'_model': 'Test', 'id': 'some_uuid'}, 'test') self.assertEqual(model.query, { @@ -233,7 +250,7 @@ def test_query(self): for data in [{}, {'_model': 'Test'}, {'id': 'some_uuid'}]: with self.assertRaises(Exception): Model(data, 'test').query - + def test_dirty(self): data = { 'id': 'some_uuid', @@ -246,23 +263,23 @@ def test_dirty(self): } } self.assertEqual(Model(data, 'test').dirty, {}) - + model = Model(data, 'test') model.spam = 'nee' self.assertEqual(model.dirty, {'spam': 'nee'}) - + model = Model(data, 'test') model.foo = 6 self.assertEqual(model.dirty, {'extra_data': {}, 'test_data': {'foo': 6, 'bar': {'baz': 'baf'}}}) - + model = Model(data, 'test') model.bar = {'fizz': 'buzz'} self.assertEqual(model.dirty, {'test_data': {'bar': {'fizz': 'buzz'}}}) - + model = Model(data, 'test') model.bar['baz'] = 'zab' self.assertEqual(model.dirty, {'test_data': {'bar': {'baz': 'zab'}}}) - + model = Model(data, 'test') model.foo = 6 model.bar = 'baz' @@ -277,7 +294,7 @@ def test_dirty(self): }, 'extra_data': {} }) - + model = Model({}, 'test') model.foo = 'bar' self.assertEqual(model.dirty, {'extra_data': {'test_foo': 'bar'}}) @@ -287,28 +304,33 @@ class TestSerializer(TestCase): class Foo(object): def __init__(self, x): self.x = x - - class Bar(Foo): pass - + + class Bar(Foo): + pass + def setUp(self): self.addCleanup(setattr, serializer, '_registry', serializer._registry.copy()) def test_date(self): d = date(2001, 2, 3) assert '"2001-02-03"' == json.dumps(d, cls=serializer) - + def test_datetime(self): dt = datetime(2001, 2, 3, 4, 5, 6) assert '"{}"'.format(dt.strftime(serializer._datetime_format)) == json.dumps(dt, cls=serializer) - + + def test_set(self): + st = set(['ya', 'ba', 'da', 'ba', 'da', 'ba', 'doo']) + assert '["ba", "da", "doo", "ya"]' == json.dumps(st, cls=serializer) + def test_duplicate_registration(self): pytest.raises(Exception, serializer.register, datetime, lambda dt: None) - + def test_new_type(self): serializer.register(self.Foo, lambda foo: foo.x) assert '5' == json.dumps(self.Foo(5), cls=serializer) assert '6' == json.dumps(self.Foo(6), cls=serializer) - + def test_new_type_subclass(self): serializer.register(self.Foo, lambda foo: 'Hello World!') serializer.register(self.Bar, lambda bar: 'Hello Kitty!') @@ -317,17 +339,17 @@ def test_new_type_subclass(self): """ Here are some cases which are currently undefined (and I'm okay with it): - + class Foo(object): pass class Bar(object): pass class Baz(Foo, Bar): pass class Baf(Foo): pass class Bax(Foo): pass - + serializer.register(Foo, foo_preprocessor) serializer.register(Bar, bar_preprocessor) serializer.register(Baf, baf_preprocessor) - + json.dumps(Baz(), cls=serializer) # undefined which function will be used json.dumps(Bax(), cls=serializer) # undefined which function will be used """ @@ -362,26 +384,31 @@ def test_user_defined_types(self): class AlwaysEmptySequence(Sequence): def __len__(self): return 0 + def __getitem__(self, i): return [][i] assert is_listy(AlwaysEmptySequence()) class AlwaysEmptySet(Set): def __len__(self): return 0 + def __iter__(self): return iter([]) + def __contains__(self, x): return False assert is_listy(AlwaysEmptySet()) def test_miscellaneous(self): - class Foo(object): pass - + class Foo(object): + pass + for x in [0, 1, False, True, Foo, object, object()]: assert not is_listy(x) def test_double_mount(request): - class Root(object): pass + class Root(object): + pass request.addfinalizer(lambda: cherrypy.tree.apps.pop('/test', None)) cherrypy.tree.mount(Root(), '/test') pytest.raises(Exception, cherrypy.tree.mount, Root(), '/test') @@ -397,3 +424,216 @@ def returns_date(self): def test_trace_logging(): log.trace('normally this would be an error') + + +class TestLocallySubscribes(object): + @pytest.yield_fixture(autouse=True) + def counter(self): + _counter = count() + + @locally_subscribes('foo', 'bar') + def counter(): + return next(_counter) + + yield counter + local_subscriptions.clear() + + def test_basic(self, counter): + local_broadcast(['foo', 'bar']) + assert 1 == counter() # was only called once even though it matched multiple channels + + def test_exception(self): + errored = Mock(side_effect=ValueError) + working = Mock() + locally_subscribes('foo')(errored) + locally_subscribes('foo')(working) + local_broadcast('foo') + assert errored.called and working.called # exception didn't halt execution + + def test_notify_triggers_local_updates(self, monkeypatch): + monkeypatch.setattr(local_broadcaster, 'defer', Mock()) + notify('foo') + local_broadcaster.defer.assert_called_with(['foo'], trigger='manual', originating_client=None) + + +def test_cached_property(): + class Foo(object): + @cached_property + def bar(self): + return 5 + + foo = Foo() + assert not hasattr(foo, '_cached_bar') + assert 5 == foo.bar + assert 5 == foo._cached_bar + foo._cached_bar = 6 + assert 6 == foo.bar + assert 5 == Foo().bar # per-instance caching + + +def test_request_cached_property(): + class Foo(object): + @request_cached_property + def bar(self): + return 5 + + name = __name__ + '.bar' + foo = Foo() + assert threadlocal.get(name) is None + assert 5 == foo.bar + assert 5 == threadlocal.get(name) + threadlocal.set(name, 6) + assert 6 == foo.bar + assert 6 == Foo().bar # cache is shared between instances + + +class TestPluggableAuth(object): + @pytest.fixture(scope='session', autouse=True) + def mock_authenticator(self): + register_authenticator('test', '/mock_login_page', lambda: 'uid' in cherrypy.session) + + @pytest.fixture(autouse=True) + def mock_session(self, monkeypatch): + monkeypatch.setattr(cherrypy, 'session', {}, raising=False) + + def mock_login(self): + cherrypy.session['uid'] = 123 + + def test_double_registration(self): + pytest.raises(Exception, register_authenticator, 'test', 'already registered', lambda: 'this will not register due to an exception') + + def test_unknown_authenticator(self): + pytest.raises(Exception, all_restricted, 'unknown_authenticator') + + def test_all_restricted(self): + self.called = False + + @all_restricted('test') + class AllRestricted(object): + def index(inner_self): + self.called = True + + with pytest.raises(cherrypy.HTTPRedirect) as exc: + AllRestricted().index() + assert not self.called and exc.value.args[0][0].endswith('/mock_login_page') + + self.mock_login() + AllRestricted().index() + assert self.called + + def test_restricted(self): + self.called = False + + class SingleRestricted(object): + @restricted('test') + def index(inner_self): + self.called = True + + with pytest.raises(cherrypy.HTTPRedirect) as exc: + SingleRestricted().index() + assert not self.called and exc.value.args[0][0].endswith('/mock_login_page') + + self.mock_login() + SingleRestricted().index() + assert self.called + + +class TestRWGuard(object): + @pytest.fixture + def guard(self, monkeypatch): + guard = RWGuard() + monkeypatch.setattr(guard.ready_for_writes, 'notify', Mock()) + monkeypatch.setattr(guard.ready_for_reads, 'notify_all', Mock()) + return guard + + def test_read_locked_tracking(self, guard): + assert {} == guard.acquired_readers + with guard.read_locked: + assert {current_thread().ident: 1} == guard.acquired_readers + with guard.read_locked: + assert {current_thread().ident: 2} == guard.acquired_readers + assert {current_thread().ident: 1} == guard.acquired_readers + assert {} == guard.acquired_readers + + def test_write_locked_tracking(self, guard): + assert {} == guard.acquired_writer + with guard.write_locked: + assert {current_thread().ident: 1} == guard.acquired_writer + with guard.write_locked: + assert {current_thread().ident: 2} == guard.acquired_writer + assert {current_thread().ident: 1} == guard.acquired_writer + assert {} == guard.acquired_writer + + def test_multi_read_locking_allowed(self, guard): + guard.acquired_readers['mock-thread-ident'] = 1 + with guard.read_locked: + pass + + def test_read_write_exclusion(self, guard): + with guard.read_locked: + with pytest.raises(AssertionError): + with guard.write_locked: + pass + + def test_write_read_exclusion(self, guard): + with guard.write_locked: + with pytest.raises(Exception): + with guard.read_locked: + pass + + def test_release_requires_acquisition(self, guard): + pytest.raises(AssertionError, guard.release) + + def test_wake_readers(self, guard): + with guard.read_locked: + guard.waiting_writer_count = 1 + assert not guard.ready_for_reads.notify_all.called + + guard.waiting_writer_count = 0 + with guard.read_locked: + pass + assert guard.ready_for_reads.notify_all.called + + def test_wake_writers(self, guard): + with guard.write_locked: + guard.acquired_readers['mock-tid'] = 1 + guard.waiting_writer_count = 1 + assert not guard.ready_for_writes.notify.called + + guard.acquired_readers.clear() + with guard.write_locked: + guard.waiting_writer_count = 0 + assert not guard.ready_for_writes.notify.called + + with guard.write_locked: + guard.waiting_writer_count = 1 + assert guard.ready_for_writes.notify.called + + def test_threading(self): + guard = RWGuard() + read, written = [False], [False] + + def reader(): + with guard.read_locked: + read[0] = True + + def writer(): + with guard.write_locked: + written[0] = True + + with guard.write_locked: + Thread(target=reader).start() + Thread(target=writer).start() + sleep(0.1) + assert not read[0] and not written[0] + sleep(0.1) + assert read[0] and written[0] + + read, written = [False], [False] + with guard.read_locked: + Thread(target=reader).start() + Thread(target=writer).start() + sleep(0.1) + assert read[0] and not written[0] + sleep(0.1) + assert read[0] and written[0] diff --git a/sideboard/tests/test_logging.py b/sideboard/tests/test_logging.py index f39d454..f0043ee 100644 --- a/sideboard/tests/test_logging.py +++ b/sideboard/tests/test_logging.py @@ -18,4 +18,3 @@ def test_importing_sideboard_doesnt_break_dummy_logger(self): dummy_logger = self._logger('dummy', stream) dummy_logger.warning('do not break dummy logger') assert stream.getvalue() == 'do not break dummy logger\n' - diff --git a/sideboard/tests/test_profiler.py b/sideboard/tests/test_profiler.py new file mode 100644 index 0000000..57e4894 --- /dev/null +++ b/sideboard/tests/test_profiler.py @@ -0,0 +1,17 @@ +from __future__ import unicode_literals +from sideboard.lib import cleanup_profiler, profile +from sideboard.config import config + + +def some_function(): + pass + + +def test_profile_is_noop(monkeypatch): + monkeypatch.setitem(config['cherrypy'], 'profiling.on', False) + profiled = profile(some_function) + assert profiled is some_function + + monkeypatch.setitem(config['cherrypy'], 'profiling.on', True) + profiled = profile(some_function) + assert profiled is not some_function diff --git a/sideboard/tests/test_sa.py b/sideboard/tests/test_sa.py index d97dc52..8ec457f 100644 --- a/sideboard/tests/test_sa.py +++ b/sideboard/tests/test_sa.py @@ -6,14 +6,17 @@ import pytest import sqlalchemy +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship from sqlalchemy.types import Boolean, Integer, UnicodeText -from sqlalchemy.schema import Column, ForeignKey, UniqueConstraint +from sqlalchemy.schema import Column, CheckConstraint, ForeignKey, MetaData, Table, UniqueConstraint +from sqlalchemy.sql import case from sideboard.lib import log, listify from sideboard.tests import patch_session from sideboard.lib.sa._crud import normalize_query, collect_ancestor_classes -from sideboard.lib.sa import SessionManager, UUID, JSON, declarative_base, CrudException, crudable, text_length_validation, regex_validation +from sideboard.lib.sa import check_constraint_naming_convention, crudable, declarative_base, \ + regex_validation, text_length_validation, CrudException, JSON, SessionManager, UUID @declarative_base @@ -21,7 +24,7 @@ class Base(object): id = Column(UUID(), primary_key=True, default=uuid.uuid4) -@crudable(update=['tags','employees']) +@crudable(update=['tags', 'employees']) @text_length_validation('name', 1, 100) class User(Base): name = Column(UnicodeText(), nullable=False, unique=True) @@ -111,6 +114,18 @@ def settable_property(self): def settable_property(self, thing): pass + @hybrid_property + def string_and_int_hybrid_property(self): + """this is the docstring""" + return '{} {}'.format(self.string_model_attr, self.int_model_attr) + + @string_and_int_hybrid_property.expression + def string_and_int_hybrid_property(cls): + return case([ + (cls.string_model_attr == None, ''), + (cls.int_model_attr == None, '') + ], else_=(cls.string_model_attr + ' ' + cls.int_model_attr)) + @property def unsettable_property(self): """ @@ -164,7 +179,8 @@ def query_from(obj, attr='id'): @pytest.fixture(scope='module') def init_db(request): - class db: pass + class db: + pass patch_session(Session, request) db.turner = create('User', name='Turner') db.hooch = create('User', name='Hooch') @@ -185,6 +201,55 @@ def db(request, init_db): return init_db +class TestNamingConventions(object): + + @pytest.mark.parametrize('sqltext,expected', [ + ('failed_logins >= 3', 'failed_logins_ge_3'), + ('failed_logins > 3', 'failed_logins_gt_3'), + (' failed_logins = 3 ', 'failed_logins_eq_3'), + ('0123456789012345678901234567890123', '1e4008bc148c5486a3c92b2377fa1c45') + ]) + def test_check_constraint_naming_convention(self, sqltext, expected): + check_constraint = CheckConstraint(sqltext) + table = Table('account', MetaData()) + result = check_constraint_naming_convention(check_constraint, table) + assert result == expected + + +class TestDeclarativeBaseConstructor(object): + def test_default_init(self): + assert User().id # default is applied at initialization instead of on save + + def test_overriden_init(self): + @declarative_base + class WithOverriddenInit(object): + id = Column(UUID(), primary_key=True, default=uuid.uuid4) + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class Foo(WithOverriddenInit): + bar = Column(Boolean()) + + assert Foo().id is None + + def test_declarative_base_without_parameters(self): + + @declarative_base + class BaseTest: + pass + + assert BaseTest.__tablename__ == 'base_test' + + def test_declarative_base_with_parameters(self): + + @declarative_base(name=str('NameOverride')) + class BaseTest: + pass + + assert BaseTest.__tablename__ == 'name_override' + + class TestCrudCount(object): def assert_counts(self, query, **expected): actual = {count['_label']: count['count'] for count in Session.crud.count(query)} @@ -194,15 +259,15 @@ def assert_counts(self, query, **expected): def test_subquery(self): results = Session.crud.count({ - '_model' : 'Tag', - 'groupby' : ['name'], - 'field' : 'user_id', - 'comparison' : 'in', - 'value' : { - '_model' : 'User', - 'select' : 'id', - 'field' : 'name', - 'value' : 'Turner' + '_model': 'Tag', + 'groupby': ['name'], + 'field': 'user_id', + 'comparison': 'in', + 'value': { + '_model': 'User', + 'select': 'id', + 'field': 'name', + 'value': 'Turner' } }) expected = { @@ -214,19 +279,19 @@ def test_subquery(self): def test_compound_subquery(self): query = { - '_model' : 'Tag', - 'groupby' : ['name'], - 'field' : 'user_id', - 'comparison' : 'in', - 'value' : { - '_model' : 'User', - 'select' : 'id', - 'or' : [{ - 'field' : 'name', - 'value' : 'Turner' + '_model': 'Tag', + 'groupby': ['name'], + 'field': 'user_id', + 'comparison': 'in', + 'value': { + '_model': 'User', + 'select': 'id', + 'or': [{ + 'field': 'name', + 'value': 'Turner' }, { - 'field' : 'name', - 'value' : 'Hooch' + 'field': 'name', + 'value': 'Hooch' }] } } @@ -246,14 +311,14 @@ def test_distinct(self): results = Session.crud.count({ '_model': 'Tag', - 'distinct' : ['name'] + 'distinct': ['name'] }) results[0]['count'] == 3 def test_groupby(self): results = Session.crud.count({ '_model': 'Tag', - 'groupby' : ['name'] + 'groupby': ['name'] }) expected = { 'Male': 2, @@ -267,20 +332,20 @@ def test_single_basic_query_string(self): self.assert_counts('User', User=2) def test_single_basic_query_dict(self): - self.assert_counts({'_model' : 'User'}, User=2) + self.assert_counts({'_model': 'User'}, User=2) def test_multi_basic_query_string(self): self.assert_counts(['User', 'Tag'], User=2, Tag=4) def test_multi_basic_query_dict(self): - self.assert_counts([{'_model' : 'User'}, {'_model' : 'Tag'}], User=2, Tag=4) + self.assert_counts([{'_model': 'User'}, {'_model': 'Tag'}], User=2, Tag=4) def test_single_complex_query(self): - self.assert_counts({'_label': 'HoochCount', '_model': 'User', 'field': 'name', 'value' : 'Hooch'}, HoochCount=1) + self.assert_counts({'_label': 'HoochCount', '_model': 'User', 'field': 'name', 'value': 'Hooch'}, HoochCount=1) def test_multi_complex_query(self): - self.assert_counts([{'_label': 'HoochCount', '_model': 'User', 'field': 'name', 'value' : 'Hooch'}, - {'_label': 'MaleCount', '_model': 'Tag', 'field': 'name', 'value' : 'Male'}], + self.assert_counts([{'_label': 'HoochCount', '_model': 'User', 'field': 'name', 'value': 'Hooch'}, + {'_label': 'MaleCount', '_model': 'Tag', 'field': 'name', 'value': 'Male'}], HoochCount=1, MaleCount=2) def test_multi_complex_query_with_same_models(self): @@ -317,7 +382,23 @@ def assert_read_result(self, expected, query, data=None): actual = Session.crud.read(query, data) assert len(expected) == actual['total'] assert sorted(expected, key=lambda m: m.get('id', m.get('_model'))) \ - == sorted(actual['results'], key=lambda m: m.get('id', m.get('_model'))) + == sorted(actual['results'], key=lambda m: m.get('id', m.get('_model'))) + + def test_to_dict_default_attrs(self): + expected = [ + 'bool_attr', + 'bool_model_attr', + 'date_attr', + 'extra_data', + 'float_attr', + 'id', + 'int_attr', + 'int_model_attr', + 'mixed_in_attr', + 'string_attr', + 'string_model_attr'] + actual = CrudableClass.to_dict_default_attrs + assert sorted(expected) == sorted(actual) def test_subquery(self): results = Session.crud.read({ @@ -362,21 +443,21 @@ def test_distinct(self): pytest.skip('Query.distinct(*columns) is postgresql-only') results = Session.crud.read({ '_model': 'Tag', - 'distinct' : ['name'] + 'distinct': ['name'] }) assert results['total'] == 3 assert len(results['results']) == 3 results = Session.crud.read({ '_model': 'Tag', - 'distinct' : True + 'distinct': True }) assert results['total'] == 4 assert len(results['results']) == 4 results = Session.crud.read({ '_model': 'Tag', - 'distinct' : ['name', 'id'] + 'distinct': ['name', 'id'] }) assert results['total'] == 4 assert len(results['results']) == 4 @@ -697,51 +778,51 @@ def test_regex(self): class TestNormalizeQuery(object): def test_one_string(self): results = normalize_query('Human') - assert results == [{'_model':'Human', '_label':'Human'}] + assert results == [{'_model': 'Human', '_label': 'Human'}] def test_one_string_in_a_list(self): results = normalize_query(['Human']) - assert results == [{'_model':'Human', '_label':'Human'}] + assert results == [{'_model': 'Human', '_label': 'Human'}] def test_two_strings(self): results = normalize_query(['Human', 'Proxy']) - assert results == [{'_model':'Human', '_label':'Human'}, {'_model':'Proxy', '_label':'Proxy'}] + assert results == [{'_model': 'Human', '_label': 'Human'}, {'_model': 'Proxy', '_label': 'Proxy'}] results = normalize_query(['Proxy', 'Human']) - assert results == [{'_model':'Proxy', '_label':'Proxy'}, {'_model':'Human', '_label':'Human'}] + assert results == [{'_model': 'Proxy', '_label': 'Proxy'}, {'_model': 'Human', '_label': 'Human'}] def test_one_dict(self): - results = normalize_query({'_model':'Human'}) - assert results == [{'_model':'Human'}] + results = normalize_query({'_model': 'Human'}) + assert results == [{'_model': 'Human'}] def test_one_dict_in_a_list(self): - results = normalize_query([{'_model':'Human'}]) - assert results == [{'_model':'Human'}] + results = normalize_query([{'_model': 'Human'}]) + assert results == [{'_model': 'Human'}] def test_two_dicts(self): - results = normalize_query([{'_model':'Human'}, {'_model':'Proxy'}]) - assert results == [{'_model':'Human'}, {'_model':'Proxy'}] - results = normalize_query([{'_model':'Proxy'}, {'_model':'Human'}]) - assert results == [{'_model':'Proxy'}, {'_model':'Human'}] + results = normalize_query([{'_model': 'Human'}, {'_model': 'Proxy'}]) + assert results == [{'_model': 'Human'}, {'_model': 'Proxy'}] + results = normalize_query([{'_model': 'Proxy'}, {'_model': 'Human'}]) + assert results == [{'_model': 'Proxy'}, {'_model': 'Human'}] def test_or_clause(self): - results = normalize_query([{'_model':'Human', 'or':[{'_model':'Human', 'field':'nickname', 'value':'Johnny'}, {'_model':'Human', 'field':'nickname', 'value':'Winny'}]}, {'_model':'Proxy'}]) - assert results == [{'_model':'Human', 'or':[{'_model':'Human', 'field':'nickname', 'value':'Johnny'}, {'_model':'Human', 'field':'nickname', 'value':'Winny'}]}, {'_model':'Proxy'}] + results = normalize_query([{'_model': 'Human', 'or': [{'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Winny'}]}, {'_model': 'Proxy'}]) + assert results == [{'_model': 'Human', 'or': [{'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Winny'}]}, {'_model': 'Proxy'}] def test_and_clause_push_down_supermodel(self): - results = normalize_query([{'_model':'Human', 'or':[{'field':'nickname', 'value':'Johnny'}, {'field':'nickname', 'value':'Winny'}]}, {'_model':'Proxy'}]) - assert results == [{'_model':'Human', 'or':[{'_model':'Human', 'field':'nickname', 'value':'Johnny'}, {'_model':'Human', 'field':'nickname', 'value':'Winny'}]}, {'_model':'Proxy'}] + results = normalize_query([{'_model': 'Human', 'or': [{'field': 'nickname', 'value': 'Johnny'}, {'field': 'nickname', 'value': 'Winny'}]}, {'_model': 'Proxy'}]) + assert results == [{'_model': 'Human', 'or': [{'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Winny'}]}, {'_model': 'Proxy'}] def test_or_clause_no_model(self): - results = normalize_query([{'or':[{'_model':'Human'}, {'_model':'Human', 'field':'nickname', 'value':'Johnny'}]}, {'_model':'Proxy'}]) - assert results == [{'_model':'Human', 'or':[{'_model':'Human'}, {'_model':'Human', 'field':'nickname', 'value':'Johnny'}]}, {'_model':'Proxy'}] + results = normalize_query([{'or': [{'_model': 'Human'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}]}, {'_model': 'Proxy'}]) + assert results == [{'_model': 'Human', 'or': [{'_model': 'Human'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}]}, {'_model': 'Proxy'}] def test_and_clause(self): - results = normalize_query([{'_model':'Human', 'and':[{'_model':'Human', 'field':'nickname', 'value':'Johnny'}, {'_model':'Human', 'field':'nickname', 'value':'Winny'}]}, {'_model':'Proxy'}]) - assert results == [{'_model':'Human', 'and':[{'_model':'Human', 'field':'nickname', 'value':'Johnny'}, {'_model':'Human', 'field':'nickname', 'value':'Winny'}]}, {'_model':'Proxy'}] + results = normalize_query([{'_model': 'Human', 'and': [{'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Winny'}]}, {'_model': 'Proxy'}]) + assert results == [{'_model': 'Human', 'and': [{'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Winny'}]}, {'_model': 'Proxy'}] def test_and_clause_no_model(self): - results = normalize_query([{'and':[{'_model':'Human'}, {'_model':'Human', 'field':'nickname', 'value':'Johnny'}]}, {'_model':'Proxy'}]) - assert results == [{'_model':'Human', 'and':[{'_model':'Human'}, {'_model':'Human', 'field':'nickname', 'value':'Johnny'}]}, {'_model':'Proxy'}] + results = normalize_query([{'and': [{'_model': 'Human'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}]}, {'_model': 'Proxy'}]) + assert results == [{'_model': 'Human', 'and': [{'_model': 'Human'}, {'_model': 'Human', 'field': 'nickname', 'value': 'Johnny'}]}, {'_model': 'Proxy'}] def test_fails_or_clause_list_of_lists(self): pytest.raises(ValueError, normalize_query, [{'or': [[], []]}, {'_model': 'Proxy', '_label': 'Proxy'}]) @@ -756,31 +837,31 @@ def test_fails_one_empty_dict(self): pytest.raises(ValueError, normalize_query, {}) def test_fails_one_dict_no_model(self): - pytest.raises(ValueError, normalize_query, {'field':'nickname', 'value':'Johnny'}) + pytest.raises(ValueError, normalize_query, {'field': 'nickname', 'value': 'Johnny'}) def test_fails_one_empty_dict_in_a_list(self): pytest.raises(ValueError, normalize_query, [{}]) def test_fails_one_dict_no_model_in_a_list(self): - pytest.raises(ValueError, normalize_query, [{'field':'nickname', 'value':'Johnny'}]) + pytest.raises(ValueError, normalize_query, [{'field': 'nickname', 'value': 'Johnny'}]) def test_fails_two_dicts_one_without_model(self): - pytest.raises(ValueError, normalize_query, [{'_model':'Proxy'}, {'field':'nickname', 'value':'Johnny'}]) + pytest.raises(ValueError, normalize_query, [{'_model': 'Proxy'}, {'field': 'nickname', 'value': 'Johnny'}]) def test_fails_and_clause_no_model(self): - pytest.raises(ValueError, normalize_query, [{'and':[{'field':'nickname', 'value':'Johnny'}, {'field':'nickname', 'value':'Winny'}]}, {'_model':'Proxy'}]) + pytest.raises(ValueError, normalize_query, [{'and': [{'field': 'nickname', 'value': 'Johnny'}, {'field': 'nickname', 'value': 'Winny'}]}, {'_model': 'Proxy'}]) def test_fails_or_clause_no_model(self): - pytest.raises(ValueError, normalize_query, [{'or':[{'field':'nickname', 'value':'Johnny'}, {'field':'nickname', 'value':'Winny'}]}, {'_model':'Proxy'}]) + pytest.raises(ValueError, normalize_query, [{'or': [{'field': 'nickname', 'value': 'Johnny'}, {'field': 'nickname', 'value': 'Winny'}]}, {'_model': 'Proxy'}]) def test_fails_and_clause_list_of_lists(self): - pytest.raises(ValueError, normalize_query, [{'and':[[], []]}, {'_model':'Proxy'}]) + pytest.raises(ValueError, normalize_query, [{'and': [[], []]}, {'_model': 'Proxy'}]) def test_fails_and_clause_with_model_list_of_lists(self): - pytest.raises(ValueError, normalize_query, [{'_model':'Human', 'and':[[], []]}, {'_model':'Proxy'}]) + pytest.raises(ValueError, normalize_query, [{'_model': 'Human', 'and': [[], []]}, {'_model': 'Proxy'}]) def test_fails_or_clause_with_model_list_of_lists(self): - pytest.raises(ValueError, normalize_query, [{'_model':'Human', 'or':[[], []]}, {'_model':'Proxy'}]) + pytest.raises(ValueError, normalize_query, [{'_model': 'Human', 'or': [[], []]}, {'_model': 'Proxy'}]) class TestCollectModels(object): @@ -952,7 +1033,7 @@ def test_crud_spec(self): assert self.expected_crud_spec == CrudableClass._crud_spec def test_basic_crud_spec(self): - expected_basic = {'fields': {k: self.expected_crud_spec['fields'][k] + expected_basic = {'fields': {k: self.expected_crud_spec['fields'][k] for k in ('id', 'mixed_in_attr', 'extra_data')}} assert expected_basic == BasicClassMixedIn._crud_spec diff --git a/sideboard/tests/test_sep.py b/sideboard/tests/test_sep.py new file mode 100644 index 0000000..aa557a0 --- /dev/null +++ b/sideboard/tests/test_sep.py @@ -0,0 +1,53 @@ +from __future__ import unicode_literals +import sys + +import pytest +from mock import Mock + +from sideboard import sep +from sideboard.lib import entry_point +from sideboard.lib._utils import _entry_points +from sideboard.sep import run_plugin_entry_point + + +class FakeExit(Exception): + pass + + +class TestSep(object): + @pytest.yield_fixture(autouse=True) + def automocks(self, monkeypatch): + monkeypatch.setattr(sep, 'exit', Mock(side_effect=FakeExit), raising=False) + prev_argv, prev_points = sys.argv[:], _entry_points.copy() + yield + sys.argv[:] = prev_argv + _entry_points.clear() + _entry_points.update(prev_points) + + def test_no_command(self): + sys.argv[:] = ['sep'] + pytest.raises(FakeExit, run_plugin_entry_point) + sep.exit.assert_called_with(1) + + def test_help(self): + for flag in ['-h', '--help']: + sys.argv[:] = ['sep', flag] + pytest.raises(FakeExit, run_plugin_entry_point) + sep.exit.assert_called_with(0) + sep.exit.reset_mock() + + def test_invalid(self): + sys.argv[:] = ['sep', 'nonexistent_entry_point'] + pytest.raises(FakeExit, run_plugin_entry_point) + sep.exit.assert_called_with(2) + + def test_valid_entry_point(self): + action = Mock() + + @entry_point + def foobar(): + action(sys.argv) + + sys.argv[:] = ['sep', 'foobar', 'baz', '--baf'] + run_plugin_entry_point() + action.assert_called_with(['foobar', 'baz', '--baf']) diff --git a/sideboard/tests/test_server.py b/sideboard/tests/test_server.py index b828a8b..35576ad 100644 --- a/sideboard/tests/test_server.py +++ b/sideboard/tests/test_server.py @@ -5,7 +5,6 @@ from time import sleep from random import randrange from unittest import TestCase -from contextlib import closing import six from six.moves.queue import Queue, Empty @@ -18,12 +17,23 @@ from ws4py.server.cherrypyserver import WebSocketPlugin import sideboard.websockets -from sideboard.lib import log, config, subscribes, notifies, services, cached_property, WebSocket -from sideboard.tests import service_patcher, config_patcher +from sideboard.lib import log, config, subscribes, notifies, notify, services, cached_property, WebSocket +from sideboard.tests import service_patcher, config_patcher, get_available_port from sideboard.tests.test_sa import Session -@pytest.mark.nonfunctional +if config['cherrypy']['server.socket_port'] == 0: + available_port = get_available_port() + + # The config is updated in two places because by the time this code is + # executed, cherrypy.config will already be populated with the values from + # our config file. The configuration will already be living in two places, + # each of which must be updated. + config['cherrypy']['server.socket_port'] = available_port + cherrypy.config.update({'server.socket_port': available_port}) + + +@pytest.mark.functional class SideboardServerTest(TestCase): port = config['cherrypy']['server.socket_port'] jsonrpc_url = 'http://127.0.0.1:{}/jsonrpc'.format(port) @@ -31,12 +41,6 @@ class SideboardServerTest(TestCase): rsess_username = 'unit_tests' - @staticmethod - def assert_port_open(port): - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(('0.0.0.0', port)) - @staticmethod def assert_can_connect_to_localhost(port): for i in range(50): @@ -51,6 +55,8 @@ def assert_can_connect_to_localhost(port): @classmethod def start_cherrypy(cls): + config['thread_wait_interval'] = 0.1 + class Root(object): @cherrypy.expose def index(self): @@ -60,7 +66,6 @@ def index(self): cherrypy.tree.apps.pop('/mock_login', None) cherrypy.tree.mount(Root(), '/mock_login') - cls.assert_port_open(cls.port) cherrypy.config.update({'engine.autoreload_on': False}) cherrypy.engine.start() cherrypy.engine.wait(cherrypy.engine.states.STARTED) @@ -232,9 +237,9 @@ def slow_echo(self, s): def get_names(self): return self.names - @notifies('names') def change_name(self, name=None): self.names[-1] = name or uuid4().hex + notify('names', delay=True) @notifies('names') def change_name_then_error(self): @@ -423,10 +428,11 @@ def test_slow(self): class TestWebsocketsCrudSubscriptions(SideboardServerTest): @pytest.fixture(autouse=True) def override(self, service_patcher): - class MockCrud: pass + class MockCrud: + pass mr = self.mr = MockCrud() for name in ['create', 'update', 'delete']: - setattr(mr, name, Session.crud.crud_notifies(self.make_crud_method(name), delay=0.5)) + setattr(mr, name, Session.crud.crud_notifies(self.make_crud_method(name))) for name in ['read', 'count']: setattr(mr, name, Session.crud.crud_subscribes(self.make_crud_method(name))) service_patcher('crud', mr) diff --git a/sideboard/tests/test_websocket.py b/sideboard/tests/test_websocket.py index cc1e9c7..6f87e30 100644 --- a/sideboard/tests/test_websocket.py +++ b/sideboard/tests/test_websocket.py @@ -3,6 +3,9 @@ import pytest from mock import Mock, ANY +import ws4py.websocket + +from sideboard.websockets import WebSocketDispatcher from sideboard.lib import log, WebSocket, threadlocal, stopped from sideboard.tests import config_patcher @@ -11,6 +14,7 @@ def reset_stopped(): stopped.clear() + @pytest.fixture def ws(monkeypatch): ws = WebSocket(connect_immediately=False) @@ -19,6 +23,14 @@ def ws(monkeypatch): ws._next_id = lambda prefix: 'xxx' return ws + +@pytest.fixture +def orig_ws(monkeypatch): + monkeypatch.setattr(ws4py.websocket.WebSocket, '__init__', Mock(return_value=None)) + monkeypatch.setattr(WebSocketDispatcher, 'check_authentication', Mock(return_value={'username': 'mock_username'})) + return WebSocketDispatcher() + + def test_subscribe_basic(ws): callback = Mock() assert 'xxx' == ws.subscribe(callback, 'foo.bar', 'x', 'y') @@ -33,6 +45,7 @@ def test_subscribe_basic(ws): ws._send.assert_called_with(method='foo.bar', params=('x', 'y'), client='xxx') assert not log.warn.called + def test_subscribe_advanced(ws): callback, errback = Mock(), Mock() request = { @@ -51,12 +64,36 @@ def test_subscribe_advanced(ws): ws._send.assert_called_with(method='foo.bar', params=('x', 'y'), client='yyy') assert not log.warn.called + def test_subscribe_error(ws): ws._send = Mock(side_effect=Exception) ws.subscribe(Mock(), 'foo.bar') assert 'xxx' in ws._callbacks assert log.warn.called + +def test_subscribe_paramback(ws): + paramback = lambda: (5, 6) + callback, errback = Mock(), Mock() + request = { + 'client': 'yyy', + 'callback': callback, + 'errback': errback, + 'paramback': paramback + } + assert 'yyy' == ws.subscribe(request, 'foo.bar') + assert ws._callbacks['yyy'] == { + 'client': 'yyy', + 'callback': callback, + 'errback': errback, + 'paramback': paramback, + 'method': 'foo.bar', + 'params': (5, 6) + } + ws._send.assert_called_with(method='foo.bar', params=(5, 6), client='yyy') + assert not log.warn.called + + def test_unsubscribe(ws): ws._callbacks['xxx'] = {} ws.unsubscribe('xxx') @@ -69,37 +106,41 @@ def returner(ws): ws._send = Mock(side_effect=lambda **kwargs: ws._callbacks['xxx']['callback'](123)) return ws + @pytest.fixture def errorer(ws): ws._send = Mock(side_effect=lambda **kwargs: ws._callbacks['xxx']['errback']('fail')) return ws + def test_call_raises_on_send_error(ws): ws._send = Mock(side_effect=Exception) pytest.raises(Exception, ws.call, 'foo.bar') assert 'xxx' not in ws._callbacks + def test_call_returns_value(returner): assert 123 == returner.call('foo.bar') assert 'xxx' not in returner._callbacks + def test_call_error(errorer): pytest.raises(Exception, errorer.call, 'foo.bar') assert 'xxx' not in errorer._callbacks + def test_call_timeout(ws, config_patcher, monkeypatch): - monkeypatch.setattr(stopped, 'wait', Mock()) + monkeypatch.setattr(stopped, 'is_set', Mock(return_value=False)) config_patcher(1, 'ws.call_timeout') pytest.raises(Exception, ws.call, 'foo.bar') assert 'xxx' not in ws._callbacks - assert stopped.wait.call_count == 10 + assert 9 <= stopped.is_set.call_count <= 11 + def test_call_stopped_set(ws, request, monkeypatch): - monkeypatch.setattr(stopped, 'wait', Mock()) request.addfinalizer(stopped.clear) stopped.set() pytest.raises(Exception, ws.call, 'foo.bar') - assert stopped.wait.call_count == 1 @pytest.fixture @@ -111,12 +152,14 @@ def refirer(ws): }) return ws + def test_refire(refirer): refirer._refire_subscriptions() assert refirer._send.call_count == 2 refirer._send.assert_any_call(method='x.x', params=(1, 2), client='xxx') refirer._send.assert_any_call(method='z.z', params=(5, 6), client='zzz') + def test_refire_error(refirer): refirer._send = Mock(side_effect=Exception) refirer._refire_subscriptions() @@ -129,15 +172,45 @@ def test_make_method_caller(ws): func(1, 2) ws.call.assert_called_with('foo.bar', 1, 2) -def test_make_subscription_caller(ws): - orig_ws = Mock(spec=['NO_RESPONSE']) + +def test_make_subscription_caller(ws, orig_ws): threadlocal.reset(message={'client': 'xxx'}, websocket=orig_ws) func = ws.make_caller('foo.bar') assert func(1, 2) == orig_ws.NO_RESPONSE ws._send.assert_called_with(method='foo.bar', params=(1, 2), client=ANY) -def test_make_subscription_unsubscribe(ws): + +def test_make_updated_subscription_caller(ws, orig_ws): + threadlocal.reset(message={'client': 'xxx'}, websocket=orig_ws) + func = ws.make_caller('foo.bar') + assert func is ws.make_caller('foo.baz') + assert func.method == 'foo.baz' + + +def test_make_subscription_unsubscribe(ws, orig_ws): ws.unsubscribe = Mock() - threadlocal.reset(message={'client': 'xxx'}, websocket=Mock(spec=['NO_RESPONSE'])) + ws._next_id = Mock(return_value='xxx') + threadlocal.reset(message={'client': 'yyy'}, websocket=orig_ws) ws.make_caller('foo.bar').unsubscribe() ws.unsubscribe.assert_called_with('xxx') + + +def test_preprocess_call(ws, returner): + ws.preprocess = lambda method, params: ['mock_modified_params'] + assert 123 == ws.call('foo.bar') + ws._send.assert_called_with(method='foo.bar', params=['mock_modified_params'], callback='xxx') + + +def test_preprocess_subscribe(ws): + ws.preprocess = lambda method, params: ['mock_modified_params'] + callback = Mock() + assert 'xxx' == ws.subscribe(callback, 'foo.bar') + registered = ws._callbacks['xxx'].copy() + del registered['errback'] + assert registered == { + 'client': 'xxx', + 'callback': callback, + 'method': 'foo.bar', + 'params': ['mock_modified_params'] + } + ws._send.assert_called_with(method='foo.bar', params=['mock_modified_params'], client='xxx') diff --git a/sideboard/tests/test_websocket_dispatcher.py b/sideboard/tests/test_websocket_dispatcher.py index d7d6402..8a32e7f 100644 --- a/sideboard/tests/test_websocket_dispatcher.py +++ b/sideboard/tests/test_websocket_dispatcher.py @@ -1,38 +1,65 @@ from __future__ import unicode_literals +from threading import RLock from collections import namedtuple import pytest from mock import Mock, ANY from ws4py.websocket import WebSocket -from sideboard.lib import log, services, subscribes +from sideboard.lib import log, services, subscribes, threadlocal from sideboard.websockets import WebSocketDispatcher, responder, threadlocal from sideboard.tests import service_patcher from sideboard.tests.test_websocket import ws +mock_session_data = {'username': 'mock_user', 'user_id': 'mock_id'} +mock_header_data = {'REMOTE_USER': 'mock_user', 'REMOTE_USER_ID': 'mock_id'} + + +def mock_wsd(): + wsd = Mock() + wsd.is_closed = False + return wsd + + +@pytest.yield_fixture(autouse=True) +def cleanup(): + yield + threadlocal.reset() + WebSocketDispatcher.instances.clear() + @pytest.fixture def wsd(monkeypatch): + WebSocketDispatcher.instances.clear() monkeypatch.setattr(WebSocket, 'send', Mock()) monkeypatch.setattr(WebSocket, 'closed', Mock()) - monkeypatch.setattr(WebSocketDispatcher, 'check_authentication', lambda self: 'mock_user') + monkeypatch.setattr(WebSocketDispatcher, 'is_closed', False) + monkeypatch.setattr(WebSocketDispatcher, 'check_authentication', lambda cls: mock_session_data) + monkeypatch.setattr(WebSocketDispatcher, 'fetch_headers', lambda cls: mock_header_data) return WebSocketDispatcher(None) + @pytest.fixture -def ws1(): return Mock() +def ws1(): return mock_wsd() + @pytest.fixture -def ws2(): return Mock() +def ws2(): return mock_wsd() + @pytest.fixture -def ws3(): return Mock() +def ws3(): return mock_wsd() + @pytest.fixture def ws4(): class RaisesError: + is_closed = False + unsubscribe_all = Mock() trigger = Mock(side_effect=Exception) return RaisesError + @pytest.fixture(autouse=True) def subscriptions(request, wsd, ws1, ws2, ws3, ws4): def reset_subscriptions(): @@ -48,47 +75,64 @@ def reset_subscriptions(): WebSocketDispatcher.subscriptions['baf'][ws4]['client-3'].add(None) +def test_instances(wsd): + assert WebSocketDispatcher.instances == {wsd} + wsd.closed('code', 'reason') + assert WebSocketDispatcher.instances == set() + + def test_get_all_subscribed(wsd, ws1, ws2, ws3, ws4): assert WebSocketDispatcher.get_all_subscribed() == {wsd, ws1, ws2, ws3, ws4} -def test_basic_broadcast(ws1, ws2): - WebSocketDispatcher.broadcast('bar', trigger='manual') - ws1.trigger.assert_called_with(client='client-1', callback='callback-1', trigger='manual') - assert not ws2.trigger.called +class TestBroadcast(object): + def test_basic_broadcast(self, ws1, ws2): + WebSocketDispatcher.broadcast('bar', trigger='manual') + ws1.trigger.assert_called_with(client='client-1', callback='callback-1', trigger='manual') + assert not ws2.trigger.called + assert not ws1.unsubscribe_all.called and not ws2.unsubscribe_all.called -def test_broadcast_with_originating_client(ws1, ws2): - WebSocketDispatcher.broadcast('foo', originating_client='client-1') - assert ws2.trigger.called and not ws1.trigger.called + def test_broadcast_with_originating_client(self, ws1, ws2): + WebSocketDispatcher.broadcast('foo', originating_client='client-1') + assert ws2.trigger.called and not ws1.trigger.called -def test_multi_broadcast(ws1, ws2, ws3, ws4): - WebSocketDispatcher.broadcast(['foo', 'bar']) - assert ws1.trigger.called and ws2.trigger.called and not ws3.trigger.called and not ws4.trigger.called + def test_multi_broadcast(self, ws1, ws2, ws3, ws4): + WebSocketDispatcher.broadcast(['foo', 'bar']) + assert ws1.trigger.called and ws2.trigger.called and not ws3.trigger.called and not ws4.trigger.called -def test_broadcast_error(ws4, monkeypatch): - monkeypatch.setattr(log, 'warn', Mock()) - WebSocketDispatcher.broadcast('foo') - assert not ws4.trigger.called and not log.warn.called - WebSocketDispatcher.broadcast('baf') - assert ws4.trigger.called and log.warn.called + def test_broadcast_error(self, ws4, monkeypatch): + monkeypatch.setattr(log, 'warn', Mock()) + WebSocketDispatcher.broadcast('foo') + assert not ws4.trigger.called and not log.warn.called + WebSocketDispatcher.broadcast('baf') + assert ws4.trigger.called and log.warn.called and not ws4.unsubscribe_all.called + + def test_broadcast_closed(self, ws1, ws2): + ws1.is_closed = True + WebSocketDispatcher.broadcast('foo') + assert ws2.trigger.called and not ws1.trigger.called + assert ws1.unsubscribe_all.called and not ws2.unsubscribe_all.called def test_basic_send(wsd): wsd.send(foo='bar', baz=None) WebSocket.send.assert_called_with(ANY, '{"foo":"bar"}') + def test_send_client_caching(wsd): wsd.send(client='xxx', data=123) wsd.send(client='xxx', data=123) wsd.send(client='yyy', data=321) assert WebSocket.send.call_count == 2 + def test_no_send_caching_without_client(wsd): wsd.send(data=123) wsd.send(data=123) wsd.send(data=321) assert WebSocket.send.call_count == 3 + def test_callback_based_send_caching(wsd): wsd.send(client='xxx', callback='yyy', data=123) wsd.send(client='xxx', callback='yyy', data=123) @@ -121,41 +165,51 @@ def test_get_method(wsd, service_patcher): def test_unsubscribe_from_nonexistent(wsd): wsd.unsubscribe('nonexistent') # does not error + def test_unsubscribe(wsd): client = 'client-1' - wsd.client_locks[client] = 'lock' - wsd.cached_queries[client] = {None: (Mock(), (), {})} - wsd.cached_fingerprints[client] = 'fingerprint' + wsd.client_locks[client] = RLock() + wsd.cached_queries[client] = {None: (Mock(), (), {}, {})} + wsd.cached_fingerprints[client][None] = 'fingerprint' WebSocketDispatcher.subscriptions['foo'] = {wsd: {client: 'subscription'}} - wsd.unsubscribe(client) + wsd.handle_message({'action': 'unsubscribe', 'client': client}) for d in [wsd.client_locks, wsd.cached_queries, wsd.cached_fingerprints, WebSocketDispatcher.subscriptions['foo']]: assert client not in d + def test_multi_unsubscribe(wsd): client = ['client-1', 'client-2'] - wsd.client_locks = {'client-1': 'lock', 'client-2': 'lock'} + wsd.client_locks = {'client-1': RLock(), 'client-2': RLock()} wsd.cached_fingerprints = {'client-1': 'fingerprint', 'client-2': 'fingerprint'} - wsd.cached_queries = {'client-1': {None: (Mock(), (), {})}, 'client-2': {None: (Mock(), (), {})}} + wsd.cached_queries = {'client-1': {None: (Mock(), (), {}, {})}, 'client-2': {None: (Mock(), (), {}, {})}} WebSocketDispatcher.subscriptions['foo'] = {wsd: {'client-1': 'subscription', 'client-2': 'subscription'}} - wsd.unsubscribe(client) + wsd.handle_message({'action': 'unsubscribe', 'client': client}) for d in [wsd.client_locks, wsd.cached_queries, wsd.cached_fingerprints, WebSocketDispatcher.subscriptions['foo']]: assert 'client-1' not in d assert 'client-2' not in d -def test_unsubscribe_all(wsd): + +def test_unsubscribe_all(wsd, subscriptions): assert wsd in WebSocketDispatcher.subscriptions['foo'] assert wsd in WebSocketDispatcher.subscriptions['bar'] + sub1 = wsd.passthru_subscriptions['client-0'] = Mock() + sub2 = wsd.passthru_subscriptions['client-x'] = Mock() + wsd.unsubscribe_all() + assert wsd not in WebSocketDispatcher.subscriptions['foo'] assert wsd not in WebSocketDispatcher.subscriptions['bar'] + assert 'client-0' not in wsd.passthru_subscriptions and 'client-x' not in wsd.passthru_subscriptions + assert sub1.unsubscribe.called and sub2.unsubscribe.called + def test_remote_unsubscribe(wsd, ws): ws.unsubscribe = Mock() - threadlocal.reset(websocket=ws, message={'client': 'xxx'}) - wsd.cached_queries['xxx'] = {None: (ws.make_caller('remote.foo'), (), {})} + ws._next_id = Mock(return_value='yyy') + threadlocal.reset(websocket=wsd, message={'client': 'xxx'}) + wsd.cached_queries['xxx'] = {None: (ws.make_caller('remote.foo'), (), {}, {})} wsd.unsubscribe('xxx') - ws.unsubscribe.assert_called_with('xxx') - + ws.unsubscribe.assert_called_with('yyy') def test_update_subscriptions_with_new_callback(wsd): @@ -163,17 +217,20 @@ def test_update_subscriptions_with_new_callback(wsd): assert WebSocketDispatcher.subscriptions['foo'][wsd]['client-0'] == {'callback-0', 'xxx'} assert WebSocketDispatcher.subscriptions['bar'][wsd]['client-0'] == {None} + def test_update_subscriptions_with_existing_null_callback(wsd): wsd.update_subscriptions(client='client-0', callback=None, channels='foo') assert WebSocketDispatcher.subscriptions['foo'][wsd]['client-0'] == {'callback-0', None} assert WebSocketDispatcher.subscriptions['bar'][wsd]['client-0'] == set() + def test_update_subscriptions_with_existing_callback(wsd): wsd.update_subscriptions(client='client-0', callback='callback-0', channels='baz') assert WebSocketDispatcher.subscriptions['foo'][wsd]['client-0'] == set() assert WebSocketDispatcher.subscriptions['bar'][wsd]['client-0'] == {None} assert WebSocketDispatcher.subscriptions['baz'][wsd]['client-0'] == {'callback-0'} + def test_update_subscriptions_with_multiple_channels(wsd): wsd.update_subscriptions(client='client-0', callback='callback-0', channels=['foo', 'baz']) assert WebSocketDispatcher.subscriptions['foo'][wsd]['client-0'] == {'callback-0'} @@ -183,49 +240,74 @@ def test_update_subscriptions_with_multiple_channels(wsd): @pytest.fixture def trig(wsd): - wsd.cached_queries['xxx']['yyy'] = (lambda *args, **kwargs: [args, kwargs], ('a', 'b'), {'c': 'd'}) + wsd.cached_queries['xxx']['yyy'] = (lambda *args, **kwargs: [args, kwargs], ('a', 'b'), {'c': 'd'}, {}) wsd.send = Mock() return wsd + +def increment(): + count = threadlocal.client_data.setdefault('count', 0) + count += 1 + threadlocal.client_data['count'] = count + return count + + def test_trigger(trig): trig.trigger(client='xxx', callback='yyy', trigger='zzz') trig.send.assert_called_with(client='xxx', callback='yyy', trigger='zzz', data=[('a', 'b'), {'c': 'd'}]) + def test_trigger_without_id(trig): trig.trigger(client='xxx', callback='yyy') trig.send.assert_called_with(client='xxx', callback='yyy', trigger=None, data=[('a', 'b'), {'c': 'd'}]) + def test_trigger_without_known_client(trig): trig.trigger(client='doesNotExist', callback='yyy') assert not trig.send.called + def test_trigger_without_known_callback(trig): trig.trigger(client='xxx', callback='doesNotExist') assert not trig.send.called +def test_trigger_with_client_data(wsd, trig, monkeypatch): + client = 'client-1' + monkeypatch.setitem(wsd.subscriptions['foo'][wsd], client, [None]) + monkeypatch.setitem(wsd.cached_fingerprints, client, {None: 'fingerprint'}) + monkeypatch.setitem(wsd.cached_queries, client, {None: (increment, (), {}, {'count': 7})}) + + wsd.trigger(client=client, callback=None) + wsd.send.assert_called_with(client=client, callback=None, trigger=None, data=8) + + @pytest.fixture def up(wsd): wsd.send = Mock() wsd.update_subscriptions = Mock() return wsd + @subscribes('foo') def foosub(): return 'e' + def test_update_triggers_client_and_callback(up): up.update_triggers('xxx', 'yyy', foosub, ('a', 'b'), {'c': 'd'}, 'e', 123) up.update_subscriptions.assert_called_with('xxx', 'yyy', ['foo']) - assert up.cached_queries['xxx']['yyy'] == (foosub, ('a', 'b'), {'c': 'd'}) + assert up.cached_queries['xxx']['yyy'] == (foosub, ('a', 'b'), {'c': 'd'}, {}) assert not up.send.called + def test_update_triggers_client_no_callback(up): up.update_triggers('xxx', None, foosub, ('a', 'b'), {'c': 'd'}, 'e', 123) up.update_subscriptions.assert_called_with('xxx', None, ['foo']) - assert up.cached_queries['xxx'][None] == (foosub, ('a', 'b'), {'c': 'd'}) + assert up.cached_queries['xxx'][None] == (foosub, ('a', 'b'), {'c': 'd'}, {}) up.send.assert_called_with(trigger='subscribe', client='xxx', data='e', _time=123) + def test_update_triggers_no_client(up): for callback in [None, 'yyy']: up.update_triggers(None, 'yyy', foosub, ('a', 'b'), {'c': 'd'}, 'e', 123) @@ -233,10 +315,11 @@ def test_update_triggers_no_client(up): assert 'yyy' not in up.cached_queries[None] assert not up.send.called + def test_update_triggers_with_error(up): up.update_triggers('xxx', None, foosub, ('a', 'b'), {'c': 'd'}, up.NO_RESPONSE, 123) up.update_subscriptions.assert_called_with('xxx', None, ['foo']) - assert up.cached_queries['xxx'][None] == (foosub, ('a', 'b'), {'c': 'd'}) + assert up.cached_queries['xxx'][None] == (foosub, ('a', 'b'), {'c': 'd'}, {}) assert not up.send.called @@ -246,17 +329,20 @@ def act(wsd, monkeypatch): monkeypatch.setattr(log, 'warn', Mock()) return wsd + def test_unsubscribe_action(act): act.unsubscribe = Mock() act.internal_action('unsubscribe', 'xxx', 'yyy') act.unsubscribe.assert_called_with('xxx') assert not log.warn.called + def test_unknown_action(act): act.internal_action('does_not_exist', 'xxx', 'yyy') assert not act.unsubscribe.called assert log.warn.called + def test_no_action(act): act.internal_action(None, 'xxx', 'yyy') assert not act.unsubscribe.called @@ -272,18 +358,21 @@ def receiver(wsd, monkeypatch): Message = namedtuple('Message', ['data']) + def test_received_message(receiver): receiver.received_message(Message('{}')) responder.defer.assert_called_with(ANY, {}) assert not receiver.send.called assert not log.error.called + def test_received_invalid_message(receiver): receiver.received_message(Message('not valid json')) assert not responder.defer.called receiver.send.assert_called_with(error=ANY) assert log.error.called + def test_received_non_dict(receiver): receiver.received_message(Message('"valid json but not a dict"')) assert not responder.defer.called @@ -307,6 +396,7 @@ def handler(ws, wsd, service_patcher, monkeypatch): wsd.update_triggers = Mock() return wsd + def test_handle_message_with_callback(handler): message = { 'method': 'foo.bar', @@ -314,12 +404,14 @@ def test_handle_message_with_callback(handler): 'callback': 'xxx' } handler.handle_message(message) - threadlocal.reset.assert_called_with(websocket=handler, message=message, username=handler.username) + threadlocal.reset.assert_called_with(websocket=handler, message=message, headers=mock_header_data, + **mock_session_data) handler.internal_action.assert_called_with(None, None, 'xxx') handler.update_triggers.assert_called_with(None, 'xxx', services.foo.bar, ['baf'], {}, 'baz', ANY) handler.send.assert_called_with(data='baz', callback='xxx', client=None, _time=ANY) assert not log.error.called + def test_handle_method_with_client(handler): message = { 'method': 'foo.bar', @@ -327,32 +419,38 @@ def test_handle_method_with_client(handler): 'client': 'xxx' } handler.handle_message(message) - threadlocal.reset.assert_called_with(websocket=handler, message=message, username=handler.username) + threadlocal.reset.assert_called_with(websocket=handler, message=message, headers=mock_header_data, + **mock_session_data) handler.internal_action.assert_called_with(None, 'xxx', None) handler.update_triggers.assert_called_with('xxx', None, services.foo.bar, [], {'baf': 1}, 'baz', ANY) assert not handler.send.called assert not log.error.called + def test_handle_message_client_error(handler): message = {'method': 'foo.err', 'client': 'xxx'} handler.handle_message(message) - threadlocal.reset.assert_called_with(websocket=handler, message=message, username=handler.username) + threadlocal.reset.assert_called_with(websocket=handler, message=message, headers=mock_header_data, + **mock_session_data) handler.internal_action.assert_called_with(None, 'xxx', None) handler.update_triggers.assert_called_with('xxx', None, services.foo.err, [], {}, handler.NO_RESPONSE, ANY) assert log.error.called handler.send.assert_called_with(error=ANY, client='xxx', callback=None) assert handler.send.call_count == 1 + def test_handle_message_callback_error(handler): message = {'method': 'foo.err', 'callback': 'xxx'} handler.handle_message(message) - threadlocal.reset.assert_called_with(websocket=handler, message=message, username=handler.username) + threadlocal.reset.assert_called_with(websocket=handler, message=message, headers=mock_header_data, + **mock_session_data) handler.internal_action.assert_called_with(None, None, 'xxx') handler.update_triggers.assert_called_with(None, 'xxx', services.foo.err, [], {}, handler.NO_RESPONSE, ANY) assert log.error.called handler.send.assert_called_with(error=ANY, callback='xxx', client=None) assert handler.send.call_count == 1 + def test_handle_message_remote_call(handler, ws): message = {'method': 'remote.method', 'callback': 'xxx', 'params': [1, 2]} handler.handle_message(message) @@ -360,9 +458,29 @@ def test_handle_message_remote_call(handler, ws): assert not ws.subscribe.called handler.send.assert_called_with(callback='xxx', data=12345, client=None, _time=ANY) + def test_handle_message_remote_subscribe(handler, ws): message = {'method': 'remote.method', 'client': 'xxx', 'params': [1, 2]} handler.handle_message(message) ws.subscribe.assert_called_with(ANY, 'remote.method', 1, 2) assert not ws.call.called assert not handler.send.called + + +def test_skip_send_if_closed(monkeypatch, wsd): + wsd.send() + monkeypatch.setattr(WebSocketDispatcher, 'is_closed', True) + wsd.send() + assert WebSocket.send.call_count == 1 + + +def test_explicit_call_resets_cache(service_patcher, wsd): + service_patcher('foo', { + 'bar': lambda: 'Hello World' + }) + message = {'method': 'foo.bar', 'client': 'client-1', 'callback': 'callback-2'} + wsd.handle_message(message) + assert 'callback-2' in wsd.cached_fingerprints['client-1'] + assert WebSocket.send.call_count == 1 + wsd.handle_message(message) + assert WebSocket.send.call_count == 2 diff --git a/sideboard/websockets.py b/sideboard/websockets.py index 4a74b7f..382b0f6 100755 --- a/sideboard/websockets.py +++ b/sideboard/websockets.py @@ -17,11 +17,38 @@ from ws4py.server.cherrypyserver import WebSocketPlugin, WebSocketTool import sideboard.lib -from sideboard.lib import log, Caller +from sideboard.lib import log, class_property, Caller from sideboard.config import config +local_subscriptions = defaultdict(list) +DELAYED_NOTIFICATIONS_KEY = 'sideboard.delayed_notifications' + class threadlocal(object): + """ + This class exposes a dict-like interface on top of the threading.local + utility class; the "get", "set", "setdefault", and "clear" methods work the + same as for a dict except that each thread gets its own keys and values. + + Sideboard clears out all existing values and then initializes some specific + values in the following situations: + + 1) CherryPy page handlers have the 'username' key set to whatever value is + returned by cherrypy.session['username']. + + 2) Service methods called via JSON-RPC have the following two fields set: + -> username: as above + -> websocket_client: if the JSON-RPC request has a "websocket_client" + field, it's value is set here; this is used internally as the + "originating_client" value in notify() and plugins can ignore this + + 3) Service methods called via websocket have the following three fields set: + -> username: as above + -> websocket: the WebSocketDispatcher instance receiving the RPC call + -> client_data: see the client_data property below for an explanation + -> message: the RPC request body; this is present on the initial call + but not on subscription triggers in the broadcast thread + """ _threadlocal = local() @classmethod @@ -32,20 +59,45 @@ def get(cls, key, default=None): def set(cls, key, val): return setattr(cls._threadlocal, key, val) + @classmethod + def setdefault(cls, key, val): + val = cls.get(key, val) + cls.set(key, val) + return val + @classmethod def clear(cls): cls._threadlocal.__dict__.clear() @classmethod def get_client(cls): + """ + If called as part of an initial websocket RPC request, this returns the + client id if one exists, and otherwise returns None. Plugins probably + shouldn't need to call this method themselves. + """ return cls.get('client') or cls.get('message', {}).get('client') @classmethod def reset(cls, **kwargs): + """ + Plugins should never call this method directly without a good reason; it + clears out all existing values and replaces them with the key-value + pairs passed as keyword arguments to this function. + """ cls.clear() for key, val in kwargs.items(): cls.set(key, val) + @class_property + def client_data(cls): + """ + This propery is basically the websocket equivalent of cherrypy.session; + it's a dictionary where your service methods can place data which you'd + like to use in subsequent method calls. + """ + return cls.setdefault('client_data', {}) + def _normalize_channels(*channels): """ @@ -101,9 +153,45 @@ def _normalize_channels(*channels): return list(set(normalized_channels)) -def notify(channels, trigger="manual", delay=0, originating_client=None): - broadcaster.delayed(delay, _normalize_channels(*sideboard.lib.listify(channels)), - trigger=trigger, originating_client=originating_client or threadlocal.get_client()) +def notify(channels, trigger="manual", delay=False, originating_client=None): + """ + Manually trigger all subscriptions on the given channels. The following + optional parameters may be specified: + + trigger: Used in log messages if you want to distinguish between triggers. + delay: Boolean indicating whether the notification should happen immediately + or after the current WebSocket RPC method has completed. Note that + if this parameter is set when notify is called outside of a WebSocket + RPC request, no notification will ever happen. + originating_client: Websocket subscriptions will NOT fire if they have the + same client as the trigger. + """ + channels = _normalize_channels(*sideboard.lib.listify(channels)) + context = { + 'trigger': trigger, + 'originating_client': originating_client or threadlocal.get_client() + } + if delay: + threadlocal.setdefault(DELAYED_NOTIFICATIONS_KEY, []).append([channels, context]) + else: + broadcaster.defer(channels, **context) + local_broadcaster.defer(channels, **context) + + +def trigger_delayed_notifications(): + """ + Sometimes plugins might want to call notify() and have it trigger after + their RPC method has completed its call. For example, a plugin might call + notify() in the middle of a database transaction and want the notification + to happen after a commit has occurred. When notify() is called with + delay=True then it appends to a list, and this method goes through that list + and triggers broadcasts for those notifications. + """ + if threadlocal.get(DELAYED_NOTIFICATIONS_KEY): + for channels, context in threadlocal.get(DELAYED_NOTIFICATIONS_KEY): + broadcaster.defer(channels, **context) + local_broadcaster.defer(channels, **context) + threadlocal.set(DELAYED_NOTIFICATIONS_KEY, []) def notifies(*args, **kwargs): @@ -123,7 +211,6 @@ def notifies(*args, **kwargs): >>> getattr(fn_dict, 'notifies') ['dict'] """ - delay = kwargs.pop("delay", 0) channels = _normalize_channels(*args) def decorated_func(func): @@ -132,7 +219,7 @@ def notifier_func(*args, **kwargs): try: return func(*args, **kwargs) finally: - notify(channels, trigger=func.__name__, delay=delay) + notify(channels, trigger=func.__name__) notifier_func.notifies = channels return notifier_func @@ -166,6 +253,41 @@ def decorated_func(func): return decorated_func +def locally_subscribes(*args): + """ + The @subscribes decorator registers a function as being one which clients + may subscribe to via websocket. This decorator may be used to register a + function which shall be called locally anytime a notify occurs, e.g. + + @locally_subscribes('example.channel') + def f(): + print('f was called') + + notify('example.channel') # causes f() to be called in a separate thread + """ + def decorated_func(func): + for channel in _normalize_channels(*args): + local_subscriptions[channel].append(func) + return func + + return decorated_func + + +def local_broadcast(channels, trigger=None, originating_client=None): + """Triggers callbacks registered via @locally_subscribes""" + triggered = set() + for channel in sideboard.lib.listify(channels): + for callback in local_subscriptions[channel]: + triggered.add(callback) + + for callback in triggered: + threadlocal.reset(trigger=trigger, originating_client=originating_client) + try: + callback() + except: + log.error('unexpected error on local broadcast callback', exc_info=True) + + def _fingerprint(x): """ Calculates the md5 sum of the given argument. @@ -217,23 +339,135 @@ def get_params(params): class WebSocketDispatcher(WebSocket): - username = None + """ + This class is instantiated for each incoming websocket connection. Each + instance of this class has its own socket object and its own thread. This + class is where we respond to RPC requests. + """ + NO_RESPONSE = object() + """ + This object is used as a sentinel value for situations where we want to + avoid double-sending a response. For example, when an RPC request for a + subscription arrives, we "trigger" a subscription response immediately, so + there's no need to actually call "send" on the return value. + + This is an internal implementation detail and plugins shouldn't need to know + or care that this field exists. + """ + subscriptions = defaultdict(lambda: defaultdict(lambda: defaultdict(set))) + """ + This tracks all subscriptions for all incoming websocket connections. The + structure looks like this: + + { + 'channel_id': { + : { + 'client_id': {'callback_id_one', 'callback_id_two', ...}, + }, + ... + }, + ... + } + + This allows us to do things like trigger a broadcast to all websockets + subscribed on a channel. Instances of this class are responsible for + adding and removing their subscriptions from this data structure. + """ + + instances = set() + """ + When debugging Sideboard, it can be useful to introspect a list of all + open websocket connections which have been made to this server. Instances + of this class add themselves to this set when instantiated and remove + themselves when closed. + """ def __init__(self, *args, **kwargs): + """ + This passes all arguments to the parent constructor. In addition, it + defines the following instance variables: + + send_lock: Used to guarantee thread-safety when sending RPC responses. + + client_locks: A dict mapping client ids to locks used by those clients. + + passthru_subscriptions: When we recieve a subscription request for a + service method registered on a remote service, we pass that request + along to the remote service and send back the responses. This + dictionary maps client ids to those subscription objects. + + session_fields: We copy session data for the currently-authenticated + user who made the incoming websocket connection; by default we only + copy the username, but this can be overridden in configuration. + Remember that Sideboard exposes two websocket handlers at /ws and + /wsrpc, with /ws being auth-protected (so the username field will be + meaningful) and /wsrpc being client-cert protected (so the username + will always be 'rpc'). + + header_fields: We copy header fields from the request that initiated the + websocket connection. + + cached_queries and cached_fingerprints: When we receive a subscription + update, Sideboard re-runs all of the subscription methods to see if + new data needs to be pushed out. We do this by storing all of the + rpc methods and an MD5 hash of their return values. We store a hash + rather than the return values themselves to save on memory, since + return values may be very large. + + The cached_queries dict has this structure: + { + 'client_id': { + 'callback_id': (func, args, kwargs, client_data), + ... + }, + ... + } + + The cached_fingerprints dict has this structure: + { + 'client_id': { + 'callback_id': 'md5_hash_of_return_value', + ... + }, + ... + } + """ WebSocket.__init__(self, *args, **kwargs) + self.instances.add(self) self.send_lock = RLock() - self.client_locks, self.cached_queries, self.cached_fingerprints = \ - defaultdict(RLock), defaultdict(dict), defaultdict(dict) - self.username = self.check_authentication() + self.passthru_subscriptions = {} + self.client_locks = defaultdict(RLock) + self.cached_queries, self.cached_fingerprints = defaultdict(dict), defaultdict(dict) + self.session_fields = self.check_authentication() + self.header_fields = self.fetch_headers() + + @classmethod + def fetch_headers(cls): + """ + This method returns a dict with all of the header fields we want to + store for this websocket so that we can set them as threadlocal global + variables for all subsequent websocket RPC requests. + """ + return {field: cherrypy.request.headers.get(field) for field in config['ws.header_fields']} @classmethod def check_authentication(cls): - return cherrypy.session['username'] + """ + This method raises an exception if the user is not currently logged in, + and otherwise returns a dict with all of the session fields we want to + store for this websocket so that we can set them as threadlocal global + variables for all subsequent websocket RPC requests. By default we only + return the username of the currently-logged-in user in this manner, but + this can be changed via the ws.session_fields config option. Subclasses + can also override this method to change how we handle authentication. + """ + return {field: cherrypy.session.get(field) for field in config['ws.session_fields']} @classmethod def get_all_subscribed(cls): + """Returns a set of all instances of this class with active subscriptions.""" websockets = set() for channel, subscriptions in cls.subscriptions.items(): for websocket, clients in subscriptions.items(): @@ -242,13 +476,34 @@ def get_all_subscribed(cls): @classmethod def broadcast(cls, channels, trigger=None, originating_client=None): + """ + Trigger all subscriptions on the given channel(s). This method is + called in the "broadcaster" thread, which means that all subscription + updates happen in the same thread. + + Callers can pass an "originating_client" id, which will prevent data + from being pushed to those clients. This is useful in cases like this: + -> a Javascipt application makes a call like "ecard.delete" + -> not wanting to wait for a subscription update, the Javascript app + preemptively updates its local data store to remove the item + -> the response to the delete call comes back as a success + -> because the local data store was already updated, there's no need + for this client to get a subscription update + + Callers can pass a "trigger" field, which will be included in the + subscription update message as the reason for the update. This doesn't + affect anything, but might be useful for logging. + """ triggered = set() for channel in sideboard.lib.listify(channels): - for websocket, clients in cls.subscriptions[channel].items(): - for client, callbacks in clients.copy().items(): - if client != originating_client: - for callback in callbacks: - triggered.add((websocket, client, callback)) + for websocket, clients in list(cls.subscriptions[channel].items()): + if websocket.is_closed: + websocket.unsubscribe_all() + else: + for client, callbacks in clients.copy().items(): + if client != originating_client: + for callback in callbacks: + triggered.add((websocket, client, callback)) for websocket, client, callback in triggered: try: @@ -256,19 +511,67 @@ def broadcast(cls, channels, trigger=None, originating_client=None): except: log.warn('ignoring unexpected trigger error', exc_info=True) + @property + def is_closed(self): + """ + The "terminated" attribute tells us whether the socket was explictly + closed, this property performs a more rigorous check to let us know + if any of the fields which indicate the socket has been closed have been + set; this allows us to avoid spurious error messages by not attempting + to send messages on a socket which is in the process of closing. + """ + return not self.stream or self.client_terminated or self.server_terminated or not self.sock + def client_lock(self, client): + """ + Sideboard has a pool of background threads which simultaneously executes + method calls, but it performs per-subscription locking to ensure thread + safety for our subscription-related data structures. Thus, if the same + connected websocket sends two method calls with the same client id, + those calls will be handled sequentially rather than concurrently. + + This utility method supports this by returning a context manager which + acquires the necessary locks on entrance and releases them on exit. It + takes either a client id or list of client ids. + """ ordered_clients = sorted(sideboard.lib.listify(client or [])) + ordered_locks = [self.client_locks[oc] for oc in ordered_clients] + class MultiLock(object): def __enter__(inner_self): - for client in ordered_clients: - self.client_locks[client].acquire() - + for lock in ordered_locks: + lock.acquire() + def __exit__(inner_self, *args, **kwargs): - for client in reversed(ordered_clients): - self.client_locks[client].release() + for lock in reversed(ordered_locks): + lock.release() + return MultiLock() def send(self, **message): + """ + This overrides the ws4py-provided send to implement three new features: + + 1) Instead of taking a string, this method treats its keyword arguments + as the message, serializes them to JSON, and sends that. + + 2) For subscription responses, we keep track of the most recent response + we sent for the given subscription. If neither the request or + response have changed since the last time we pushed data back to the + client for this subscription, we don't send anything. + + 3) We lock when sending to ensure that our sends are thread-safe. + Surprisingly, the "ws4py.threadedclient" class isn't thread-safe! + + 4) Subscriptions firing will sometimes trigger a send on a websocket + which has already been marked as closed. When this happens we log a + debug message and then exit without error. + """ + if self.is_closed: + log.debug('ignoring send on an already closed websocket: {}', message) + self.unsubscribe_all() + return + message = {k: v for k, v in message.items() if v is not None} if 'data' in message and 'client' in message: fingerprint = _fingerprint(message['data']) @@ -283,35 +586,63 @@ def send(self, **message): message = json.dumps(message, cls=sideboard.lib.serializer, separators=(',', ':'), sort_keys=True) with self.send_lock: - WebSocket.send(self, message) + if not self.is_closed: + WebSocket.send(self, message) def closed(self, code, reason=''): + """ + This overrides the default closed handler to first clean up all of our + subscriptions, remove this websocket from the registry of instances, + and log a message before closing. + """ log.info('closing: code={!r} reason={!r}', code, reason) + self.instances.discard(self) self.unsubscribe_all() WebSocket.closed(self, code, reason) + def teardown_passthru(self, client): + """ + Given a client id, check whether there's a "passthrough subscription" + for that client and clean it up if one exists. + """ + subscription = self.passthru_subscriptions.pop(client, None) + if subscription: + subscription.unsubscribe() + def get_method(self, action): + """ + Given a method string in the format "module_name.function_name", + return a callable object representing that function, raising an + exception if the format is invalid or no such method exists. + """ service_name, method_name = action.split('.') service = getattr(sideboard.lib.services, service_name) method = getattr(service, method_name) return method def unsubscribe(self, clients): - for client in sideboard.lib.listify(clients): + """ + Given a client id or list of client ids, clean up those subscriptions + from the internal data structures of this class. + """ + for client in sideboard.lib.listify(clients or []): + self.teardown_passthru(client) self.client_locks.pop(client, None) - self.cached_fingerprints.pop(client, None) - for func, args, kwargs in self.cached_queries[client].values(): - if hasattr(func, 'unsubscribe'): - func.unsubscribe() # remote subscriptions self.cached_queries.pop(client, None) + self.cached_fingerprints.pop(client, None) for clients in self.subscriptions.values(): clients[self].pop(client, None) def unsubscribe_all(self): + """Called on close to tear down all of this websocket's subscriptions.""" for clients in self.subscriptions.values(): - clients.pop(self, None) + clients.pop(self, {}) + + for passthru_client in list(self.passthru_subscriptions.keys()): + self.teardown_passthru(passthru_client) def update_subscriptions(self, client, callback, channels): + """Updates WebSocketDispatcher.subscriptions for the given client/channels.""" for clients in self.subscriptions.values(): clients[self][client].discard(callback) @@ -319,25 +650,62 @@ def update_subscriptions(self, client, callback, channels): self.subscriptions[channel][self][client].add(callback) def trigger(self, client, callback, trigger=None): + """ + This is the method called by the global broadcaster thread when a + notification is posted to a channel this client is subscribed to. It + re-calls the function and sends the result back to the client. + """ if callback in self.cached_queries[client]: - function, args, kwargs = self.cached_queries[client][callback] + function, args, kwargs, client_data = self.cached_queries[client][callback] + threadlocal.reset(websocket=self, client_data=client_data, headers=self.header_fields, **self.session_fields) result = function(*args, **kwargs) self.send(trigger=trigger, client=client, callback=callback, data=result) def update_triggers(self, client, callback, function, args, kwargs, result, duration=None): + """ + This is called after an RPC function is invoked; it takes the function + and its return value and updates our internal data structures then sends + the response back to the client. + """ if hasattr(function, 'subscribes') and client is not None: - self.cached_queries[client][callback] = (function, args, kwargs) + self.cached_queries[client][callback] = (function, args, kwargs, threadlocal.client_data) self.update_subscriptions(client, callback, function.subscribes) if client is not None and callback is None and result is not self.NO_RESPONSE: self.send(trigger='subscribe', client=client, data=result, _time=duration) def internal_action(self, action, client, callback): + """ + Sideboard currently supports both method calls and "internal actions" + which affect the state of the websocket connection itself. This + implements the command-dispatch pattern to perform the given action and + raises an exception if that action doesn't exist. + + The only action currently implemented is "unsubscribe". + """ if action == 'unsubscribe': self.unsubscribe(client) elif action is not None: log.warn('unknown action {!r}', action) + def clear_cached_response(self, client, callback): + """ + As explained above, Sideboard caches the most recent response to a + subscription so that when we check the subscription we can see if new + data needs to be sent. However, if the user makes a series of requests + with the same client/callback ids which return the same response, they + probably still expect to get data back. This method is therefore called + every time we receive an explicit RPC call for a subscription to discard + the cached value, ensuring that an explicit RPC call to a service + exposed via websocket always receives a response. + """ + self.cached_fingerprints[client].pop(callback, None) + def received_message(self, message): + """ + This overrides the default ws4py event handler to parse the incoming + message and pass it off to our pool of background threads, which call + this class' handle_message function to perform the relevant RPC actions. + """ try: data = message.data if isinstance(message.data, six.text_type) else message.data.decode('utf-8') fields = json.loads(data) @@ -351,14 +719,20 @@ def received_message(self, message): responder.defer(self, fields) def handle_message(self, message): + """ + Given a message dictionary, perform the relevant RPC actions and send + out the response. This function is called from a pool of background + threads + """ before = time.time() duration, result = None, None - threadlocal.reset(websocket=self, message=message, username=self.username) + threadlocal.reset(websocket=self, message=message, headers=self.header_fields, **self.session_fields) action, callback, client, method = message.get('action'), message.get('callback'), message.get('client'), message.get('method') try: with self.client_lock(client): self.internal_action(action, client, callback) if method: + self.clear_cached_response(client, callback) func = self.get_method(method) args, kwargs = get_params(message.get('params')) result = self.NO_RESPONSE @@ -366,6 +740,7 @@ def handle_message(self, message): result = func(*args, **kwargs) duration = (time.time() - before) if config['debug'] else None finally: + trigger_delayed_notifications() self.update_triggers(client, callback, func, args, kwargs, result, duration) except: log.error('unexpected websocket dispatch error', exc_info=True) @@ -378,7 +753,17 @@ def handle_message(self, message): self.send(data=result, callback=callback, client=client, _time=duration) def __repr__(self): - return '<%s username=%s>' % (self.__class__.__name__, self.username) + return '<{} {}>'.format( + self.__class__.__name__, + ' '.join('{}={}'.format(k, v) for k, v in self.session_fields.items()) + ) + + +class WebSocketAuthError(Exception): + """ + Exception raised by WebSocketDispatcher subclasses to indicate that there is + not a currently-logged-in user able to make a websocket connection. + """ class WebSocketRoot(object): @@ -390,12 +775,16 @@ def default(self): class WebSocketChecker(WebSocketTool): def __init__(self): cherrypy.Tool.__init__(self, 'before_handler', self.upgrade) + self._priority = cherrypy.tools.sessions._priority + 1 # must be initialized after the sessions tool def upgrade(self, **kwargs): try: kwargs['handler_cls'].check_authentication() - except: + except WebSocketAuthError: raise cherrypy.HTTPError(401, 'You must be logged in to establish a websocket connection.') + except: + log.error('unexpected websocket authentication error', exc_info=True) + raise cherrypy.HTTPError(401, 'unexpected authentication error') else: return WebSocketTool.upgrade(self, **kwargs) @@ -408,5 +797,6 @@ def upgrade(self, **kwargs): WebSocketPlugin.start.priority = 66 websocket_plugin.subscribe() +local_broadcaster = Caller(local_broadcast) broadcaster = Caller(WebSocketDispatcher.broadcast) responder = Caller(WebSocketDispatcher.handle_message, threads=config['ws.thread_pool']) diff --git a/test-defaults.ini b/test-defaults.ini new file mode 100644 index 0000000..970fdb0 --- /dev/null +++ b/test-defaults.ini @@ -0,0 +1,30 @@ + +# The settings in this file (test-defaults.ini) can be overridden in test.ini. + +debug = True +ws.auth_required = False + +is_test_running = True + +[cherrypy] +engine.autoreload.on = False +profiling.on = False +server.socket_host = "127.0.0.1" + +# By default the test server runs on port 8282. If you are already using +# port 8282, you'll receive errors like: +# OSError: Port 8282 not free on '127.0.0.1' +# +# You can change this setting in test.ini to either another free port, or +# to port 0 (meaning a free port will be chosen by the OS automatically). +# Using port 0 is mostly safe, but on heavily used systems there is a +# potential race condition if another process uses the same port in the time +# between requesting an available port and actually using it. +# +# See https://eklitzke.org/binding-on-port-zero +# +server.socket_port = 8282 + + +[loggers] +root = "DEBUG" diff --git a/test_requirements.txt b/test_requirements.txt new file mode 100644 index 0000000..3c0c982 --- /dev/null +++ b/test_requirements.txt @@ -0,0 +1,5 @@ +pytest>=3.0.1 +mock>=1.0.1,<1.1 +Sphinx>=1.2.1 +coverage>=3.6 +pep8>=1.7.0 diff --git a/tox.ini b/tox.ini index a048275..1712758 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,38 @@ -# content of: tox.ini, put in same dir as setup.py +# QUICK TIPS +# ========== +# +# Run all tests for all environments from the command line: +# $ tox +# +# +# Run all tests for a single environment from the command line: +# $ tox -e pep8 +# or: +# $ tox -e py34 +# +# +# Run only tests that match a substring expression, for a single environment: +# $ tox -e py34 -- -k expression +# +# +# In general, everything after the "--" is passed as arguments to py.test: +# $ tox -- -s -v -k expression +# [tox] -envlist = py27 +envlist=pep8,py27,py33,py34,py35 +skipsdist=True + [testenv] +setenv= + SIDEBOARD_CONFIG_OVERRIDES=test-defaults.ini +deps= + -rrequirements.txt + -rtest_requirements.txt commands= - python setup.py develop - coverage run -m py.test - coverage report + coverage run --source sideboard -m py.test {posargs} sideboard + coverage report --show-missing + +[testenv:pep8] +deps=pep8 +commands= + pep8 sideboard/

Login

Username:
Password: