Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 45 additions & 25 deletions jsonrpc/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
log = logging.getLogger(__name__)


class _StreamError(Exception):
"""Raised on stream errors."""


class JsonRpcStreamReader(object):

def __init__(self, rfile):
Expand All @@ -21,9 +25,10 @@ def listen(self, message_consumer):
message_consumer (fn): function that is passed each message as it is read off the socket.
"""
while not self._rfile.closed:
request_str = self._read_message()

if request_str is None:
try:
request_str = self._read_message()
except _StreamError:
log.exception("Failed to read message.")
break

try:
Expand All @@ -36,37 +41,52 @@ def _read_message(self):
"""Reads the contents of a message.

Returns:
body of message if parsable else None
body of message

Raises:
_StreamError: If message was not parsable.
"""
line = self._rfile.readline()
# Read the headers
headers = self._read_headers()

if not line:
return None
try:
content_length = int(headers[b"Content-Length"])
except (ValueError, KeyError):
raise _StreamError("Invalid or missing Content-Length headers: {}".format(headers))

content_length = self._content_length(line)
# Grab the body
body = self._rfile.read(content_length)
if not body:
raise _StreamError("Got EOF when reading from stream")

# Blindly consume all header lines
while line and line.strip():
line = self._rfile.readline()
return body

if not line:
return None
def _read_headers(self):
"""Read the headers from a LSP base message.

Returns:
dict: A dict containing the headers and their values.

Raises:
_StreamError: If headers are not parsable.
"""
headers = {}
while True:
line = self._rfile.readline()
if not line:
raise _StreamError("Got EOF when reading from stream")
if not line.strip():
# Finished reading headers break while loop
break

# Grab the body
return self._rfile.read(content_length)

@staticmethod
def _content_length(line):
"""Extract the content length from an input line."""
if line.startswith(b'Content-Length: '):
_, value = line.split(b'Content-Length: ')
value = value.strip()
try:
return int(value)
key, value = line.split(b":")
except ValueError:
raise ValueError("Invalid Content-Length header: {}".format(value))
raise _StreamError("Invalid header {}: ".format(line))

headers[key.strip()] = value.strip()

return None
return headers


class JsonRpcStreamWriter(object):
Expand Down
42 changes: 26 additions & 16 deletions test/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,22 @@ def writer(wfile):
return JsonRpcStreamWriter(wfile, sort_keys=True)


def test_reader(rfile, reader):
rfile.write(
@pytest.mark.parametrize("data", [
(
b'Content-Length: 49\r\n'
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
b'\r\n'
b'{"id": "hello", "method": "method", "params": {}}'
)
),
(
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
b'Content-Length: 49\r\n'
b'\r\n'
b'{"id": "hello", "method": "method", "params": {}}'
),
], ids=["Content-Length first", "Content-Length middle"])
def test_reader(rfile, reader, data):
rfile.write(data)
rfile.seek(0)

consumer = mock.Mock()
Expand All @@ -46,23 +55,24 @@ def test_reader(rfile, reader):
})


def test_reader_bad_message(rfile, reader):
rfile.write(b'Hello world')
rfile.seek(0)

# Ensure the listener doesn't throw
consumer = mock.Mock()
reader.listen(consumer)
consumer.assert_not_called()


def test_reader_bad_json(rfile, reader):
rfile.write(
@pytest.mark.parametrize("data", [
(
b'hello'
),
(
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
b'Content-Length: NOT_AN_INT\r\n'
b'\r\n'
),
(
b'Content-Length: 8\r\n'
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
b'\r\n'
b'{hello}}'
)
),
], ids=["hello", "Invalid Content-Length", "Bad json"])
def test_reader_bad_message(rfile, reader, data):
rfile.write(data)
rfile.seek(0)

# Ensure the listener doesn't throw
Expand Down