Skip to content
140 changes: 83 additions & 57 deletions juju/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ssl
import string
import subprocess
import weakref
import websockets
from concurrent.futures import CancelledError
from http.client import HTTPSConnection
Expand Down Expand Up @@ -40,14 +41,15 @@ class Monitor:
"""
ERROR = 'error'
CONNECTED = 'connected'
DISCONNECTING = 'disconnecting'
DISCONNECTED = 'disconnected'
UNKNOWN = 'unknown'

def __init__(self, connection):
self.connection = connection
self.close_called = asyncio.Event(loop=self.connection.loop)
self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
self.connection = weakref.ref(connection)
self.reconnecting = asyncio.Lock(loop=connection.loop)
self.close_called = asyncio.Event(loop=connection.loop)
self.receiver_stopped = asyncio.Event(loop=connection.loop)
self.pinger_stopped = asyncio.Event(loop=connection.loop)
self.receiver_stopped.set()
self.pinger_stopped.set()

Expand All @@ -63,35 +65,27 @@ def status(self):
isn't usable until that receiver has been started.

"""
connection = self.connection()

# DISCONNECTED: connection not yet open
if not self.connection.ws:
# the connection instance was destroyed but someone kept
# a separate reference to the monitor for some reason
if not connection:
return self.DISCONNECTED
if self.receiver_stopped.is_set():
return self.DISCONNECTED

# ERROR: Connection closed (or errored), but we didn't call
# connection.close
if not self.close_called.is_set() and self.receiver_stopped.is_set():
return self.ERROR
if not self.close_called.is_set() and not self.connection.ws.open:
# The check for self.receiver_stopped existing above guards
# against the case where we're not open because we simply
# haven't setup the connection yet.
return self.ERROR

# DISCONNECTED: cleanly disconnected.
if self.close_called.is_set() and not self.connection.ws.open:
# connection cleanly disconnected or not yet opened
if not connection.ws:
return self.DISCONNECTED

# CONNECTED: everything is fine!
if self.connection.ws.open:
return self.CONNECTED
# close called but not yet complete
if self.close_called.is_set():
return self.DISCONNECTING

# connection closed uncleanly (we didn't call connection.close)
if self.receiver_stopped.is_set() or not connection.ws.open:
return self.ERROR

# UNKNOWN: We should never hit this state -- if we do,
# something went wrong with the logic above, and we do not
# know what state the connection is in.
return self.UNKNOWN
# everything is fine!
return self.CONNECTED


class Connection:
Expand Down Expand Up @@ -120,6 +114,7 @@ def __init__(
self, endpoint, uuid, username, password, cacert=None,
macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
self.endpoint = endpoint
self._endpoint = endpoint
self.uuid = uuid
if macaroons:
self.macaroons = macaroons
Expand All @@ -130,6 +125,7 @@ def __init__(
self.username = username
self.password = password
self.cacert = cacert
self._cacert = cacert
self.loop = loop or asyncio.get_event_loop()

self.__request_id__ = 0
Expand All @@ -144,9 +140,7 @@ def __init__(

@property
def is_open(self):
if self.ws:
return self.ws.open
return False
return self.monitor.status == Monitor.CONNECTED

def _get_ssl(self, cert=None):
return ssl.create_default_context(
Expand All @@ -171,12 +165,13 @@ async def open(self):
return self

async def close(self):
if not self.is_open:
if not self.ws:
return
self.monitor.close_called.set()
await self.monitor.pinger_stopped.wait()
await self.monitor.receiver_stopped.wait()
await self.ws.close()
self.ws = None

async def recv(self, request_id):
if not self.is_open:
Expand All @@ -197,13 +192,18 @@ async def receiver(self):
await self.messages.put(result['request-id'], result)
except CancelledError:
pass
except Exception as e:
except websockets.ConnectionClosed as e:
log.warning('Receiver: Connection closed, reconnecting')
await self.messages.put_all(e)
if isinstance(e, websockets.ConnectionClosed):
# ConnectionClosed is not really exceptional for us,
# but it may be for any pending message listeners
return
# the reconnect has to be done as a task because the receiver will
# be cancelled by the reconnect and we don't want the reconnect
# to be aborted half-way through
self.loop.create_task(self.reconnect())
return
except Exception as e:
log.exception("Error in receiver")
# make pending listeners aware of the error
await self.messages.put_all(e)
raise
finally:
self.monitor.receiver_stopped.set()
Expand All @@ -225,7 +225,7 @@ async def _do_ping():

pinger_facade = client.PingerFacade.from_connection(self)
try:
while self.is_open:
while True:
await utils.run_with_interrupt(
_do_ping(),
self.monitor.close_called,
Expand All @@ -234,6 +234,7 @@ async def _do_ping():
break
finally:
self.monitor.pinger_stopped.set()
return

async def rpc(self, msg, encoder=None):
self.__request_id__ += 1
Expand All @@ -243,7 +244,19 @@ async def rpc(self, msg, encoder=None):
if "version" not in msg:
msg['version'] = self.facades[msg['type']]
outgoing = json.dumps(msg, indent=2, cls=encoder)
await self.ws.send(outgoing)
for attempt in range(3):
try:
await self.ws.send(outgoing)
break
except websockets.ConnectionClosed:
if attempt == 2:
raise
log.warning('RPC: Connection closed, reconnecting')
# the reconnect has to be done in a separate task because,
# if it is triggered by the pinger, then this RPC call will
# be cancelled when the pinger is cancelled by the reconnect,
# and we don't want the reconnect to be aborted halfway through
await asyncio.wait([self.reconnect()], loop=self.loop)
result = await self.recv(msg['request-id'])

if not result:
Expand Down Expand Up @@ -379,36 +392,49 @@ async def _try_endpoint(self, endpoint, cacert):
await self.close()
return success, result, new_endpoints

@classmethod
async def connect(
cls, endpoint, uuid, username, password, cacert=None,
macaroons=None, loop=None, max_frame_size=None):
"""Connect to the websocket.

If uuid is None, the connection will be to the controller. Otherwise it
will be to the model.

async def reconnect(self):
""" Force a reconnection.
"""
client = cls(endpoint, uuid, username, password, cacert, macaroons,
loop, max_frame_size)
endpoints = [(endpoint, cacert)]
monitor = self.monitor
if monitor.reconnecting.locked() or monitor.close_called.is_set():
return
async with monitor.reconnecting:
await self.close()
await self._connect()

async def _connect(self):
endpoints = [(self._endpoint, self._cacert)]
while endpoints:
_endpoint, _cacert = endpoints.pop(0)
success, result, new_endpoints = await client._try_endpoint(
success, result, new_endpoints = await self._try_endpoint(
_endpoint, _cacert)
if success:
break
endpoints.extend(new_endpoints)
else:
# ran out of endpoints without a successful login
raise Exception("Couldn't authenticate to {}".format(endpoint))
raise Exception("Couldn't authenticate to {}".format(
self._endpoint))

response = result['response']
client.info = response.copy()
client.build_facades(response.get('facades', {}))
client.loop.create_task(client.pinger())
client.monitor.pinger_stopped.clear()
self.info = response.copy()
self.build_facades(response.get('facades', {}))
self.loop.create_task(self.pinger())
self.monitor.pinger_stopped.clear()

@classmethod
async def connect(
cls, endpoint, uuid, username, password, cacert=None,
macaroons=None, loop=None, max_frame_size=None):
"""Connect to the websocket.

If uuid is None, the connection will be to the controller. Otherwise it
will be to the model.

"""
client = cls(endpoint, uuid, username, password, cacert, macaroons,
loop, max_frame_size)
await client._connect()
return client

@classmethod
Expand Down
12 changes: 10 additions & 2 deletions juju/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@ def abort():
task.cancel()
run._sigint = True

loop.add_signal_handler(signal.SIGINT, abort)
added = False
try:
loop.add_signal_handler(signal.SIGINT, abort)
added = True
except ValueError as e:
# add_signal_handler doesn't work in a thread
if 'main thread' not in str(e):
raise
try:
for step in steps:
task = loop.create_task(step)
Expand All @@ -31,4 +38,5 @@ def abort():
raise task.exception()
return task.result()
finally:
loop.remove_signal_handler(signal.SIGINT)
if added:
loop.remove_signal_handler(signal.SIGINT)
44 changes: 38 additions & 6 deletions juju/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import partial
from pathlib import Path

import websockets
import yaml
import theblues.charmstore
import theblues.errors
Expand Down Expand Up @@ -550,6 +551,8 @@ async def block_until(self, *conditions, timeout=None, wait_period=0.5):
"""
async def _block():
while not all(c() for c in conditions):
if not (self.connection and self.connection.is_open):
raise websockets.ConnectionClosed(1006, 'no reason')
await asyncio.sleep(wait_period, loop=self.loop)
await asyncio.wait_for(_block(), timeout, loop=self.loop)

Expand Down Expand Up @@ -643,16 +646,45 @@ def _watch(self):
See :meth:`add_observer` to register an onchange callback.

"""
async def _start_watch():
async def _all_watcher():
try:
allwatcher = client.AllWatcherFacade.from_connection(
self.connection)
while not self._watch_stopping.is_set():
results = await utils.run_with_interrupt(
allwatcher.Next(),
self._watch_stopping,
self.loop)
try:
results = await utils.run_with_interrupt(
allwatcher.Next(),
self._watch_stopping,
self.loop)
except JujuAPIError as e:
if 'watcher was stopped' not in str(e):
raise
if self._watch_stopping.is_set():
# this shouldn't ever actually happen, because
# the event should trigger before the controller
# has a chance to tell us the watcher is stopped
# but handle it gracefully, just in case
break
# controller stopped our watcher for some reason
# but we're not actually stopping, so just restart it
log.warning(
'Watcher: watcher stopped, restarting')
del allwatcher.Id
continue
except websockets.ConnectionClosed:
monitor = self.connection.monitor
if monitor.status == monitor.ERROR:
# closed unexpectedly, try to reopen
log.warning(
'Watcher: connection closed, reopening')
await self.connection.reconnect()
del allwatcher.Id
continue
else:
# closed on request, go ahead and shutdown
break
if self._watch_stopping.is_set():
await allwatcher.Stop()
break
for delta in results.deltas:
delta = get_entity_delta(delta)
Expand All @@ -671,7 +703,7 @@ async def _start_watch():
self._watch_received.clear()
self._watch_stopping.clear()
self._watch_stopped.clear()
self.loop.create_task(_start_watch())
self.loop.create_task(_all_watcher())

async def _notify_observers(self, delta, old_obj, new_obj):
"""Call observing callbacks, notifying them of a change in model state
Expand Down
5 changes: 4 additions & 1 deletion tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ async def __aenter__(self):
model_name = 'model-{}'.format(uuid.uuid4())
self.model = await self.controller.add_model(model_name)

# save the model UUID in case test closes model
self.model_uuid = self.model.info.uuid

# Ensure that we connect to the new model by default. This also
# prevents failures if test was started with no current model.
self._patch_cm = mock.patch.object(JujuData, 'current_model',
Expand All @@ -55,7 +58,7 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc, tb):
self._patch_cm.stop()
await self.model.disconnect()
await self.controller.destroy_model(self.model.info.uuid)
await self.controller.destroy_model(self.model_uuid)
await self.controller.disconnect()


Expand Down
Loading