diff --git a/juju/client/connection.py b/juju/client/connection.py index 6f2f2a2a0..745739187 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -8,6 +8,7 @@ import ssl import string import subprocess +import weakref import websockets from concurrent.futures import CancelledError from http.client import HTTPSConnection @@ -40,14 +41,15 @@ class Monitor: """ ERROR = 'error' CONNECTED = 'connected' + DISCONNECTING = 'disconnecting' DISCONNECTED = 'disconnected' - UNKNOWN = 'unknown' def __init__(self, connection): - self.connection = connection - self.close_called = asyncio.Event(loop=self.connection.loop) - self.receiver_stopped = asyncio.Event(loop=self.connection.loop) - self.pinger_stopped = asyncio.Event(loop=self.connection.loop) + self.connection = weakref.ref(connection) + self.reconnecting = asyncio.Lock(loop=connection.loop) + self.close_called = asyncio.Event(loop=connection.loop) + self.receiver_stopped = asyncio.Event(loop=connection.loop) + self.pinger_stopped = asyncio.Event(loop=connection.loop) self.receiver_stopped.set() self.pinger_stopped.set() @@ -63,35 +65,27 @@ def status(self): isn't usable until that receiver has been started. """ + connection = self.connection() - # DISCONNECTED: connection not yet open - if not self.connection.ws: + # the connection instance was destroyed but someone kept + # a separate reference to the monitor for some reason + if not connection: return self.DISCONNECTED - if self.receiver_stopped.is_set(): - return self.DISCONNECTED - - # ERROR: Connection closed (or errored), but we didn't call - # connection.close - if not self.close_called.is_set() and self.receiver_stopped.is_set(): - return self.ERROR - if not self.close_called.is_set() and not self.connection.ws.open: - # The check for self.receiver_stopped existing above guards - # against the case where we're not open because we simply - # haven't setup the connection yet. - return self.ERROR - # DISCONNECTED: cleanly disconnected. - if self.close_called.is_set() and not self.connection.ws.open: + # connection cleanly disconnected or not yet opened + if not connection.ws: return self.DISCONNECTED - # CONNECTED: everything is fine! - if self.connection.ws.open: - return self.CONNECTED + # close called but not yet complete + if self.close_called.is_set(): + return self.DISCONNECTING + + # connection closed uncleanly (we didn't call connection.close) + if self.receiver_stopped.is_set() or not connection.ws.open: + return self.ERROR - # UNKNOWN: We should never hit this state -- if we do, - # something went wrong with the logic above, and we do not - # know what state the connection is in. - return self.UNKNOWN + # everything is fine! + return self.CONNECTED class Connection: @@ -120,6 +114,7 @@ def __init__( self, endpoint, uuid, username, password, cacert=None, macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE): self.endpoint = endpoint + self._endpoint = endpoint self.uuid = uuid if macaroons: self.macaroons = macaroons @@ -130,6 +125,7 @@ def __init__( self.username = username self.password = password self.cacert = cacert + self._cacert = cacert self.loop = loop or asyncio.get_event_loop() self.__request_id__ = 0 @@ -144,9 +140,7 @@ def __init__( @property def is_open(self): - if self.ws: - return self.ws.open - return False + return self.monitor.status == Monitor.CONNECTED def _get_ssl(self, cert=None): return ssl.create_default_context( @@ -171,12 +165,13 @@ async def open(self): return self async def close(self): - if not self.is_open: + if not self.ws: return self.monitor.close_called.set() await self.monitor.pinger_stopped.wait() await self.monitor.receiver_stopped.wait() await self.ws.close() + self.ws = None async def recv(self, request_id): if not self.is_open: @@ -197,13 +192,18 @@ async def receiver(self): await self.messages.put(result['request-id'], result) except CancelledError: pass - except Exception as e: + except websockets.ConnectionClosed as e: + log.warning('Receiver: Connection closed, reconnecting') await self.messages.put_all(e) - if isinstance(e, websockets.ConnectionClosed): - # ConnectionClosed is not really exceptional for us, - # but it may be for any pending message listeners - return + # the reconnect has to be done as a task because the receiver will + # be cancelled by the reconnect and we don't want the reconnect + # to be aborted half-way through + self.loop.create_task(self.reconnect()) + return + except Exception as e: log.exception("Error in receiver") + # make pending listeners aware of the error + await self.messages.put_all(e) raise finally: self.monitor.receiver_stopped.set() @@ -225,7 +225,7 @@ async def _do_ping(): pinger_facade = client.PingerFacade.from_connection(self) try: - while self.is_open: + while True: await utils.run_with_interrupt( _do_ping(), self.monitor.close_called, @@ -234,6 +234,7 @@ async def _do_ping(): break finally: self.monitor.pinger_stopped.set() + return async def rpc(self, msg, encoder=None): self.__request_id__ += 1 @@ -243,7 +244,19 @@ async def rpc(self, msg, encoder=None): if "version" not in msg: msg['version'] = self.facades[msg['type']] outgoing = json.dumps(msg, indent=2, cls=encoder) - await self.ws.send(outgoing) + for attempt in range(3): + try: + await self.ws.send(outgoing) + break + except websockets.ConnectionClosed: + if attempt == 2: + raise + log.warning('RPC: Connection closed, reconnecting') + # the reconnect has to be done in a separate task because, + # if it is triggered by the pinger, then this RPC call will + # be cancelled when the pinger is cancelled by the reconnect, + # and we don't want the reconnect to be aborted halfway through + await asyncio.wait([self.reconnect()], loop=self.loop) result = await self.recv(msg['request-id']) if not result: @@ -379,36 +392,49 @@ async def _try_endpoint(self, endpoint, cacert): await self.close() return success, result, new_endpoints - @classmethod - async def connect( - cls, endpoint, uuid, username, password, cacert=None, - macaroons=None, loop=None, max_frame_size=None): - """Connect to the websocket. - - If uuid is None, the connection will be to the controller. Otherwise it - will be to the model. - + async def reconnect(self): + """ Force a reconnection. """ - client = cls(endpoint, uuid, username, password, cacert, macaroons, - loop, max_frame_size) - endpoints = [(endpoint, cacert)] + monitor = self.monitor + if monitor.reconnecting.locked() or monitor.close_called.is_set(): + return + async with monitor.reconnecting: + await self.close() + await self._connect() + + async def _connect(self): + endpoints = [(self._endpoint, self._cacert)] while endpoints: _endpoint, _cacert = endpoints.pop(0) - success, result, new_endpoints = await client._try_endpoint( + success, result, new_endpoints = await self._try_endpoint( _endpoint, _cacert) if success: break endpoints.extend(new_endpoints) else: # ran out of endpoints without a successful login - raise Exception("Couldn't authenticate to {}".format(endpoint)) + raise Exception("Couldn't authenticate to {}".format( + self._endpoint)) response = result['response'] - client.info = response.copy() - client.build_facades(response.get('facades', {})) - client.loop.create_task(client.pinger()) - client.monitor.pinger_stopped.clear() + self.info = response.copy() + self.build_facades(response.get('facades', {})) + self.loop.create_task(self.pinger()) + self.monitor.pinger_stopped.clear() + + @classmethod + async def connect( + cls, endpoint, uuid, username, password, cacert=None, + macaroons=None, loop=None, max_frame_size=None): + """Connect to the websocket. + + If uuid is None, the connection will be to the controller. Otherwise it + will be to the model. + """ + client = cls(endpoint, uuid, username, password, cacert, macaroons, + loop, max_frame_size) + await client._connect() return client @classmethod diff --git a/juju/loop.py b/juju/loop.py index 3720159df..4abedfcc3 100644 --- a/juju/loop.py +++ b/juju/loop.py @@ -20,7 +20,14 @@ def abort(): task.cancel() run._sigint = True - loop.add_signal_handler(signal.SIGINT, abort) + added = False + try: + loop.add_signal_handler(signal.SIGINT, abort) + added = True + except ValueError as e: + # add_signal_handler doesn't work in a thread + if 'main thread' not in str(e): + raise try: for step in steps: task = loop.create_task(step) @@ -31,4 +38,5 @@ def abort(): raise task.exception() return task.result() finally: - loop.remove_signal_handler(signal.SIGINT) + if added: + loop.remove_signal_handler(signal.SIGINT) diff --git a/juju/model.py b/juju/model.py index 61905c921..7b86ba3b5 100644 --- a/juju/model.py +++ b/juju/model.py @@ -14,6 +14,7 @@ from functools import partial from pathlib import Path +import websockets import yaml import theblues.charmstore import theblues.errors @@ -550,6 +551,8 @@ async def block_until(self, *conditions, timeout=None, wait_period=0.5): """ async def _block(): while not all(c() for c in conditions): + if not (self.connection and self.connection.is_open): + raise websockets.ConnectionClosed(1006, 'no reason') await asyncio.sleep(wait_period, loop=self.loop) await asyncio.wait_for(_block(), timeout, loop=self.loop) @@ -643,16 +646,45 @@ def _watch(self): See :meth:`add_observer` to register an onchange callback. """ - async def _start_watch(): + async def _all_watcher(): try: allwatcher = client.AllWatcherFacade.from_connection( self.connection) while not self._watch_stopping.is_set(): - results = await utils.run_with_interrupt( - allwatcher.Next(), - self._watch_stopping, - self.loop) + try: + results = await utils.run_with_interrupt( + allwatcher.Next(), + self._watch_stopping, + self.loop) + except JujuAPIError as e: + if 'watcher was stopped' not in str(e): + raise + if self._watch_stopping.is_set(): + # this shouldn't ever actually happen, because + # the event should trigger before the controller + # has a chance to tell us the watcher is stopped + # but handle it gracefully, just in case + break + # controller stopped our watcher for some reason + # but we're not actually stopping, so just restart it + log.warning( + 'Watcher: watcher stopped, restarting') + del allwatcher.Id + continue + except websockets.ConnectionClosed: + monitor = self.connection.monitor + if monitor.status == monitor.ERROR: + # closed unexpectedly, try to reopen + log.warning( + 'Watcher: connection closed, reopening') + await self.connection.reconnect() + del allwatcher.Id + continue + else: + # closed on request, go ahead and shutdown + break if self._watch_stopping.is_set(): + await allwatcher.Stop() break for delta in results.deltas: delta = get_entity_delta(delta) @@ -671,7 +703,7 @@ async def _start_watch(): self._watch_received.clear() self._watch_stopping.clear() self._watch_stopped.clear() - self.loop.create_task(_start_watch()) + self.loop.create_task(_all_watcher()) async def _notify_observers(self, delta, old_obj, new_obj): """Call observing callbacks, notifying them of a change in model state diff --git a/tests/base.py b/tests/base.py index 8ea51092d..e1ec45238 100644 --- a/tests/base.py +++ b/tests/base.py @@ -44,6 +44,9 @@ async def __aenter__(self): model_name = 'model-{}'.format(uuid.uuid4()) self.model = await self.controller.add_model(model_name) + # save the model UUID in case test closes model + self.model_uuid = self.model.info.uuid + # Ensure that we connect to the new model by default. This also # prevents failures if test was started with no current model. self._patch_cm = mock.patch.object(JujuData, 'current_model', @@ -55,7 +58,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc, tb): self._patch_cm.stop() await self.model.disconnect() - await self.controller.destroy_model(self.model.info.uuid) + await self.controller.destroy_model(self.model_uuid) await self.controller.disconnect() diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index 67dfb2e3c..290203d47 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -1,3 +1,4 @@ +import asyncio import pytest from juju.client.connection import Connection @@ -47,7 +48,7 @@ async def test_monitor_catches_error(event_loop): @pytest.mark.asyncio async def test_full_status(event_loop): async with base.CleanModel() as model: - app = await model.deploy( + await model.deploy( 'ubuntu-0', application_name='ubuntu', series='trusty', @@ -56,4 +57,27 @@ async def test_full_status(event_loop): c = client.ClientFacade.from_connection(model.connection) - status = await c.FullStatus(None) + await c.FullStatus(None) + + +@base.bootstrapped +@pytest.mark.asyncio +async def test_reconnect(event_loop): + async with base.CleanModel() as model: + conn = await Connection.connect( + model.connection.endpoint, + model.connection.uuid, + model.connection.username, + model.connection.password, + model.connection.cacert, + model.connection.macaroons, + model.connection.loop, + model.connection.max_frame_size) + try: + await asyncio.sleep(0.1) + assert conn.is_open + await conn.ws.close() + assert not conn.is_open + await model.block_until(lambda: conn.is_open, timeout=3) + finally: + await conn.close() diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 088dcd576..37f51c0dc 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -212,6 +212,15 @@ async def test_get_machines(event_loop): assert isinstance(result, list) +@base.bootstrapped +@pytest.mark.asyncio +async def test_watcher_reconnect(event_loop): + async with base.CleanModel() as model: + await model.connection.ws.close() + await asyncio.sleep(0.1) + assert model.connection.is_open + + # @base.bootstrapped # @pytest.mark.asyncio # async def test_grant(event_loop) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 340264ead..f69b8d6bc 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -1,3 +1,4 @@ +import asyncio import json import mock import pytest @@ -20,6 +21,7 @@ async def send(self, message): async def recv(self): if not self.responses: + await asyncio.sleep(1) # delay to give test time to finish raise ConnectionClosed(0, 'ran out of responses') return json.dumps(self.responses.popleft())