diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py index 4f6b67fe123..14832270f10 100644 --- a/lib/py/src/transport/TTransport.py +++ b/lib/py/src/transport/TTransport.py @@ -379,13 +379,19 @@ def open(self): self.transport.open() self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii')) - self.send_sasl_msg(self.OK, self.sasl.process()) + initial_response = self.sasl.process() + self.send_sasl_msg(self.OK, initial_response if initial_response is not None else b'') while True: status, challenge = self.recv_sasl_msg() if status == self.OK: self.send_sasl_msg(self.OK, self.sasl.process(challenge)) elif status == self.COMPLETE: + if challenge: + # Process server's final challenge (e.g. DIGEST-MD5 rspauth + # verification). Return value intentionally unused; puresasl + # raises on verification failure. + self.sasl.process(challenge) if not self.sasl.complete: raise TTransportException( TTransportException.NOT_OPEN, @@ -403,6 +409,8 @@ def isOpen(self): return self.transport.isOpen() def send_sasl_msg(self, status, body): + if body is None: + body = b'' header = pack(">BI", status, len(body)) self.transport.write(header + body) self.transport.flush() @@ -413,7 +421,7 @@ def recv_sasl_msg(self): if length > 0: payload = self.transport.readAll(length) else: - payload = "" + payload = b"" return status, payload def write(self, data): diff --git a/lib/py/test/test_sasl_transport.py b/lib/py/test/test_sasl_transport.py new file mode 100644 index 00000000000..d71949b5995 --- /dev/null +++ b/lib/py/test/test_sasl_transport.py @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +import sys +import types +import unittest +from unittest.mock import MagicMock, PropertyMock, call + +# Register 'thrift' as a package alias for the src directory so that +# tests can run without a build step. This mirrors setup.py's +# package_dir={'thrift': 'src'} configuration. +_src_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), '..', 'src')) +if 'thrift' not in sys.modules: + _thrift_pkg = types.ModuleType('thrift') + _thrift_pkg.__path__ = [_src_dir] + _thrift_pkg.__package__ = 'thrift' + sys.modules['thrift'] = _thrift_pkg + +# Stub puresasl so TSaslClientTransport can be imported without it installed. +sys.modules.setdefault('puresasl', types.ModuleType('puresasl')) +sys.modules.setdefault('puresasl.client', types.ModuleType('puresasl.client')) + +from thrift.transport.TTransport import TSaslClientTransport +from thrift.transport.TTransport import TTransportException + + +class TSaslClientTransportTest(unittest.TestCase): + + def _make_transport(self, process_side_effect, recv_messages, complete_value=True): + transport = object.__new__(TSaslClientTransport) + + mock_inner = MagicMock() + mock_inner.isOpen.return_value = True + transport.transport = mock_inner + + mock_sasl = MagicMock() + mock_sasl.mechanism = 'DIGEST-MD5' + mock_sasl.process.side_effect = process_side_effect + type(mock_sasl).complete = PropertyMock(return_value=complete_value) + transport.sasl = mock_sasl + + transport.send_sasl_msg = MagicMock() + transport.recv_sasl_msg = MagicMock(side_effect=recv_messages) + + return transport, mock_sasl + + def test_open_with_none_initial_response(self): + transport, mock_sasl = self._make_transport( + process_side_effect=[None, b'response'], + recv_messages=[ + (TSaslClientTransport.OK, b'server-challenge'), + (TSaslClientTransport.COMPLETE, b''), + ], + ) + + transport.open() + + transport.send_sasl_msg.assert_any_call( + TSaslClientTransport.START, b'DIGEST-MD5' + ) + transport.send_sasl_msg.assert_any_call(TSaslClientTransport.OK, b'') + mock_sasl.process.assert_any_call(b'server-challenge') + + def test_open_with_bytes_initial_response(self): + transport, mock_sasl = self._make_transport( + process_side_effect=[b'initial-token'], + recv_messages=[ + (TSaslClientTransport.COMPLETE, b''), + ], + ) + + transport.open() + + transport.send_sasl_msg.assert_any_call( + TSaslClientTransport.OK, b'initial-token' + ) + + def test_open_complete_with_challenge(self): + transport, mock_sasl = self._make_transport( + process_side_effect=[b'initial', b'response', None], + recv_messages=[ + (TSaslClientTransport.OK, b'challenge1'), + (TSaslClientTransport.COMPLETE, b'rspauth=abc123'), + ], + ) + + transport.open() + + mock_sasl.process.assert_any_call(b'rspauth=abc123') + + def test_open_complete_without_challenge(self): + transport, mock_sasl = self._make_transport( + process_side_effect=[b'initial'], + recv_messages=[ + (TSaslClientTransport.COMPLETE, b''), + ], + ) + + transport.open() + + process_calls = mock_sasl.process.call_args_list + self.assertNotIn(call(b''), process_calls) + + def test_open_bad_status_raises(self): + transport, mock_sasl = self._make_transport( + process_side_effect=[b'initial'], + recv_messages=[ + (0xFF, b'error message'), + ], + ) + + with self.assertRaises(TTransportException) as ctx: + transport.open() + self.assertIn('Bad SASL negotiation status', str(ctx.exception)) + + def test_open_incomplete_after_complete_status_raises(self): + transport, mock_sasl = self._make_transport( + process_side_effect=[b'initial'], + recv_messages=[ + (TSaslClientTransport.COMPLETE, b''), + ], + complete_value=False, + ) + + with self.assertRaises(TTransportException) as ctx: + transport.open() + self.assertIn('erroneously indicated', str(ctx.exception)) + + def test_open_process_raises_during_complete(self): + transport, mock_sasl = self._make_transport( + process_side_effect=[b'initial', Exception('rspauth verification failed')], + recv_messages=[ + (TSaslClientTransport.COMPLETE, b'rspauth=bad'), + ], + ) + + with self.assertRaises(Exception) as ctx: + transport.open() + self.assertIn('rspauth verification failed', str(ctx.exception)) + + +if __name__ == '__main__': + unittest.main()