Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions lib/py/src/transport/TTransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,19 @@ def open(self):
self.transport.open()

self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii'))
self.send_sasl_msg(self.OK, self.sasl.process())
initial_response = self.sasl.process()
self.send_sasl_msg(self.OK, initial_response if initial_response is not None else b'')

while True:
status, challenge = self.recv_sasl_msg()
if status == self.OK:
self.send_sasl_msg(self.OK, self.sasl.process(challenge))
elif status == self.COMPLETE:
if challenge:
# Process server's final challenge (e.g. DIGEST-MD5 rspauth
# verification). Return value intentionally unused; puresasl
# raises on verification failure.
self.sasl.process(challenge)
if not self.sasl.complete:
raise TTransportException(
TTransportException.NOT_OPEN,
Expand All @@ -403,6 +409,8 @@ def isOpen(self):
return self.transport.isOpen()

def send_sasl_msg(self, status, body):
if body is None:
body = b''
header = pack(">BI", status, len(body))
self.transport.write(header + body)
self.transport.flush()
Expand All @@ -413,7 +421,7 @@ def recv_sasl_msg(self):
if length > 0:
payload = self.transport.readAll(length)
else:
payload = ""
payload = b""
return status, payload

def write(self, data):
Expand Down
160 changes: 160 additions & 0 deletions lib/py/test/test_sasl_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import os
import sys
import types
import unittest
from unittest.mock import MagicMock, PropertyMock, call

# Register 'thrift' as a package alias for the src directory so that
# tests can run without a build step. This mirrors setup.py's
# package_dir={'thrift': 'src'} configuration.
_src_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), '..', 'src'))
if 'thrift' not in sys.modules:
_thrift_pkg = types.ModuleType('thrift')
_thrift_pkg.__path__ = [_src_dir]
_thrift_pkg.__package__ = 'thrift'
sys.modules['thrift'] = _thrift_pkg

# Stub puresasl so TSaslClientTransport can be imported without it installed.
sys.modules.setdefault('puresasl', types.ModuleType('puresasl'))
sys.modules.setdefault('puresasl.client', types.ModuleType('puresasl.client'))

from thrift.transport.TTransport import TSaslClientTransport
from thrift.transport.TTransport import TTransportException


class TSaslClientTransportTest(unittest.TestCase):

def _make_transport(self, process_side_effect, recv_messages, complete_value=True):
transport = object.__new__(TSaslClientTransport)

mock_inner = MagicMock()
mock_inner.isOpen.return_value = True
transport.transport = mock_inner

mock_sasl = MagicMock()
mock_sasl.mechanism = 'DIGEST-MD5'
mock_sasl.process.side_effect = process_side_effect
type(mock_sasl).complete = PropertyMock(return_value=complete_value)
transport.sasl = mock_sasl

transport.send_sasl_msg = MagicMock()
transport.recv_sasl_msg = MagicMock(side_effect=recv_messages)

return transport, mock_sasl

def test_open_with_none_initial_response(self):
transport, mock_sasl = self._make_transport(
process_side_effect=[None, b'response'],
recv_messages=[
(TSaslClientTransport.OK, b'server-challenge'),
(TSaslClientTransport.COMPLETE, b''),
],
)

transport.open()

transport.send_sasl_msg.assert_any_call(
TSaslClientTransport.START, b'DIGEST-MD5'
)
transport.send_sasl_msg.assert_any_call(TSaslClientTransport.OK, b'')
mock_sasl.process.assert_any_call(b'server-challenge')

def test_open_with_bytes_initial_response(self):
transport, mock_sasl = self._make_transport(
process_side_effect=[b'initial-token'],
recv_messages=[
(TSaslClientTransport.COMPLETE, b''),
],
)

transport.open()

transport.send_sasl_msg.assert_any_call(
TSaslClientTransport.OK, b'initial-token'
)

def test_open_complete_with_challenge(self):
transport, mock_sasl = self._make_transport(
process_side_effect=[b'initial', b'response', None],
recv_messages=[
(TSaslClientTransport.OK, b'challenge1'),
(TSaslClientTransport.COMPLETE, b'rspauth=abc123'),
],
)

transport.open()

mock_sasl.process.assert_any_call(b'rspauth=abc123')

def test_open_complete_without_challenge(self):
transport, mock_sasl = self._make_transport(
process_side_effect=[b'initial'],
recv_messages=[
(TSaslClientTransport.COMPLETE, b''),
],
)

transport.open()

process_calls = mock_sasl.process.call_args_list
self.assertNotIn(call(b''), process_calls)

def test_open_bad_status_raises(self):
transport, mock_sasl = self._make_transport(
process_side_effect=[b'initial'],
recv_messages=[
(0xFF, b'error message'),
],
)

with self.assertRaises(TTransportException) as ctx:
transport.open()
self.assertIn('Bad SASL negotiation status', str(ctx.exception))

def test_open_incomplete_after_complete_status_raises(self):
transport, mock_sasl = self._make_transport(
process_side_effect=[b'initial'],
recv_messages=[
(TSaslClientTransport.COMPLETE, b''),
],
complete_value=False,
)

with self.assertRaises(TTransportException) as ctx:
transport.open()
self.assertIn('erroneously indicated', str(ctx.exception))

def test_open_process_raises_during_complete(self):
transport, mock_sasl = self._make_transport(
process_side_effect=[b'initial', Exception('rspauth verification failed')],
recv_messages=[
(TSaslClientTransport.COMPLETE, b'rspauth=bad'),
],
)

with self.assertRaises(Exception) as ctx:
transport.open()
self.assertIn('rspauth verification failed', str(ctx.exception))


if __name__ == '__main__':
unittest.main()
Loading