diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst index 7b761d7ec58d0d..9cc9a6869fc9ca 100644 --- a/Doc/library/socket.rst +++ b/Doc/library/socket.rst @@ -636,6 +636,23 @@ correspond to Unix system calls applicable to sockets. to decode C structures encoded as strings). +.. method:: socket.getblocking() + + Return ``True`` if socket is in blocking mode, ``False`` if in + non-blocking. + + This is equivalent to checking ``socket.gettimeout() == 0``. + + .. versionadded:: 3.7 + + +.. method:: socket.gettimeout() + + Return the timeout in seconds (float) associated with socket operations, + or ``None`` if no timeout is set. This reflects the last call to + :meth:`setblocking` or :meth:`settimeout`. + + .. method:: socket.ioctl(control, option) :platform: Windows diff --git a/Include/Python.h b/Include/Python.h index a9327b02c81f09..384c86a745a0df 100644 --- a/Include/Python.h +++ b/Include/Python.h @@ -78,6 +78,7 @@ #error "PYMALLOC_DEBUG requires WITH_PYMALLOC" #endif #include "pymath.h" +#include "pytime.h" #include "pymem.h" #include "object.h" diff --git a/Include/pytime.h b/Include/pytime.h new file mode 100644 index 00000000000000..bdda1da2e6b8f2 --- /dev/null +++ b/Include/pytime.h @@ -0,0 +1,246 @@ +#ifndef Py_LIMITED_API +#ifndef Py_PYTIME_H +#define Py_PYTIME_H + +#include "pyconfig.h" /* include for defines */ +#include "object.h" + +/************************************************************************** +Symbols and macros to supply platform-independent interfaces to time related +functions and constants +**************************************************************************/ +#ifdef __cplusplus +extern "C" { +#endif + +/* _PyTime_t: Python timestamp with subsecond precision. It can be used to + store a duration, and so indirectly a date (related to another date, like + UNIX epoch). */ +typedef int64_t _PyTime_t; +#define _PyTime_MIN INT64_MIN +#define _PyTime_MAX INT64_MAX + +typedef enum { + /* Round towards minus infinity (-inf). + For example, used to read a clock. */ + _PyTime_ROUND_FLOOR=0, + /* Round towards infinity (+inf). + For example, used for timeout to wait "at least" N seconds. */ + _PyTime_ROUND_CEILING=1, + /* Round to nearest with ties going to nearest even integer. + For example, used to round from a Python float. */ + _PyTime_ROUND_HALF_EVEN=2, + /* Round away from zero + For example, used for timeout. _PyTime_ROUND_CEILING rounds + -1e-9 to 0 milliseconds which causes bpo-31786 issue. + _PyTime_ROUND_UP rounds -1e-9 to -1 millisecond which keeps + the timeout sign as expected. select.poll(timeout) must block + for negative values." */ + _PyTime_ROUND_UP=3, + /* _PyTime_ROUND_TIMEOUT (an alias for _PyTime_ROUND_UP) should be + used for timeouts. */ + _PyTime_ROUND_TIMEOUT = _PyTime_ROUND_UP +} _PyTime_round_t; + + +/* Convert a time_t to a PyLong. */ +PyAPI_FUNC(PyObject *) _PyLong_FromTime_t( + time_t sec); + +/* Convert a PyLong to a time_t. */ +PyAPI_FUNC(time_t) _PyLong_AsTime_t( + PyObject *obj); + +/* Convert a number of seconds, int or float, to time_t. */ +PyAPI_FUNC(int) _PyTime_ObjectToTime_t( + PyObject *obj, + time_t *sec, + _PyTime_round_t); + +/* Convert a number of seconds, int or float, to a timeval structure. + usec is in the range [0; 999999] and rounded towards zero. + For example, -1.2 is converted to (-2, 800000). */ +PyAPI_FUNC(int) _PyTime_ObjectToTimeval( + PyObject *obj, + time_t *sec, + long *usec, + _PyTime_round_t); + +/* Convert a number of seconds, int or float, to a timespec structure. + nsec is in the range [0; 999999999] and rounded towards zero. + For example, -1.2 is converted to (-2, 800000000). */ +PyAPI_FUNC(int) _PyTime_ObjectToTimespec( + PyObject *obj, + time_t *sec, + long *nsec, + _PyTime_round_t); + + +/* Create a timestamp from a number of seconds. */ +PyAPI_FUNC(_PyTime_t) _PyTime_FromSeconds(int seconds); + +/* Macro to create a timestamp from a number of seconds, no integer overflow. + Only use the macro for small values, prefer _PyTime_FromSeconds(). */ +#define _PYTIME_FROMSECONDS(seconds) \ + ((_PyTime_t)(seconds) * (1000 * 1000 * 1000)) + +/* Create a timestamp from a number of nanoseconds. */ +PyAPI_FUNC(_PyTime_t) _PyTime_FromNanoseconds(_PyTime_t ns); + +/* Create a timestamp from nanoseconds (Python int). */ +PyAPI_FUNC(int) _PyTime_FromNanosecondsObject(_PyTime_t *t, + PyObject *obj); + +/* Convert a number of seconds (Python float or int) to a timetamp. + Raise an exception and return -1 on error, return 0 on success. */ +PyAPI_FUNC(int) _PyTime_FromSecondsObject(_PyTime_t *t, + PyObject *obj, + _PyTime_round_t round); + +/* Convert a number of milliseconds (Python float or int, 10^-3) to a timetamp. + Raise an exception and return -1 on error, return 0 on success. */ +PyAPI_FUNC(int) _PyTime_FromMillisecondsObject(_PyTime_t *t, + PyObject *obj, + _PyTime_round_t round); + +/* Convert a timestamp to a number of seconds as a C double. */ +PyAPI_FUNC(double) _PyTime_AsSecondsDouble(_PyTime_t t); + +/* Convert timestamp to a number of milliseconds (10^-3 seconds). */ +PyAPI_FUNC(_PyTime_t) _PyTime_AsMilliseconds(_PyTime_t t, + _PyTime_round_t round); + +/* Convert timestamp to a number of microseconds (10^-6 seconds). */ +PyAPI_FUNC(_PyTime_t) _PyTime_AsMicroseconds(_PyTime_t t, + _PyTime_round_t round); + +/* Convert timestamp to a number of nanoseconds (10^-9 seconds) as a Python int + object. */ +PyAPI_FUNC(PyObject *) _PyTime_AsNanosecondsObject(_PyTime_t t); + +/* Create a timestamp from a timeval structure. + Raise an exception and return -1 on overflow, return 0 on success. */ +PyAPI_FUNC(int) _PyTime_FromTimeval(_PyTime_t *tp, struct timeval *tv); + +/* Convert a timestamp to a timeval structure (microsecond resolution). + tv_usec is always positive. + Raise an exception and return -1 if the conversion overflowed, + return 0 on success. */ +PyAPI_FUNC(int) _PyTime_AsTimeval(_PyTime_t t, + struct timeval *tv, + _PyTime_round_t round); + +/* Similar to _PyTime_AsTimeval(), but don't raise an exception on error. */ +PyAPI_FUNC(int) _PyTime_AsTimeval_noraise(_PyTime_t t, + struct timeval *tv, + _PyTime_round_t round); + +/* Convert a timestamp to a number of seconds (secs) and microseconds (us). + us is always positive. This function is similar to _PyTime_AsTimeval() + except that secs is always a time_t type, whereas the timeval structure + uses a C long for tv_sec on Windows. + Raise an exception and return -1 if the conversion overflowed, + return 0 on success. */ +PyAPI_FUNC(int) _PyTime_AsTimevalTime_t( + _PyTime_t t, + time_t *secs, + int *us, + _PyTime_round_t round); + +#if defined(HAVE_CLOCK_GETTIME) || defined(HAVE_KQUEUE) +/* Create a timestamp from a timespec structure. + Raise an exception and return -1 on overflow, return 0 on success. */ +PyAPI_FUNC(int) _PyTime_FromTimespec(_PyTime_t *tp, struct timespec *ts); + +/* Convert a timestamp to a timespec structure (nanosecond resolution). + tv_nsec is always positive. + Raise an exception and return -1 on error, return 0 on success. */ +PyAPI_FUNC(int) _PyTime_AsTimespec(_PyTime_t t, struct timespec *ts); +#endif + +/* Compute ticks * mul / div. + The caller must ensure that ((div - 1) * mul) cannot overflow. */ +PyAPI_FUNC(_PyTime_t) _PyTime_MulDiv(_PyTime_t ticks, + _PyTime_t mul, + _PyTime_t div); + +/* Get the current time from the system clock. + + The function cannot fail. _PyTime_Init() ensures that the system clock + works. */ +PyAPI_FUNC(_PyTime_t) _PyTime_GetSystemClock(void); + +/* Get the time of a monotonic clock, i.e. a clock that cannot go backwards. + The clock is not affected by system clock updates. The reference point of + the returned value is undefined, so that only the difference between the + results of consecutive calls is valid. + + The function cannot fail. _PyTime_Init() ensures that a monotonic clock + is available and works. */ +PyAPI_FUNC(_PyTime_t) _PyTime_GetMonotonicClock(void); + + +/* Structure used by time.get_clock_info() */ +typedef struct { + const char *implementation; + int monotonic; + int adjustable; + double resolution; +} _Py_clock_info_t; + +/* Get the current time from the system clock. + * Fill clock information if info is not NULL. + * Raise an exception and return -1 on error, return 0 on success. + */ +PyAPI_FUNC(int) _PyTime_GetSystemClockWithInfo( + _PyTime_t *t, + _Py_clock_info_t *info); + +/* Get the time of a monotonic clock, i.e. a clock that cannot go backwards. + The clock is not affected by system clock updates. The reference point of + the returned value is undefined, so that only the difference between the + results of consecutive calls is valid. + + Fill info (if set) with information of the function used to get the time. + + Return 0 on success, raise an exception and return -1 on error. */ +PyAPI_FUNC(int) _PyTime_GetMonotonicClockWithInfo( + _PyTime_t *t, + _Py_clock_info_t *info); + + +/* Initialize time. + Return 0 on success, raise an exception and return -1 on error. */ +PyAPI_FUNC(int) _PyTime_Init(void); + +/* Converts a timestamp to the Gregorian time, using the local time zone. + Return 0 on success, raise an exception and return -1 on error. */ +PyAPI_FUNC(int) _PyTime_localtime(time_t t, struct tm *tm); + +/* Converts a timestamp to the Gregorian time, assuming UTC. + Return 0 on success, raise an exception and return -1 on error. */ +PyAPI_FUNC(int) _PyTime_gmtime(time_t t, struct tm *tm); + +/* Get the performance counter: clock with the highest available resolution to + measure a short duration. + + The function cannot fail. _PyTime_Init() ensures that the system clock + works. */ +PyAPI_FUNC(_PyTime_t) _PyTime_GetPerfCounter(void); + +/* Get the performance counter: clock with the highest available resolution to + measure a short duration. + + Fill info (if set) with information of the function used to get the time. + + Return 0 on success, raise an exception and return -1 on error. */ +PyAPI_FUNC(int) _PyTime_GetPerfCounterWithInfo( + _PyTime_t *t, + _Py_clock_info_t *info); + +#ifdef __cplusplus +} +#endif + +#endif /* Py_PYTIME_H */ +#endif /* Py_LIMITED_API */ diff --git a/Lib/socket.py b/Lib/socket.py index 437634cc3b58ce..55e446cd63f76f 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -154,7 +154,7 @@ def getfqdn(name=''): _socketmethods = ( 'bind', 'connect', 'connect_ex', 'fileno', 'listen', 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', - 'sendall', 'setblocking', + 'sendall', 'getblocking', 'setblocking', 'settimeout', 'gettimeout', 'shutdown') if os.name == "nt": diff --git a/Lib/ssl.py b/Lib/ssl.py index 0bb43a4a4de11e..e295aa5939df83 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -94,6 +94,10 @@ import os from collections import namedtuple from contextlib import closing +import socket + +# import logging + import _ssl # if we can't import it, let the error propagate @@ -523,6 +527,7 @@ def __init__(self, sock=None, keyfile=None, certfile=None, server_hostname=None, _context=None): + self._sslobj = None self._makefile_refs = 0 if _context: self._context = _context @@ -568,6 +573,7 @@ def __init__(self, sock=None, keyfile=None, certfile=None, "in client mode") if self._context.check_hostname and not server_hostname: raise ValueError("check_hostname requires server_hostname") + sock_timeout = sock.gettimeout() self.server_side = server_side self.server_hostname = server_hostname self.do_handshake_on_connect = do_handshake_on_connect @@ -580,11 +586,42 @@ def __init__(self, sock=None, keyfile=None, certfile=None, if e.errno != errno.ENOTCONN: raise connected = False + blocking = self.getblocking() + self.setblocking(False) + try: + # We are not connected so this is not supposed to block, but + # testing revealed otherwise on macOS and Windows so we do + # the non-blocking dance regardless. Our raise when any data + # is found means consuming the data is harmless. + notconn_pre_handshake_data = self.recv(1) + except (OSError, socket_error) as e: + + # EINVAL occurs for recv(1) on non-connected on unix sockets. + if e.errno not in (errno.ENOTCONN, errno.EINVAL): + raise + + notconn_pre_handshake_data = b'' + self.setblocking(blocking) + if notconn_pre_handshake_data: + # This prevents pending data sent to the socket before it was + # closed from escaping to the caller who could otherwise + # presume it came through a successful TLS connection. + reason = "Closed before TLS handshake with data in recv buffer." + notconn_pre_handshake_data_error = SSLError(e.errno, reason) + # Add the SSLError attributes that _ssl.c always adds. + notconn_pre_handshake_data_error.reason = reason + notconn_pre_handshake_data_error.library = None + try: + self.close() + except (OSError, socket_error): + pass + raise notconn_pre_handshake_data_error else: connected = True self._closed = False self._sslobj = None + self.settimeout(sock_timeout) # Must come after setblocking() calls. self._connected = connected if connected: # create the SSL object @@ -598,7 +635,7 @@ def __init__(self, sock=None, keyfile=None, certfile=None, raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") self.do_handshake() - except (OSError, ValueError): + except (OSError, ValueError, socket_error): self.close() raise @@ -814,6 +851,12 @@ def unwrap(self): else: raise ValueError("No SSL wrapper around " + str(self)) + def verify_client_post_handshake(self): + if self._sslobj: + return self._sslobj.verify_client_post_handshake() + else: + raise ValueError("No SSL wrapper around " + str(self)) + def _real_close(self): self._sslobj = None socket._real_close(self) @@ -915,7 +958,6 @@ def version(self): return None return self._sslobj.version() - def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_TLS, ca_certs=None, diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 988e12a813d5fb..2c1b48c698ae99 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -50,6 +50,13 @@ def try_address(host, port=0, family=socket.AF_INET): HOST = test_support.HOST MSG = 'Michael Gilfix was here\n' +def _is_fd_in_blocking_mode(sock): + return not bool( + fcntl.fcntl(sock, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK) + + +HAVE_SOCKET_CAN = _have_socket_can() + class SocketTCPTest(unittest.TestCase): def setUp(self): @@ -957,8 +964,44 @@ def testSetBlocking(self): # Testing whether set blocking works self.serv.setblocking(True) self.assertIsNone(self.serv.gettimeout()) + self.assertTrue(self.serv.getblocking()) + if fcntl: + self.assertTrue(_is_fd_in_blocking_mode(self.serv)) + self.serv.setblocking(False) self.assertEqual(self.serv.gettimeout(), 0.0) + self.assertFalse(self.serv.getblocking()) + if fcntl: + self.assertFalse(_is_fd_in_blocking_mode(self.serv)) + + self.serv.settimeout(None) + self.assertTrue(self.serv.getblocking()) + if fcntl: + self.assertTrue(_is_fd_in_blocking_mode(self.serv)) + + self.serv.settimeout(0) + self.assertFalse(self.serv.getblocking()) + self.assertEqual(self.serv.gettimeout(), 0) + if fcntl: + self.assertFalse(_is_fd_in_blocking_mode(self.serv)) + + self.serv.settimeout(10) + self.assertTrue(self.serv.getblocking()) + self.assertEqual(self.serv.gettimeout(), 10) + if fcntl: + # When a Python socket has a non-zero timeout, it's + # switched internally to a non-blocking mode. + # Later, sock.sendall(), sock.recv(), and other socket + # operations use a `select()` call and handle EWOULDBLOCK/EGAIN + # on all socket operations. That's how timeouts are + # enforced. + self.assertFalse(_is_fd_in_blocking_mode(self.serv)) + + self.serv.settimeout(0) + self.assertFalse(self.serv.getblocking()) + if fcntl: + self.assertFalse(_is_fd_in_blocking_mode(self.serv)) + start = time.time() try: self.serv.accept() @@ -983,6 +1026,47 @@ def testSetBlocking_overflow(self): _testSetBlocking_overflow = test_support.cpython_only(_testSetBlocking) + @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), + 'test needs socket.SOCK_NONBLOCK') + @support.requires_linux_version(2, 6, 28) + def testInitNonBlocking(self): + # reinit server socket + self.serv.close() + self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM | + socket.SOCK_NONBLOCK) + self.assertFalse(self.serv.getblocking()) + self.assertEqual(self.serv.gettimeout(), 0) + self.port = support.bind_port(self.serv) + self.serv.listen() + # actual testing + start = time.time() + try: + self.serv.accept() + except OSError: + pass + end = time.time() + self.assertTrue((end - start) < 1.0, "Error creating with non-blocking mode.") + + def _testInitNonBlocking(self): + pass + + def testInheritFlags(self): + # Issue #7995: when calling accept() on a listening socket with a + # timeout, the resulting socket should not be non-blocking. + self.serv.settimeout(10) + try: + conn, addr = self.serv.accept() + message = conn.recv(len(MSG)) + finally: + conn.close() + self.serv.settimeout(None) + + def _testInheritFlags(self): + time.sleep(0.1) + self.cli.connect((HOST, self.port)) + time.sleep(0.5) + self.cli.send(MSG) + def testAccept(self): # Testing non-blocking accept self.serv.setblocking(0) @@ -1805,6 +1889,728 @@ def _testStream(self): self.cli.close() +class ContextManagersTest(ThreadedTCPSocketTest): + + def _testSocketClass(self): + # base test + with socket.socket() as sock: + self.assertFalse(sock._closed) + self.assertTrue(sock._closed) + # close inside with block + with socket.socket() as sock: + sock.close() + self.assertTrue(sock._closed) + # exception inside with block + with socket.socket() as sock: + self.assertRaises(OSError, sock.sendall, b'foo') + self.assertTrue(sock._closed) + + def testCreateConnectionBase(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + data = conn.recv(1024) + conn.sendall(data) + + def _testCreateConnectionBase(self): + address = self.serv.getsockname() + with socket.create_connection(address) as sock: + self.assertFalse(sock._closed) + sock.sendall(b'foo') + self.assertEqual(sock.recv(1024), b'foo') + self.assertTrue(sock._closed) + + def testCreateConnectionClose(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + data = conn.recv(1024) + conn.sendall(data) + + def _testCreateConnectionClose(self): + address = self.serv.getsockname() + with socket.create_connection(address) as sock: + sock.close() + self.assertTrue(sock._closed) + self.assertRaises(OSError, sock.sendall, b'foo') + + +class InheritanceTest(unittest.TestCase): + @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), + "SOCK_CLOEXEC not defined") + @support.requires_linux_version(2, 6, 28) + def test_SOCK_CLOEXEC(self): + with socket.socket(socket.AF_INET, + socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s: + self.assertEqual(s.type, socket.SOCK_STREAM) + self.assertFalse(s.get_inheritable()) + + def test_default_inheritable(self): + sock = socket.socket() + with sock: + self.assertEqual(sock.get_inheritable(), False) + + def test_dup(self): + sock = socket.socket() + with sock: + newsock = sock.dup() + sock.close() + with newsock: + self.assertEqual(newsock.get_inheritable(), False) + + def test_set_inheritable(self): + sock = socket.socket() + with sock: + sock.set_inheritable(True) + self.assertEqual(sock.get_inheritable(), True) + + sock.set_inheritable(False) + self.assertEqual(sock.get_inheritable(), False) + + @unittest.skipIf(fcntl is None, "need fcntl") + def test_get_inheritable_cloexec(self): + sock = socket.socket() + with sock: + fd = sock.fileno() + self.assertEqual(sock.get_inheritable(), False) + + # clear FD_CLOEXEC flag + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + flags &= ~fcntl.FD_CLOEXEC + fcntl.fcntl(fd, fcntl.F_SETFD, flags) + + self.assertEqual(sock.get_inheritable(), True) + + @unittest.skipIf(fcntl is None, "need fcntl") + def test_set_inheritable_cloexec(self): + sock = socket.socket() + with sock: + fd = sock.fileno() + self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC, + fcntl.FD_CLOEXEC) + + sock.set_inheritable(True) + self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC, + 0) + + + def test_socketpair(self): + s1, s2 = socket.socketpair() + self.addCleanup(s1.close) + self.addCleanup(s2.close) + self.assertEqual(s1.get_inheritable(), False) + self.assertEqual(s2.get_inheritable(), False) + + +@unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"), + "SOCK_NONBLOCK not defined") +class NonblockConstantTest(unittest.TestCase): + def checkNonblock(self, s, nonblock=True, timeout=0.0): + if nonblock: + self.assertEqual(s.type, socket.SOCK_STREAM) + self.assertEqual(s.gettimeout(), timeout) + self.assertTrue( + fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK) + if timeout == 0: + # timeout == 0: means that getblocking() must be False. + self.assertFalse(s.getblocking()) + else: + # If timeout > 0, the socket will be in a "blocking" mode + # from the standpoint of the Python API. For Python socket + # object, "blocking" means that operations like 'sock.recv()' + # will block. Internally, file descriptors for + # "blocking" Python sockets *with timeouts* are in a + # *non-blocking* mode, and 'sock.recv()' uses 'select()' + # and handles EWOULDBLOCK/EAGAIN to enforce the timeout. + self.assertTrue(s.getblocking()) + else: + self.assertEqual(s.type, socket.SOCK_STREAM) + self.assertEqual(s.gettimeout(), None) + self.assertFalse( + fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK) + self.assertTrue(s.getblocking()) + + @support.requires_linux_version(2, 6, 28) + def test_SOCK_NONBLOCK(self): + # a lot of it seems silly and redundant, but I wanted to test that + # changing back and forth worked ok + with socket.socket(socket.AF_INET, + socket.SOCK_STREAM | socket.SOCK_NONBLOCK) as s: + self.checkNonblock(s) + s.setblocking(1) + self.checkNonblock(s, nonblock=False) + s.setblocking(0) + self.checkNonblock(s) + s.settimeout(None) + self.checkNonblock(s, nonblock=False) + s.settimeout(2.0) + self.checkNonblock(s, timeout=2.0) + s.setblocking(1) + self.checkNonblock(s, nonblock=False) + # defaulttimeout + t = socket.getdefaulttimeout() + socket.setdefaulttimeout(0.0) + with socket.socket() as s: + self.checkNonblock(s) + socket.setdefaulttimeout(None) + with socket.socket() as s: + self.checkNonblock(s, False) + socket.setdefaulttimeout(2.0) + with socket.socket() as s: + self.checkNonblock(s, timeout=2.0) + socket.setdefaulttimeout(None) + with socket.socket() as s: + self.checkNonblock(s, False) + socket.setdefaulttimeout(t) + + +@unittest.skipUnless(os.name == "nt", "Windows specific") +@unittest.skipUnless(multiprocessing, "need multiprocessing") +class TestSocketSharing(SocketTCPTest): + # This must be classmethod and not staticmethod or multiprocessing + # won't be able to bootstrap it. + @classmethod + def remoteProcessServer(cls, q): + # Recreate socket from shared data + sdata = q.get() + message = q.get() + + s = socket.fromshare(sdata) + s2, c = s.accept() + + # Send the message + s2.sendall(message) + s2.close() + s.close() + + def testShare(self): + # Transfer the listening server socket to another process + # and service it from there. + + # Create process: + q = multiprocessing.Queue() + p = multiprocessing.Process(target=self.remoteProcessServer, args=(q,)) + p.start() + + # Get the shared socket data + data = self.serv.share(p.pid) + + # Pass the shared socket to the other process + addr = self.serv.getsockname() + self.serv.close() + q.put(data) + + # The data that the server will send us + message = b"slapmahfro" + q.put(message) + + # Connect + s = socket.create_connection(addr) + # listen for the data + m = [] + while True: + data = s.recv(100) + if not data: + break + m.append(data) + s.close() + received = b"".join(m) + self.assertEqual(received, message) + p.join() + + def testShareLength(self): + data = self.serv.share(os.getpid()) + self.assertRaises(ValueError, socket.fromshare, data[:-1]) + self.assertRaises(ValueError, socket.fromshare, data+b"foo") + + def compareSockets(self, org, other): + # socket sharing is expected to work only for blocking socket + # since the internal python timeout value isn't transferred. + self.assertEqual(org.gettimeout(), None) + self.assertEqual(org.gettimeout(), other.gettimeout()) + + self.assertEqual(org.family, other.family) + self.assertEqual(org.type, other.type) + # If the user specified "0" for proto, then + # internally windows will have picked the correct value. + # Python introspection on the socket however will still return + # 0. For the shared socket, the python value is recreated + # from the actual value, so it may not compare correctly. + if org.proto != 0: + self.assertEqual(org.proto, other.proto) + + def testShareLocal(self): + data = self.serv.share(os.getpid()) + s = socket.fromshare(data) + try: + self.compareSockets(self.serv, s) + finally: + s.close() + + def testTypes(self): + families = [socket.AF_INET, socket.AF_INET6] + types = [socket.SOCK_STREAM, socket.SOCK_DGRAM] + for f in families: + for t in types: + try: + source = socket.socket(f, t) + except OSError: + continue # This combination is not supported + try: + data = source.share(os.getpid()) + shared = socket.fromshare(data) + try: + self.compareSockets(source, shared) + finally: + shared.close() + finally: + source.close() + + +class SendfileUsingSendTest(ThreadedTCPSocketTest): + """ + Test the send() implementation of socket.sendfile(). + """ + + FILESIZE = (10 * 1024 * 1024) # 10 MiB + BUFSIZE = 8192 + FILEDATA = b"" + TIMEOUT = 2 + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step + total -= step + if total: + yield total + + chunk = b"".join([random.choice(string.ascii_letters).encode() + for i in range(cls.BUFSIZE)]) + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, cls.BUFSIZE): + f.write(chunk) + with open(support.TESTFN, 'rb') as f: + cls.FILEDATA = f.read() + assert len(cls.FILEDATA) == cls.FILESIZE + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + def accept_conn(self): + self.serv.settimeout(self.TIMEOUT) + conn, addr = self.serv.accept() + conn.settimeout(self.TIMEOUT) + self.addCleanup(conn.close) + return conn + + def recv_data(self, conn): + received = [] + while True: + chunk = conn.recv(self.BUFSIZE) + if not chunk: + break + received.append(chunk) + return b''.join(received) + + def meth_from_sock(self, sock): + # Depending on the mixin class being run return either send() + # or sendfile() method implementation. + return getattr(sock, "_sendfile_use_send") + + # regular file + + def _testRegularFile(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + + def testRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non regular file + + def _testNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(self.FILEDATA) + with socket.create_connection(address) as sock, file as file: + sent = sock.sendfile(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + self.assertRaises(socket._GiveupOnSendfile, + sock._sendfile_use_sendfile, file) + + def testNonRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # empty file + + def _testEmptyFileSend(self): + address = self.serv.getsockname() + filename = support.TESTFN + "2" + with open(filename, 'wb'): + self.addCleanup(support.unlink, filename) + file = open(filename, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, 0) + self.assertEqual(file.tell(), 0) + + def testEmptyFileSend(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(data, b"") + + # offset + + def _testOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file, offset=5000) + self.assertEqual(sent, self.FILESIZE - 5000) + self.assertEqual(file.tell(), self.FILESIZE) + + def testOffset(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE - 5000) + self.assertEqual(data, self.FILEDATA[5000:]) + + # count + + def _testCount(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 5000007 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCount(self): + count = 5000007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count small + + def _testCountSmall(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 1 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCountSmall(self): + count = 1 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count + offset + + def _testCountWithOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 100007 + meth = self.meth_from_sock(sock) + sent = meth(file, offset=2007, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count + 2007) + + def testCountWithOffset(self): + count = 100007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[2007:count+2007]) + + # non blocking sockets are not supposed to work + + def _testNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + meth = self.meth_from_sock(sock) + self.assertRaises(ValueError, meth, file) + self.assertRaises(ValueError, sock.sendfile, file) + + def testNonBlocking(self): + conn = self.accept_conn() + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # timeout (non-triggered) + + def _testWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + + def testWithTimeout(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # timeout (triggered) + + def _testWithTimeoutTriggeredSend(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + meth = self.meth_from_sock(sock) + self.assertRaises(socket.timeout, meth, file) + + def testWithTimeoutTriggeredSend(self): + conn = self.accept_conn() + conn.recv(88192) + + # errors + + def _test_errors(self): + pass + + def test_errors(self): + with open(support.TESTFN, 'rb') as file: + with socket.socket(type=socket.SOCK_DGRAM) as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "SOCK_STREAM", meth, file) + with open(support.TESTFN, 'rt') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "binary mode", meth, file) + with open(support.TESTFN, 'rb') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count='2') + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count=0.1) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=0) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=-1) + + +@unittest.skipUnless(hasattr(os, "sendfile"), + 'os.sendfile() required for this test.') +class SendfileUsingSendfileTest(SendfileUsingSendTest): + """ + Test the sendfile() implementation of socket.sendfile(). + """ + def meth_from_sock(self, sock): + return getattr(sock, "_sendfile_use_sendfile") + + +@unittest.skipUnless(HAVE_SOCKET_ALG, 'AF_ALG required') +class LinuxKernelCryptoAPI(unittest.TestCase): + # tests for AF_ALG + def create_alg(self, typ, name): + sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0) + try: + sock.bind((typ, name)) + except FileNotFoundError as e: + # type / algorithm is not available + sock.close() + raise unittest.SkipTest(str(e), typ, name) + else: + return sock + + # bpo-31705: On kernel older than 4.5, sendto() failed with ENOKEY, + # at least on ppc64le architecture + @support.requires_linux_version(4, 5) + def test_sha256(self): + expected = bytes.fromhex("ba7816bf8f01cfea414140de5dae2223b00361a396" + "177a9cb410ff61f20015ad") + with self.create_alg('hash', 'sha256') as algo: + op, _ = algo.accept() + with op: + op.sendall(b"abc") + self.assertEqual(op.recv(512), expected) + + op, _ = algo.accept() + with op: + op.send(b'a', socket.MSG_MORE) + op.send(b'b', socket.MSG_MORE) + op.send(b'c', socket.MSG_MORE) + op.send(b'') + self.assertEqual(op.recv(512), expected) + + def test_hmac_sha1(self): + expected = bytes.fromhex("effcdf6ae5eb2fa2d27416d5f184df9c259a7c79") + with self.create_alg('hash', 'hmac(sha1)') as algo: + algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, b"Jefe") + op, _ = algo.accept() + with op: + op.sendall(b"what do ya want for nothing?") + self.assertEqual(op.recv(512), expected) + + # Although it should work with 3.19 and newer the test blocks on + # Ubuntu 15.10 with Kernel 4.2.0-19. + @support.requires_linux_version(4, 3) + def test_aes_cbc(self): + key = bytes.fromhex('06a9214036b8a15b512e03d534120006') + iv = bytes.fromhex('3dafba429d9eb430b422da802c9fac41') + msg = b"Single block msg" + ciphertext = bytes.fromhex('e353779c1079aeb82708942dbe77181a') + msglen = len(msg) + with self.create_alg('skcipher', 'cbc(aes)') as algo: + algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key) + op, _ = algo.accept() + with op: + op.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, iv=iv, + flags=socket.MSG_MORE) + op.sendall(msg) + self.assertEqual(op.recv(msglen), ciphertext) + + op, _ = algo.accept() + with op: + op.sendmsg_afalg([ciphertext], + op=socket.ALG_OP_DECRYPT, iv=iv) + self.assertEqual(op.recv(msglen), msg) + + # long message + multiplier = 1024 + longmsg = [msg] * multiplier + op, _ = algo.accept() + with op: + op.sendmsg_afalg(longmsg, + op=socket.ALG_OP_ENCRYPT, iv=iv) + enc = op.recv(msglen * multiplier) + self.assertEqual(len(enc), msglen * multiplier) + self.assertTrue(enc[:msglen], ciphertext) + + op, _ = algo.accept() + with op: + op.sendmsg_afalg([enc], + op=socket.ALG_OP_DECRYPT, iv=iv) + dec = op.recv(msglen * multiplier) + self.assertEqual(len(dec), msglen * multiplier) + self.assertEqual(dec, msg * multiplier) + + @support.requires_linux_version(4, 9) # see issue29324 + def test_aead_aes_gcm(self): + key = bytes.fromhex('c939cc13397c1d37de6ae0e1cb7c423c') + iv = bytes.fromhex('b3d8cc017cbb89b39e0f67e2') + plain = bytes.fromhex('c3b3c41f113a31b73d9a5cd432103069') + assoc = bytes.fromhex('24825602bd12a984e0092d3e448eda5f') + expected_ct = bytes.fromhex('93fe7d9e9bfd10348a5606e5cafa7354') + expected_tag = bytes.fromhex('0032a1dc85f1c9786925a2e71d8272dd') + + taglen = len(expected_tag) + assoclen = len(assoc) + + with self.create_alg('aead', 'gcm(aes)') as algo: + algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key) + algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_AEAD_AUTHSIZE, + None, taglen) + + # send assoc, plain and tag buffer in separate steps + op, _ = algo.accept() + with op: + op.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, iv=iv, + assoclen=assoclen, flags=socket.MSG_MORE) + op.sendall(assoc, socket.MSG_MORE) + op.sendall(plain) + res = op.recv(assoclen + len(plain) + taglen) + self.assertEqual(expected_ct, res[assoclen:-taglen]) + self.assertEqual(expected_tag, res[-taglen:]) + + # now with msg + op, _ = algo.accept() + with op: + msg = assoc + plain + op.sendmsg_afalg([msg], op=socket.ALG_OP_ENCRYPT, iv=iv, + assoclen=assoclen) + res = op.recv(assoclen + len(plain) + taglen) + self.assertEqual(expected_ct, res[assoclen:-taglen]) + self.assertEqual(expected_tag, res[-taglen:]) + + # create anc data manually + pack_uint32 = struct.Struct('I').pack + op, _ = algo.accept() + with op: + msg = assoc + plain + op.sendmsg( + [msg], + ([socket.SOL_ALG, socket.ALG_SET_OP, pack_uint32(socket.ALG_OP_ENCRYPT)], + [socket.SOL_ALG, socket.ALG_SET_IV, pack_uint32(len(iv)) + iv], + [socket.SOL_ALG, socket.ALG_SET_AEAD_ASSOCLEN, pack_uint32(assoclen)], + ) + ) + res = op.recv(len(msg) + taglen) + self.assertEqual(expected_ct, res[assoclen:-taglen]) + self.assertEqual(expected_tag, res[-taglen:]) + + # decrypt and verify + op, _ = algo.accept() + with op: + msg = assoc + expected_ct + expected_tag + op.sendmsg_afalg([msg], op=socket.ALG_OP_DECRYPT, iv=iv, + assoclen=assoclen) + res = op.recv(len(msg) - taglen) + self.assertEqual(plain, res[assoclen:]) + + @support.requires_linux_version(4, 3) # see test_aes_cbc + def test_drbg_pr_sha256(self): + # deterministic random bit generator, prediction resistance, sha256 + with self.create_alg('rng', 'drbg_pr_sha256') as algo: + extra_seed = os.urandom(32) + algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, extra_seed) + op, _ = algo.accept() + with op: + rn = op.recv(32) + self.assertEqual(len(rn), 32) + + def test_sendmsg_afalg_args(self): + sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0) + with sock: + with self.assertRaises(TypeError): + sock.sendmsg_afalg() + + with self.assertRaises(TypeError): + sock.sendmsg_afalg(op=None) + + with self.assertRaises(TypeError): + sock.sendmsg_afalg(1) + + with self.assertRaises(TypeError): + sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=None) + + with self.assertRaises(TypeError): + sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=-1) + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index ef2e59c1d15d06..ff4adb46b0ea80 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -6,8 +6,13 @@ from test import test_support as support from test.script_helper import assert_python_ok import asyncore +# import unittest.mock +from mock import Mock +from test import support +import re import socket import select +import struct import time import datetime import gc @@ -85,6 +90,35 @@ def data_file(*name): OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0) OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0) +SIGNED_CERTFILE_HOSTNAME = 'localhost' +SIGNED_CERTFILE2_HOSTNAME = 'fakehostname' +NOSANFILE = data_file("nosan.pem") +NOSAN_HOSTNAME = 'localhost' + + +def testing_context(server_cert=SIGNED_CERTFILE): + """Create context + + client_context, server_context, hostname = testing_context() + """ + if server_cert == SIGNED_CERTFILE: + hostname = SIGNED_CERTFILE_HOSTNAME + elif server_cert == SIGNED_CERTFILE2: + hostname = SIGNED_CERTFILE2_HOSTNAME + elif server_cert == NOSANFILE: + hostname = NOSAN_HOSTNAME + else: + raise ValueError(server_cert) + + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + client_context.load_verify_locations(SIGNING_CA) + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + server_context.load_cert_chain(server_cert) + server_context.load_verify_locations(SIGNING_CA) + + return client_context, server_context, hostname + def handle_error(prefix): exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) @@ -1753,7 +1787,7 @@ def wrap_conn(self): return False else: # OSError may occur with wrong protocols, e.g. both - # sides use PROTOCOL_TLS_SERVER. + # sides use PROTOCOL_TLS. # # XXX Various errors can have happened here, for example # a mismatching protocol version, an invalid certificate, @@ -2153,7 +2187,7 @@ def try_protocol_combo(server_protocol, client_protocol, expect_success, class ThreadedTests(unittest.TestCase): - + # FIX: client_context, server_context, hostname = testing_context() @skip_if_broken_ubuntu_ssl def test_echo(self): """Basic test of an SSL client connecting to a server""" @@ -2232,8 +2266,7 @@ def test_crl_check(self): server = ThreadedEchoServer(context=server_context, chatty=True) with server: with closing(context.wrap_socket(socket.socket())) as s: - with self.assertRaisesRegexp(ssl.SSLError, - "certificate verify failed"): + with self.assertRaisesRegexp(ssl.SSLError, "certificate verify failed"): s.connect((HOST, server.port)) # now load a CRL file. The CRL file is signed by the CA. @@ -2261,8 +2294,7 @@ def test_check_hostname(self): # correct hostname should verify server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with closing(context.wrap_socket(socket.socket(), - server_hostname="localhost")) as s: + with closing(context.wrap_socket(socket.socket(), server_hostname="localhost")) as s: s.connect((HOST, server.port)) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") @@ -2270,8 +2302,7 @@ def test_check_hostname(self): # incorrect hostname should raise an exception server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with closing(context.wrap_socket(socket.socket(), - server_hostname="invalid")) as s: + with closing(context.wrap_socket(socket.socket(), server_hostname="invalid")) as s: with self.assertRaisesRegexp(ssl.CertificateError, "hostname 'invalid' doesn't match u?'localhost'"): s.connect((HOST, server.port)) @@ -3240,6 +3271,481 @@ def test_read_write_after_close_raises_valuerror(self): self.assertRaises(ValueError, s.read, 1024) self.assertRaises(ValueError, s.write, b'hello') + def test_pha_no_pha_client(self): + client_context, server_context, hostname = testing_context() + server_context.post_handshake_auth = True + server_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_cert_chain(SIGNED_CERTFILE) + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing (client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + with self.assertRaisesRegexp(ssl.SSLError, 'not server'): + s.verify_client_post_handshake() + s.write(b'PHA') + self.assertIn(b'extension not received', s.recv(1024)) + + def test_pha_no_pha_server(self): + # server doesn't have PHA enabled, cert is requested in handshake + client_context, server_context, hostname = testing_context() + server_context.verify_mode = ssl.CERT_REQUIRED + client_context.post_handshake_auth = True + client_context.load_cert_chain(SIGNED_CERTFILE) + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + s.write(b'HASCERT') + self.assertEqual(s.recv(1024), b'TRUE\n') + # PHA doesn't fail if there is already a cert + s.write(b'PHA') + self.assertEqual(s.recv(1024), b'OK\n') + s.write(b'HASCERT') + self.assertEqual(s.recv(1024), b'TRUE\n') + + def test_pha_not_tls13(self): + # TLS 1.2 + client_context, server_context, hostname = testing_context() + server_context.verify_mode = ssl.CERT_REQUIRED + client_context.maximum_version = ssl.TLSVersion.TLSv1_2 + client_context.post_handshake_auth = True + client_context.load_cert_chain(SIGNED_CERTFILE) + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + # PHA fails for TLS != 1.3 + s.write(b'PHA') + self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024)) + + def test_bpo37428_pha_cert_none(self): + # verify that post_handshake_auth does not implicitly enable cert + # validation. + hostname = SIGNED_CERTFILE_HOSTNAME + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS) + client_context.post_handshake_auth = True + client_context.load_cert_chain(SIGNED_CERTFILE) + # no cert validation and CA on client side + client_context.check_hostname = False + client_context.verify_mode = ssl.CERT_NONE + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS) + server_context.load_cert_chain(SIGNED_CERTFILE) + server_context.load_verify_locations(SIGNING_CA) + server_context.post_handshake_auth = True + server_context.verify_mode = ssl.CERT_REQUIRED + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + s.write(b'HASCERT') + self.assertEqual(s.recv(1024), b'FALSE\n') + s.write(b'PHA') + self.assertEqual(s.recv(1024), b'OK\n') + s.write(b'HASCERT') + self.assertEqual(s.recv(1024), b'TRUE\n') + # server cert has not been validated + self.assertEqual(s.getpeercert(), {}) + + +HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename') +requires_keylog = unittest.skipUnless( + HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback') + +class TestSSLDebug(unittest.TestCase): + + def keylog_lines(self, fname=support.TESTFN): + with open(fname) as f: + return len(list(f)) + + @requires_keylog + def test_keylog_defaults(self): + self.addCleanup(support.unlink, support.TESTFN) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) + self.assertEqual(ctx.keylog_filename, None) + + self.assertFalse(os.path.isfile(support.TESTFN)) + ctx.keylog_filename = support.TESTFN + self.assertEqual(ctx.keylog_filename, support.TESTFN) + self.assertTrue(os.path.isfile(support.TESTFN)) + self.assertEqual(self.keylog_lines(), 1) + + ctx.keylog_filename = None + self.assertEqual(ctx.keylog_filename, None) + + with self.assertRaises((IsADirectoryError, PermissionError)): + # Windows raises PermissionError + ctx.keylog_filename = os.path.dirname( + os.path.abspath(support.TESTFN)) + + with self.assertRaises(TypeError): + ctx.keylog_filename = 1 + + @requires_keylog + def test_keylog_filename(self): + self.addCleanup(support.unlink, support.TESTFN) + client_context, server_context, hostname = testing_context() + + client_context.keylog_filename = support.TESTFN + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + # header, 5 lines for TLS 1.3 + self.assertEqual(self.keylog_lines(), 6) + + client_context.keylog_filename = None + server_context.keylog_filename = support.TESTFN + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + self.assertGreaterEqual(self.keylog_lines(), 11) + + client_context.keylog_filename = support.TESTFN + server_context.keylog_filename = support.TESTFN + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + self.assertGreaterEqual(self.keylog_lines(), 21) + + client_context.keylog_filename = None + server_context.keylog_filename = None + + @requires_keylog + @unittest.skipIf(sys.flags.ignore_environment, + "test is not compatible with ignore_environment") + def test_keylog_env(self): + self.addCleanup(support.unlink, support.TESTFN) + with mock.patch.dict(os.environ): + os.environ['SSLKEYLOGFILE'] = support.TESTFN + self.assertEqual(os.environ['SSLKEYLOGFILE'], support.TESTFN) + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) + self.assertEqual(ctx.keylog_filename, None) + + ctx = ssl.create_default_context() + self.assertEqual(ctx.keylog_filename, support.TESTFN) + + ctx = ssl._create_stdlib_context() + self.assertEqual(ctx.keylog_filename, support.TESTFN) + + def test_msg_callback(self): + client_context, server_context, hostname = testing_context() + + def msg_cb(conn, direction, version, content_type, msg_type, data): + pass + + self.assertIs(client_context._msg_callback, None) + client_context._msg_callback = msg_cb + self.assertIs(client_context._msg_callback, msg_cb) + with self.assertRaises(TypeError): + client_context._msg_callback = object() + + def test_msg_callback_tls12(self): + client_context, server_context, hostname = testing_context() + client_context.options |= ssl.OP_NO_TLSv1_3 + + msg = [] + + def msg_cb(conn, direction, version, content_type, msg_type, data): + self.assertIsInstance(conn, ssl.SSLSocket) + self.assertIsInstance(data, bytes) + self.assertIn(direction, {'read', 'write'}) + msg.append((direction, version, content_type, msg_type)) + + client_context._msg_callback = msg_cb + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + + self.assertIn( + ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE, + _TLSMessageType.SERVER_KEY_EXCHANGE), + msg + ) + self.assertIn( + ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC, + _TLSMessageType.CHANGE_CIPHER_SPEC), + msg + ) + + def test_msg_callback_deadlock_bpo43577(self): + client_context, server_context, hostname = testing_context() + server_context2 = testing_context()[1] + + def msg_cb(conn, direction, version, content_type, msg_type, data): + pass + + def sni_cb(sock, servername, ctx): + sock.context = server_context2 + + server_context._msg_callback = msg_cb + server_context.sni_callback = sni_cb + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + + +def set_socket_so_linger_on_with_zero_timeout(sock): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) + + +class TestPreHandshakeClose(unittest.TestCase): + """Verify behavior of close sockets with received data before to the handshake. + """ + + class SingleConnectionTestServerThread(threading.Thread): + + def __init__(self, _dummy=object(), name="name", call_after_accept=False): + self.call_after_accept = call_after_accept + self.received_data = b'' # set by .run() + self.wrap_error = None # set by .run() + self.listener = None # set by .start() + self.port = None # set by .start() + super().__init__(name=name) + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + try: + if self.listener: + self.listener.close() + except (OSError, socket_error): + pass + self.join() + self.wrap_error = None # avoid dangling references + + def start(self): + self.ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + self.ssl_ctx.verify_mode = ssl.CERT_REQUIRED + self.ssl_ctx.load_verify_locations(cafile=ONLYCERT) + self.ssl_ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY) + self.listener = socket.socket() + self.port = support.bind_port(self.listener) + self.listener.settimeout(2.0) + self.listener.listen(1) + super().start() + + def run(self): + conn, address = self.listener.accept() + self.listener.close() + with conn: + if self.call_after_accept(conn): + return + try: + tls_socket = self.ssl_ctx.wrap_socket(conn, server_side=True) + except OSError as err: # ssl.SSLError inherits from OSError + self.wrap_error = err + else: + try: + self.received_data = tls_socket.recv(400) + except OSError: + pass # closed, protocol error, etc. + + def non_linux_skip_if_other_okay_error(self, err): + if sys.platform == "linux": + return # Expect the full test setup to always work on Linux. + if (isinstance(err, ConnectionResetError) or + (isinstance(err, OSError) and err.errno == errno.EINVAL) or + re.search('wrong.version.number', getattr(err, "reason", ""), re.I)): + # On Windows the TCP RST leads to a ConnectionResetError + # (ECONNRESET) which Linux doesn't appear to surface to userspace. + # If wrap_socket() winds up on the "if connected:" path and doing + # the actual wrapping... we get an SSLError from OpenSSL. Typically + # WRONG_VERSION_NUMBER. While appropriate, neither is the scenario + # we're specifically trying to test. The way this test is written + # is known to work on Linux. We'll skip it anywhere else that it + # does not present as doing so. + self.skipTest("Could not recreate conditions on {sys.platform}: {err}".format(sys.plattform,err)) + # If maintaining this conditional winds up being a problem. + # just turn this into an unconditional skip anything but Linux. + # The important thing is that our CI has the logic covered. + + def test_preauth_data_to_tls_server(self): + server_accept_called = threading.Event() + ready_for_server_wrap_socket = threading.Event() + + def call_after_accept(unused): + server_accept_called.set() + if not ready_for_server_wrap_socket.wait(2.0): + raise RuntimeError("wrap_socket event never set, test may fail.") + return False # Tell the server thread to continue. + + server = self.SingleConnectionTestServerThread( + call_after_accept=call_after_accept, + name="preauth_data_to_tls_server") + server.__enter__() # starts it + self.addCleanup(server.__exit__) # ... & unittest.TestCase stops it. + + with socket.socket() as client: + client.connect(server.listener.getsockname()) + # This forces an immediate connection close via RST on .close(). + set_socket_so_linger_on_with_zero_timeout(client) + client.setblocking(False) + + server_accept_called.wait() + client.send(b"DELETE /data HTTP/1.0\r\n\r\n") + client.close() # RST + + ready_for_server_wrap_socket.set() + server.join() + wrap_error = server.wrap_error + self.assertEqual(b"", server.received_data) + self.assertIsInstance(wrap_error, OSError) # All platforms. + self.non_linux_skip_if_other_okay_error(wrap_error) + self.assertIsInstance(wrap_error, ssl.SSLError) + self.assertIn("before TLS handshake with data", wrap_error.args[1]) + self.assertIn("before TLS handshake with data", wrap_error.reason) + self.assertNotEqual(0, wrap_error.args[0]) + self.assertIsNone(wrap_error.library, msg="attr must exist") + + def test_preauth_data_to_tls_client(self): + client_can_continue_with_wrap_socket = threading.Event() + + def call_after_accept(conn_to_client): + # This forces an immediate connection close via RST on .close(). + set_socket_so_linger_on_with_zero_timeout(conn_to_client) + conn_to_client.send( + b"HTTP/1.0 307 Temporary Redirect\r\n" + b"Location: https://example.com/someone-elses-server\r\n" + b"\r\n") + conn_to_client.close() # RST + client_can_continue_with_wrap_socket.set() + return True # Tell the server to stop. + + server = self.SingleConnectionTestServerThread( + call_after_accept=call_after_accept, + name="preauth_data_to_tls_client") + server.__enter__() # starts it + self.addCleanup(server.__exit__) # ... & unittest.TestCase stops it. + + # Redundant; call_after_accept sets SO_LINGER on the accepted conn. + set_socket_so_linger_on_with_zero_timeout(server.listener) + + with socket.socket() as client: + client.connect(server.listener.getsockname()) + if not client_can_continue_with_wrap_socket.wait(2.0): + self.fail("test server took too long.") + ssl_ctx = ssl.create_default_context() + try: + tls_client = ssl_ctx.wrap_socket( + client, server_hostname="localhost") + except OSError as err: # SSLError inherits from OSError + wrap_error = err + received_data = b"" + else: + wrap_error = None + received_data = tls_client.recv(400) + tls_client.close() + + server.join() + self.assertEqual(b"", received_data) + self.assertIsInstance(wrap_error, OSError) # All platforms. + self.non_linux_skip_if_other_okay_error(wrap_error) + self.assertIsInstance(wrap_error, ssl.SSLError) + self.assertIn("before TLS handshake with data", wrap_error.args[1]) + self.assertIn("before TLS handshake with data", wrap_error.reason) + self.assertNotEqual(0, wrap_error.args[0]) + self.assertIsNone(wrap_error.library, msg="attr must exist") + + def test_https_client_non_tls_response_ignored(self): + server_responding = threading.Event() + + class SynchronizedHTTPSConnection(http.client.HTTPSConnection): + def connect(self): + # Call clear text HTTP connect(), not the encrypted HTTPS (TLS) + # connect(): wrap_socket() is called manually below. + http.client.HTTPConnection.connect(self) + + # Wait for our fault injection server to have done its thing. + if not server_responding.wait(support.SHORT_TIMEOUT) and support.verbose: + sys.stdout.write("server_responding event never set.") + self.sock = self._context.wrap_socket( + self.sock, server_hostname=self.host) + + def call_after_accept(conn_to_client): + # This forces an immediate connection close via RST on .close(). + set_socket_so_linger_on_with_zero_timeout(conn_to_client) + conn_to_client.send( + b"HTTP/1.0 402 Payment Required\r\n" + b"\r\n") + conn_to_client.close() # RST + server_responding.set() + return True # Tell the server to stop. + + timeout = 2.0 + server = self.SingleConnectionTestServerThread( + call_after_accept=call_after_accept, + name="non_tls_http_RST_responder", + timeout=timeout) + server.__enter__() # starts it + self.addCleanup(server.__exit__) # ... & unittest.TestCase stops it. + # Redundant; call_after_accept sets SO_LINGER on the accepted conn. + set_socket_so_linger_on_with_zero_timeout(server.listener) + + connection = SynchronizedHTTPSConnection( + server.listener.getsockname()[0], + port=server.port, + context=ssl.create_default_context(), + timeout=timeout, + ) + + # There are lots of reasons this raises as desired, long before this + # test was added. Sending the request requires a successful TLS wrapped + # socket; that fails if the connection is broken. It may seem pointless + # to test this. It serves as an illustration of something that we never + # want to happen... properly not happening. + with self.assertRaises(OSError): + connection.request("HEAD", "/test", headers={"Host": "localhost"}) + response = connection.getresponse() + + server.join() + + def test_bpo37428_pha_cert_none(self): + # verify that post_handshake_auth does not implicitly enable cert + # validation. + hostname = SIGNED_CERTFILE_HOSTNAME + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.post_handshake_auth = True + client_context.load_cert_chain(SIGNED_CERTFILE) + # no cert validation and CA on client side + client_context.check_hostname = False + client_context.verify_mode = ssl.CERT_NONE + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(SIGNED_CERTFILE) + server_context.load_verify_locations(SIGNING_CA) + server_context.post_handshake_auth = True + server_context.verify_mode = ssl.CERT_REQUIRED + + server = ThreadedEchoServer(context=server_context, chatty=False) + with server: + with closing(client_context.wrap_socket(socket.socket(), server_hostname=hostname)) as s: + s.connect((HOST, server.port)) + s.write(b'HASCERT') + self.assertEqual(s.recv(1024), b'FALSE\n') + s.write(b'PHA') + self.assertEqual(s.recv(1024), b'OK\n') + s.write(b'HASCERT') + self.assertEqual(s.recv(1024), b'TRUE\n') + # server cert has not been validated + self.assertEqual(s.getpeercert(), {}) + def test_main(verbose=False): if support.verbose: diff --git a/Misc/NEWS.d/next/Library/2017-12-19-09-23-46.bpo-32373.8qAkoW.rst b/Misc/NEWS.d/next/Library/2017-12-19-09-23-46.bpo-32373.8qAkoW.rst new file mode 100644 index 00000000000000..9772dda35a3013 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-12-19-09-23-46.bpo-32373.8qAkoW.rst @@ -0,0 +1 @@ +Add socket.getblocking() method. diff --git a/Misc/NEWS.d/next/Library/2019-06-27-13-27-02.bpo-37428._wcwUd.rst b/Misc/NEWS.d/next/Library/2019-06-27-13-27-02.bpo-37428._wcwUd.rst new file mode 100644 index 00000000000000..2cdce6b24dc6b6 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-06-27-13-27-02.bpo-37428._wcwUd.rst @@ -0,0 +1,4 @@ +SSLContext.post_handshake_auth = True no longer sets +SSL_VERIFY_POST_HANDSHAKE verify flag for client connections. Although the +option is documented as ignored for clients, OpenSSL implicitly enables cert +chain validation when the flag is set. diff --git a/Misc/NEWS.d/next/Security/2023-08-22-17-39-12.gh-issue-108310.fVM3sg.rst b/Misc/NEWS.d/next/Security/2023-08-22-17-39-12.gh-issue-108310.fVM3sg.rst new file mode 100644 index 00000000000000..403c77a9d480ee --- /dev/null +++ b/Misc/NEWS.d/next/Security/2023-08-22-17-39-12.gh-issue-108310.fVM3sg.rst @@ -0,0 +1,7 @@ +Fixed an issue where instances of :class:`ssl.SSLSocket` were vulnerable to +a bypass of the TLS handshake and included protections (like certificate +verification) and treating sent unencrypted data as if it were +post-handshake TLS encrypted data. Security issue reported as +`CVE-2023-40217 +`_ by +Aapo Oksman. Patch by Gregory P. Smith. diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 98c8a5a4a95ca9..6df9f47792e4df 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -324,6 +324,9 @@ typedef struct { PyObject *set_hostname; #endif int check_hostname; +#ifdef TLS1_3_VERSION + int post_handshake_auth; +#endif } PySSLContext; typedef struct { @@ -606,6 +609,28 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, #endif SSL_set_mode(self->ssl, mode); + /* bpo-37428: newPySSLSocket() sets SSL_VERIFY_POST_HANDSHAKE flag for + * server sockets and SSL_set_post_handshake_auth() for client. */ +// #ifdef TLS1_3_VERSION +// if (sslctx->post_handshake_auth == 1) { +// if (socket_type == PY_SSL_SERVER) { +// /* bpo-37428: OpenSSL does not ignore SSL_VERIFY_POST_HANDSHAKE. +// * Set SSL_VERIFY_POST_HANDSHAKE flag only for server sockets and +// * only in combination with SSL_VERIFY_PEER flag. */ +// int mode = SSL_get_verify_mode(self->ssl); +// if (mode & SSL_VERIFY_PEER) { +// int (*verify_cb)(int, X509_STORE_CTX *) = NULL; +// verify_cb = SSL_get_verify_callback(self->ssl); +// mode |= SSL_VERIFY_POST_HANDSHAKE; +// SSL_set_verify(self->ssl, mode, verify_cb); +// } +// } else { +// /* client socket */ +// SSL_set_post_handshake_auth(self->ssl, 1); +// } +// } +// #endif + #if HAVE_SNI if (server_hostname != NULL) { /* Don't send SNI for IP addresses. We cannot simply use inet_aton() and @@ -2100,6 +2125,36 @@ If the TLS handshake is not yet complete, None is returned"); #endif /* HAVE_OPENSSL_FINISHED */ +/*[clinic input] +_ssl._SSLSocket.verify_client_post_handshake + +Initiate TLS 1.3 post-handshake authentication +[clinic start generated code]*/ + +PyDoc_STRVAR(_ssl__SSLSocket_verify_client_post_handshake__doc__, +"verify_client_post_handshake($self, /)\n" +"--\n" +"\n" +"Initiate TLS 1.3 post-handshake authentication"); + +static PyObject * +_ssl__SSLSocket_verify_client_post_handshake_impl(PySSLSocket *self) +/*[clinic end generated code: output=532147f3b1341425 input=6bfa874810a3d889]*/ +{ +#ifdef TLS1_3_VERSION + int err = SSL_verify_client_post_handshake(self->ssl); + if (err == 0) + return _setSSLError(NULL, 0, __FILE__, __LINE__); + else + Py_RETURN_NONE; +#else + PyErr_SetString(PyExc_NotImplementedError, + "Post-handshake auth is not supported by your " + "OpenSSL version."); + return NULL; +#endif +} + static PyGetSetDef ssl_getsetlist[] = { {"context", (getter) PySSL_get_context, (setter) PySSL_set_context, PySSL_set_context_doc}, @@ -2131,6 +2186,7 @@ static PyMethodDef PySSLMethods[] = { {"tls_unique_cb", (PyCFunction)PySSL_tls_unique_cb, METH_NOARGS, PySSL_tls_unique_cb_doc}, #endif + {"verify_client_post_handshake", (PyCFunction)_ssl__SSLSocket_verify_client_post_handshake_impl, METH_NOARGS, _ssl__SSLSocket_verify_client_post_handshake__doc__}, {NULL, NULL} }; @@ -2308,6 +2364,13 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) } #endif +#ifdef TLS1_3_VERSION + { + self->post_handshake_auth = 0; + SSL_CTX_set_post_handshake_auth(self->ctx, self->post_handshake_auth); + } +#endif + return (PyObject *)self; } @@ -2527,6 +2590,8 @@ static int set_verify_mode(PySSLContext *self, PyObject *arg, void *c) { int n, mode; + int (*verify_cb)(int, X509_STORE_CTX *) = NULL; + if (!PyArg_Parse(arg, "i", &n)) return -1; if (n == PY_SSL_CERT_NONE) @@ -2540,13 +2605,20 @@ set_verify_mode(PySSLContext *self, PyObject *arg, void *c) "invalid value for verify_mode"); return -1; } + if (mode == SSL_VERIFY_NONE && self->check_hostname) { PyErr_SetString(PyExc_ValueError, "Cannot set verify_mode to CERT_NONE when " "check_hostname is enabled."); return -1; } - SSL_CTX_set_verify(self->ctx, mode, NULL); + + /* bpo-37428: newPySSLSocket() sets SSL_VERIFY_POST_HANDSHAKE flag for + * server sockets and SSL_set_post_handshake_auth() for client. */ + + /* keep current verify cb */ + verify_cb = SSL_CTX_get_verify_callback(self->ctx); + SSL_CTX_set_verify(self->ctx, mode, verify_cb); return 0; } @@ -2651,6 +2723,36 @@ set_check_hostname(PySSLContext *self, PyObject *arg, void *c) return 0; } +static PyObject * +get_post_handshake_auth(PySSLContext *self, void *c) { +#if TLS1_3_VERSION + return PyBool_FromLong(self->post_handshake_auth); +#else + Py_RETURN_NONE; +#endif +} + +#if TLS1_3_VERSION +static int +set_post_handshake_auth(PySSLContext *self, PyObject *arg, void *c) { + if (arg == NULL) { + PyErr_SetString(PyExc_AttributeError, "cannot delete attribute"); + return -1; + } + int pha = PyObject_IsTrue(arg); + + if (pha == -1) { + return -1; + } + self->post_handshake_auth = pha; + + /* bpo-37428: newPySSLSocket() sets SSL_VERIFY_POST_HANDSHAKE flag for + * server sockets and SSL_set_post_handshake_auth() for client. */ + + return 0; +} +#endif + typedef struct { PyThreadState *thread_state; @@ -3506,6 +3608,13 @@ static PyGetSetDef context_getsetlist[] = { #endif {"verify_mode", (getter) get_verify_mode, (setter) set_verify_mode, NULL}, + {"post_handshake_auth", (getter) get_post_handshake_auth, +#ifdef TLS1_3_VERSION + (setter) set_post_handshake_auth, +#else + NULL, +#endif + }, {NULL}, /* sentinel */ }; @@ -4532,3 +4641,4 @@ init_ssl(void) if (r == NULL || PyModule_AddObject(m, "_OPENSSL_API_VERSION", r)) return; } + diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c index 4d5a8f6f01703d..b145daf5c88b6d 100644 --- a/Modules/socketmodule.c +++ b/Modules/socketmodule.c @@ -142,7 +142,8 @@ sendall(data[, flags]) -- send all data\n\ send(data[, flags]) -- send data, may not send all of it\n\ sendto(data[, flags], addr) -- send data to a given address\n\ setblocking(0 | 1) -- set or clear the blocking I/O flag\n\ -setsockopt(level, optname, value) -- set socket options\n\ +getblocking() -- return True if socket is blocking, False if non-blocking\n\ +setsockopt(level, optname, value[, optlen]) -- set socket options\n\ settimeout(None | float) -- set or clear the timeout\n\ shutdown(how) -- shut down traffic in one or both directions\n\ \n\ @@ -1814,6 +1815,68 @@ Set the socket to blocking (flag is true) or non-blocking (false).\n\ setblocking(True) is equivalent to settimeout(None);\n\ setblocking(False) is equivalent to settimeout(0.0)."); +/* s.getblocking() method. + Returns True if socket is in blocking mode, + False if it is in non-blocking mode. +*/ +static PyObject * +sock_getblocking(PySocketSockObject *s) +{ + if (s->sock_timeout) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + +PyDoc_STRVAR(getblocking_doc, +"getblocking()\n\ +\n\ +Returns True if socket is in blocking mode, or False if it\n\ +is in non-blocking mode."); + +static int +socket_parse_timeout(_PyTime_t *timeout, PyObject *timeout_obj) +{ +#ifdef MS_WINDOWS + struct timeval tv; +#endif +#ifndef HAVE_POLL + _PyTime_t ms; +#endif + int overflow = 0; + + if (timeout_obj == Py_None) { + *timeout = _PyTime_FromSeconds(-1); + return 0; + } + + if (_PyTime_FromSecondsObject(timeout, + timeout_obj, _PyTime_ROUND_TIMEOUT) < 0) + return -1; + + if (*timeout < 0) { + PyErr_SetString(PyExc_ValueError, "Timeout value out of range"); + return -1; + } + +#ifdef MS_WINDOWS + overflow |= (_PyTime_AsTimeval(*timeout, &tv, _PyTime_ROUND_TIMEOUT) < 0); +#endif +#ifndef HAVE_POLL + ms = _PyTime_AsMilliseconds(*timeout, _PyTime_ROUND_TIMEOUT); + overflow |= (ms > INT_MAX); +#endif + if (overflow) { + PyErr_SetString(PyExc_OverflowError, + "timeout doesn't fit into C timeval"); + return -1; + } + + return 0; +} + /* s.settimeout(timeout) method. Argument: None -- no timeout, blocking mode; same as setblocking(True) 0.0 -- non-blocking mode; same as setblocking(False) @@ -1837,11 +1900,34 @@ sock_settimeout(PySocketSockObject *s, PyObject *arg) } } - s->sock_timeout = timeout; - internal_setblocking(s, timeout < 0.0); + s->sock_timeout = timeout; - Py_INCREF(Py_None); - return Py_None; + int block = timeout < 0; + /* Blocking mode for a Python socket object means that operations + like :meth:`recv` or :meth:`sendall` will block the execution of + the current thread until they are complete or aborted with a + `socket.timeout` or `socket.error` errors. When timeout is `None`, + the underlying FD is in a blocking mode. When timeout is a positive + number, the FD is in a non-blocking mode, and socket ops are + implemented with a `select()` call. + + When timeout is 0.0, the FD is in a non-blocking mode. + + This table summarizes all states in which the socket object and + its underlying FD can be: + + ==================== ===================== ============== + `gettimeout()` `getblocking()` FD + ==================== ===================== ============== + ``None`` ``True`` blocking + ``0.0`` ``False`` non-blocking + ``> 0`` ``True`` non-blocking + */ + + if (internal_setblocking(s, block) == -1) { + return NULL; + } + Py_RETURN_NONE; } PyDoc_STRVAR(settimeout_doc, @@ -3122,6 +3208,8 @@ static PyMethodDef sock_methods[] = { sendto_doc}, {"setblocking", (PyCFunction)sock_setblocking, METH_O, setblocking_doc}, + {"getblocking", (PyCFunction)sock_getblocking, METH_NOARGS, + getblocking_doc}, {"settimeout", (PyCFunction)sock_settimeout, METH_O, settimeout_doc}, {"gettimeout", (PyCFunction)sock_gettimeout, METH_NOARGS, diff --git a/Python/pytime.c b/Python/pytime.c new file mode 100644 index 00000000000000..109d52692ce486 --- /dev/null +++ b/Python/pytime.c @@ -0,0 +1,1114 @@ +#include "Python.h" +#ifdef MS_WINDOWS +#include /* struct timeval */ +#endif + +#if defined(__APPLE__) +#include /* mach_absolute_time(), mach_timebase_info() */ +#endif + +#define _PyTime_check_mul_overflow(a, b) \ + (assert(b > 0), \ + (_PyTime_t)(a) < _PyTime_MIN / (_PyTime_t)(b) \ + || _PyTime_MAX / (_PyTime_t)(b) < (_PyTime_t)(a)) + +/* To millisecond (10^-3) */ +#define SEC_TO_MS 1000 + +/* To microseconds (10^-6) */ +#define MS_TO_US 1000 +#define SEC_TO_US (SEC_TO_MS * MS_TO_US) + +/* To nanoseconds (10^-9) */ +#define US_TO_NS 1000 +#define MS_TO_NS (MS_TO_US * US_TO_NS) +#define SEC_TO_NS (SEC_TO_MS * MS_TO_NS) + +/* Conversion from nanoseconds */ +#define NS_TO_MS (1000 * 1000) +#define NS_TO_US (1000) + +static void +error_time_t_overflow(void) +{ + PyErr_SetString(PyExc_OverflowError, + "timestamp out of range for platform time_t"); +} + +static void +_PyTime_overflow(void) +{ + PyErr_SetString(PyExc_OverflowError, + "timestamp too large to convert to C _PyTime_t"); +} + + +_PyTime_t +_PyTime_MulDiv(_PyTime_t ticks, _PyTime_t mul, _PyTime_t div) +{ + _PyTime_t intpart, remaining; + /* Compute (ticks * mul / div) in two parts to prevent integer overflow: + compute integer part, and then the remaining part. + + (ticks * mul) / div == (ticks / div) * mul + (ticks % div) * mul / div + + The caller must ensure that "(div - 1) * mul" cannot overflow. */ + intpart = ticks / div; + ticks %= div; + remaining = ticks * mul; + remaining /= div; + return intpart * mul + remaining; +} + + +time_t +_PyLong_AsTime_t(PyObject *obj) +{ +#if SIZEOF_TIME_T == SIZEOF_LONG_LONG + long long val; + val = PyLong_AsLongLong(obj); +#else + long val; + Py_BUILD_ASSERT(sizeof(time_t) <= sizeof(long)); + val = PyLong_AsLong(obj); +#endif + if (val == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + error_time_t_overflow(); + } + return -1; + } + return (time_t)val; +} + +PyObject * +_PyLong_FromTime_t(time_t t) +{ +#if SIZEOF_TIME_T == SIZEOF_LONG_LONG + return PyLong_FromLongLong((long long)t); +#else + Py_BUILD_ASSERT(sizeof(time_t) <= sizeof(long)); + return PyLong_FromLong((long)t); +#endif +} + +/* Round to nearest with ties going to nearest even integer + (_PyTime_ROUND_HALF_EVEN) */ +static double +_PyTime_RoundHalfEven(double x) +{ + double rounded = round(x); + if (fabs(x-rounded) == 0.5) { + /* halfway case: round to even */ + rounded = 2.0*round(x/2.0); + } + return rounded; +} + +static double +_PyTime_Round(double x, _PyTime_round_t round) +{ + /* volatile avoids optimization changing how numbers are rounded */ + volatile double d; + + d = x; + if (round == _PyTime_ROUND_HALF_EVEN) { + d = _PyTime_RoundHalfEven(d); + } + else if (round == _PyTime_ROUND_CEILING) { + d = ceil(d); + } + else if (round == _PyTime_ROUND_FLOOR) { + d = floor(d); + } + else { + assert(round == _PyTime_ROUND_UP); + d = (d >= 0.0) ? ceil(d) : floor(d); + } + return d; +} + +static int +_PyTime_DoubleToDenominator(double d, time_t *sec, long *numerator, + long idenominator, _PyTime_round_t round) +{ + double denominator = (double)idenominator; + double intpart; + /* volatile avoids optimization changing how numbers are rounded */ + volatile double floatpart; + + floatpart = modf(d, &intpart); + + floatpart *= denominator; + floatpart = _PyTime_Round(floatpart, round); + if (floatpart >= denominator) { + floatpart -= denominator; + intpart += 1.0; + } + else if (floatpart < 0) { + floatpart += denominator; + intpart -= 1.0; + } + assert(0.0 <= floatpart && floatpart < denominator); + + if (!_Py_InIntegralTypeRange(time_t, intpart)) { + error_time_t_overflow(); + return -1; + } + *sec = (time_t)intpart; + *numerator = (long)floatpart; + assert(0 <= *numerator && *numerator < idenominator); + return 0; +} + +static int +_PyTime_ObjectToDenominator(PyObject *obj, time_t *sec, long *numerator, + long denominator, _PyTime_round_t round) +{ + assert(denominator >= 1); + + if (PyFloat_Check(obj)) { + double d = PyFloat_AsDouble(obj); + if (Py_IS_NAN(d)) { + *numerator = 0; + PyErr_SetString(PyExc_ValueError, "Invalid value NaN (not a number)"); + return -1; + } + return _PyTime_DoubleToDenominator(d, sec, numerator, + denominator, round); + } + else { + *sec = _PyLong_AsTime_t(obj); + *numerator = 0; + if (*sec == (time_t)-1 && PyErr_Occurred()) { + return -1; + } + return 0; + } +} + +int +_PyTime_ObjectToTime_t(PyObject *obj, time_t *sec, _PyTime_round_t round) +{ + if (PyFloat_Check(obj)) { + double intpart; + /* volatile avoids optimization changing how numbers are rounded */ + volatile double d; + + d = PyFloat_AsDouble(obj); + if (Py_IS_NAN(d)) { + PyErr_SetString(PyExc_ValueError, "Invalid value NaN (not a number)"); + return -1; + } + + d = _PyTime_Round(d, round); + (void)modf(d, &intpart); + + if (!_Py_InIntegralTypeRange(time_t, intpart)) { + error_time_t_overflow(); + return -1; + } + *sec = (time_t)intpart; + return 0; + } + else { + *sec = _PyLong_AsTime_t(obj); + if (*sec == (time_t)-1 && PyErr_Occurred()) { + return -1; + } + return 0; + } +} + +int +_PyTime_ObjectToTimespec(PyObject *obj, time_t *sec, long *nsec, + _PyTime_round_t round) +{ + return _PyTime_ObjectToDenominator(obj, sec, nsec, SEC_TO_NS, round); +} + +int +_PyTime_ObjectToTimeval(PyObject *obj, time_t *sec, long *usec, + _PyTime_round_t round) +{ + return _PyTime_ObjectToDenominator(obj, sec, usec, SEC_TO_US, round); +} + +_PyTime_t +_PyTime_FromSeconds(int seconds) +{ + _PyTime_t t; + /* ensure that integer overflow cannot happen, int type should have 32 + bits, whereas _PyTime_t type has at least 64 bits (SEC_TO_MS takes 30 + bits). */ + Py_BUILD_ASSERT(INT_MAX <= _PyTime_MAX / SEC_TO_NS); + Py_BUILD_ASSERT(INT_MIN >= _PyTime_MIN / SEC_TO_NS); + + t = (_PyTime_t)seconds; + assert((t >= 0 && t <= _PyTime_MAX / SEC_TO_NS) + || (t < 0 && t >= _PyTime_MIN / SEC_TO_NS)); + t *= SEC_TO_NS; + return t; +} + +_PyTime_t +_PyTime_FromNanoseconds(_PyTime_t ns) +{ + /* _PyTime_t already uses nanosecond resolution, no conversion needed */ + return ns; +} + +int +_PyTime_FromNanosecondsObject(_PyTime_t *tp, PyObject *obj) +{ + long long nsec; + _PyTime_t t; + + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expect int, got %s", + Py_TYPE(obj)->tp_name); + return -1; + } + + Py_BUILD_ASSERT(sizeof(long long) == sizeof(_PyTime_t)); + nsec = PyLong_AsLongLong(obj); + if (nsec == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + _PyTime_overflow(); + } + return -1; + } + + /* _PyTime_t already uses nanosecond resolution, no conversion needed */ + t = (_PyTime_t)nsec; + *tp = t; + return 0; +} + +#ifdef HAVE_CLOCK_GETTIME +static int +pytime_fromtimespec(_PyTime_t *tp, struct timespec *ts, int raise) +{ + _PyTime_t t, nsec; + int res = 0; + + Py_BUILD_ASSERT(sizeof(ts->tv_sec) <= sizeof(_PyTime_t)); + t = (_PyTime_t)ts->tv_sec; + + if (_PyTime_check_mul_overflow(t, SEC_TO_NS)) { + if (raise) { + _PyTime_overflow(); + } + res = -1; + t = (t > 0) ? _PyTime_MAX : _PyTime_MIN; + } + else { + t = t * SEC_TO_NS; + } + + nsec = ts->tv_nsec; + /* The following test is written for positive only nsec */ + assert(nsec >= 0); + if (t > _PyTime_MAX - nsec) { + if (raise) { + _PyTime_overflow(); + } + res = -1; + t = _PyTime_MAX; + } + else { + t += nsec; + } + + *tp = t; + return res; +} + +int +_PyTime_FromTimespec(_PyTime_t *tp, struct timespec *ts) +{ + return pytime_fromtimespec(tp, ts, 1); +} +#endif + +#if !defined(MS_WINDOWS) +static int +pytime_fromtimeval(_PyTime_t *tp, struct timeval *tv, int raise) +{ + _PyTime_t t, usec; + int res = 0; + + Py_BUILD_ASSERT(sizeof(tv->tv_sec) <= sizeof(_PyTime_t)); + t = (_PyTime_t)tv->tv_sec; + + if (_PyTime_check_mul_overflow(t, SEC_TO_NS)) { + if (raise) { + _PyTime_overflow(); + } + res = -1; + t = (t > 0) ? _PyTime_MAX : _PyTime_MIN; + } + else { + t = t * SEC_TO_NS; + } + + usec = (_PyTime_t)tv->tv_usec * US_TO_NS; + /* The following test is written for positive only usec */ + assert(usec >= 0); + if (t > _PyTime_MAX - usec) { + if (raise) { + _PyTime_overflow(); + } + res = -1; + t = _PyTime_MAX; + } + else { + t += usec; + } + + *tp = t; + return res; +} + +int +_PyTime_FromTimeval(_PyTime_t *tp, struct timeval *tv) +{ + return pytime_fromtimeval(tp, tv, 1); +} +#endif + +static int +_PyTime_FromDouble(_PyTime_t *t, double value, _PyTime_round_t round, + long unit_to_ns) +{ + /* volatile avoids optimization changing how numbers are rounded */ + volatile double d; + + /* convert to a number of nanoseconds */ + d = value; + d *= (double)unit_to_ns; + d = _PyTime_Round(d, round); + + if (!_Py_InIntegralTypeRange(_PyTime_t, d)) { + _PyTime_overflow(); + return -1; + } + *t = (_PyTime_t)d; + return 0; +} + +static int +_PyTime_FromObject(_PyTime_t *t, PyObject *obj, _PyTime_round_t round, + long unit_to_ns) +{ + if (PyFloat_Check(obj)) { + double d; + d = PyFloat_AsDouble(obj); + if (Py_IS_NAN(d)) { + PyErr_SetString(PyExc_ValueError, "Invalid value NaN (not a number)"); + return -1; + } + return _PyTime_FromDouble(t, d, round, unit_to_ns); + } + else { + long long sec; + Py_BUILD_ASSERT(sizeof(long long) <= sizeof(_PyTime_t)); + + sec = PyLong_AsLongLong(obj); + if (sec == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + _PyTime_overflow(); + } + return -1; + } + + if (_PyTime_check_mul_overflow(sec, unit_to_ns)) { + _PyTime_overflow(); + return -1; + } + *t = sec * unit_to_ns; + return 0; + } +} + +int +_PyTime_FromSecondsObject(_PyTime_t *t, PyObject *obj, _PyTime_round_t round) +{ + return _PyTime_FromObject(t, obj, round, SEC_TO_NS); +} + +int +_PyTime_FromMillisecondsObject(_PyTime_t *t, PyObject *obj, _PyTime_round_t round) +{ + return _PyTime_FromObject(t, obj, round, MS_TO_NS); +} + +double +_PyTime_AsSecondsDouble(_PyTime_t t) +{ + /* volatile avoids optimization changing how numbers are rounded */ + volatile double d; + + if (t % SEC_TO_NS == 0) { + _PyTime_t secs; + /* Divide using integers to avoid rounding issues on the integer part. + 1e-9 cannot be stored exactly in IEEE 64-bit. */ + secs = t / SEC_TO_NS; + d = (double)secs; + } + else { + d = (double)t; + d /= 1e9; + } + return d; +} + +PyObject * +_PyTime_AsNanosecondsObject(_PyTime_t t) +{ + Py_BUILD_ASSERT(sizeof(long long) >= sizeof(_PyTime_t)); + return PyLong_FromLongLong((long long)t); +} + +static _PyTime_t +_PyTime_Divide(const _PyTime_t t, const _PyTime_t k, + const _PyTime_round_t round) +{ + assert(k > 1); + if (round == _PyTime_ROUND_HALF_EVEN) { + _PyTime_t x, r, abs_r; + x = t / k; + r = t % k; + abs_r = Py_ABS(r); + if (abs_r > k / 2 || (abs_r == k / 2 && (Py_ABS(x) & 1))) { + if (t >= 0) { + x++; + } + else { + x--; + } + } + return x; + } + else if (round == _PyTime_ROUND_CEILING) { + if (t >= 0) { + return (t + k - 1) / k; + } + else { + return t / k; + } + } + else if (round == _PyTime_ROUND_FLOOR){ + if (t >= 0) { + return t / k; + } + else { + return (t - (k - 1)) / k; + } + } + else { + assert(round == _PyTime_ROUND_UP); + if (t >= 0) { + return (t + k - 1) / k; + } + else { + return (t - (k - 1)) / k; + } + } +} + +_PyTime_t +_PyTime_AsMilliseconds(_PyTime_t t, _PyTime_round_t round) +{ + return _PyTime_Divide(t, NS_TO_MS, round); +} + +_PyTime_t +_PyTime_AsMicroseconds(_PyTime_t t, _PyTime_round_t round) +{ + return _PyTime_Divide(t, NS_TO_US, round); +} + +static int +_PyTime_AsTimeval_impl(_PyTime_t t, _PyTime_t *p_secs, int *p_us, + _PyTime_round_t round) +{ + _PyTime_t secs, ns; + int usec; + int res = 0; + + secs = t / SEC_TO_NS; + ns = t % SEC_TO_NS; + + usec = (int)_PyTime_Divide(ns, US_TO_NS, round); + if (usec < 0) { + usec += SEC_TO_US; + if (secs != _PyTime_MIN) { + secs -= 1; + } + else { + res = -1; + } + } + else if (usec >= SEC_TO_US) { + usec -= SEC_TO_US; + if (secs != _PyTime_MAX) { + secs += 1; + } + else { + res = -1; + } + } + assert(0 <= usec && usec < SEC_TO_US); + + *p_secs = secs; + *p_us = usec; + + return res; +} + +static int +_PyTime_AsTimevalStruct_impl(_PyTime_t t, struct timeval *tv, + _PyTime_round_t round, int raise) +{ + _PyTime_t secs, secs2; + int us; + int res; + + res = _PyTime_AsTimeval_impl(t, &secs, &us, round); + +#ifdef MS_WINDOWS + tv->tv_sec = (long)secs; +#else + tv->tv_sec = secs; +#endif + tv->tv_usec = us; + + secs2 = (_PyTime_t)tv->tv_sec; + if (res < 0 || secs2 != secs) { + if (raise) { + error_time_t_overflow(); + } + return -1; + } + return 0; +} + +int +_PyTime_AsTimeval(_PyTime_t t, struct timeval *tv, _PyTime_round_t round) +{ + return _PyTime_AsTimevalStruct_impl(t, tv, round, 1); +} + +int +_PyTime_AsTimeval_noraise(_PyTime_t t, struct timeval *tv, _PyTime_round_t round) +{ + return _PyTime_AsTimevalStruct_impl(t, tv, round, 0); +} + +int +_PyTime_AsTimevalTime_t(_PyTime_t t, time_t *p_secs, int *us, + _PyTime_round_t round) +{ + _PyTime_t secs; + int res; + + res = _PyTime_AsTimeval_impl(t, &secs, us, round); + + *p_secs = secs; + + if (res < 0 || (_PyTime_t)*p_secs != secs) { + error_time_t_overflow(); + return -1; + } + return 0; +} + + +#if defined(HAVE_CLOCK_GETTIME) || defined(HAVE_KQUEUE) +int +_PyTime_AsTimespec(_PyTime_t t, struct timespec *ts) +{ + _PyTime_t secs, nsec; + + secs = t / SEC_TO_NS; + nsec = t % SEC_TO_NS; + if (nsec < 0) { + nsec += SEC_TO_NS; + secs -= 1; + } + ts->tv_sec = (time_t)secs; + assert(0 <= nsec && nsec < SEC_TO_NS); + ts->tv_nsec = nsec; + + if ((_PyTime_t)ts->tv_sec != secs) { + error_time_t_overflow(); + return -1; + } + return 0; +} +#endif + +static int +pygettimeofday(_PyTime_t *tp, _Py_clock_info_t *info, int raise) +{ +#ifdef MS_WINDOWS + FILETIME system_time; + ULARGE_INTEGER large; + + assert(info == NULL || raise); + + GetSystemTimeAsFileTime(&system_time); + large.u.LowPart = system_time.dwLowDateTime; + large.u.HighPart = system_time.dwHighDateTime; + /* 11,644,473,600,000,000,000: number of nanoseconds between + the 1st january 1601 and the 1st january 1970 (369 years + 89 leap + days). */ + *tp = large.QuadPart * 100 - 11644473600000000000; + if (info) { + DWORD timeAdjustment, timeIncrement; + BOOL isTimeAdjustmentDisabled, ok; + + info->implementation = "GetSystemTimeAsFileTime()"; + info->monotonic = 0; + ok = GetSystemTimeAdjustment(&timeAdjustment, &timeIncrement, + &isTimeAdjustmentDisabled); + if (!ok) { + PyErr_SetFromWindowsErr(0); + return -1; + } + info->resolution = timeIncrement * 1e-7; + info->adjustable = 1; + } + +#else /* MS_WINDOWS */ + int err; +#ifdef HAVE_CLOCK_GETTIME + struct timespec ts; +#else + struct timeval tv; +#endif + + assert(info == NULL || raise); + +#ifdef HAVE_CLOCK_GETTIME + err = clock_gettime(CLOCK_REALTIME, &ts); + if (err) { + if (raise) { + PyErr_SetFromErrno(PyExc_OSError); + } + return -1; + } + if (pytime_fromtimespec(tp, &ts, raise) < 0) { + return -1; + } + + if (info) { + struct timespec res; + info->implementation = "clock_gettime(CLOCK_REALTIME)"; + info->monotonic = 0; + info->adjustable = 1; + if (clock_getres(CLOCK_REALTIME, &res) == 0) { + info->resolution = res.tv_sec + res.tv_nsec * 1e-9; + } + else { + info->resolution = 1e-9; + } + } +#else /* HAVE_CLOCK_GETTIME */ + + /* test gettimeofday() */ +#ifdef GETTIMEOFDAY_NO_TZ + err = gettimeofday(&tv); +#else + err = gettimeofday(&tv, (struct timezone *)NULL); +#endif + if (err) { + if (raise) { + PyErr_SetFromErrno(PyExc_OSError); + } + return -1; + } + if (pytime_fromtimeval(tp, &tv, raise) < 0) { + return -1; + } + + if (info) { + info->implementation = "gettimeofday()"; + info->resolution = 1e-6; + info->monotonic = 0; + info->adjustable = 1; + } +#endif /* !HAVE_CLOCK_GETTIME */ +#endif /* !MS_WINDOWS */ + return 0; +} + +_PyTime_t +_PyTime_GetSystemClock(void) +{ + _PyTime_t t; + if (pygettimeofday(&t, NULL, 0) < 0) { + /* should not happen, _PyTime_Init() checked the clock at startup */ + Py_UNREACHABLE(); + } + return t; +} + +int +_PyTime_GetSystemClockWithInfo(_PyTime_t *t, _Py_clock_info_t *info) +{ + return pygettimeofday(t, info, 1); +} + +static int +pymonotonic(_PyTime_t *tp, _Py_clock_info_t *info, int raise) +{ +#if defined(MS_WINDOWS) + ULONGLONG ticks; + _PyTime_t t; + + assert(info == NULL || raise); + + ticks = GetTickCount64(); + Py_BUILD_ASSERT(sizeof(ticks) <= sizeof(_PyTime_t)); + t = (_PyTime_t)ticks; + + if (_PyTime_check_mul_overflow(t, MS_TO_NS)) { + if (raise) { + _PyTime_overflow(); + return -1; + } + /* Hello, time traveler! */ + Py_UNREACHABLE(); + } + *tp = t * MS_TO_NS; + + if (info) { + DWORD timeAdjustment, timeIncrement; + BOOL isTimeAdjustmentDisabled, ok; + info->implementation = "GetTickCount64()"; + info->monotonic = 1; + ok = GetSystemTimeAdjustment(&timeAdjustment, &timeIncrement, + &isTimeAdjustmentDisabled); + if (!ok) { + PyErr_SetFromWindowsErr(0); + return -1; + } + info->resolution = timeIncrement * 1e-7; + info->adjustable = 0; + } + +#elif defined(__APPLE__) + static mach_timebase_info_data_t timebase; + static uint64_t t0 = 0; + uint64_t ticks; + + if (timebase.denom == 0) { + /* According to the Technical Q&A QA1398, mach_timebase_info() cannot + fail: https://developer.apple.com/library/mac/#qa/qa1398/ */ + (void)mach_timebase_info(&timebase); + + /* Sanity check: should never occur in practice */ + if (timebase.numer < 1 || timebase.denom < 1) { + PyErr_SetString(PyExc_RuntimeError, + "invalid mach_timebase_info"); + return -1; + } + + /* Check that timebase.numer and timebase.denom can be casted to + _PyTime_t. In practice, timebase uses uint32_t, so casting cannot + overflow. At the end, only make sure that the type is uint32_t + (_PyTime_t is 64-bit long). */ + assert(sizeof(timebase.numer) < sizeof(_PyTime_t)); + assert(sizeof(timebase.denom) < sizeof(_PyTime_t)); + + /* Make sure that (ticks * timebase.numer) cannot overflow in + _PyTime_MulDiv(), with ticks < timebase.denom. + + Known time bases: + + * always (1, 1) on Intel + * (1000000000, 33333335) or (1000000000, 25000000) on PowerPC + + None of these time bases can overflow with 64-bit _PyTime_t, but + check for overflow, just in case. */ + if ((_PyTime_t)timebase.numer > _PyTime_MAX / (_PyTime_t)timebase.denom) { + PyErr_SetString(PyExc_OverflowError, + "mach_timebase_info is too large"); + return -1; + } + + t0 = mach_absolute_time(); + } + + if (info) { + info->implementation = "mach_absolute_time()"; + info->resolution = (double)timebase.numer / (double)timebase.denom * 1e-9; + info->monotonic = 1; + info->adjustable = 0; + } + + ticks = mach_absolute_time(); + /* Use a "time zero" to reduce precision loss when converting time + to floatting point number, as in time.monotonic(). */ + ticks -= t0; + *tp = _PyTime_MulDiv(ticks, + (_PyTime_t)timebase.numer, + (_PyTime_t)timebase.denom); + +#elif defined(__hpux) + hrtime_t time; + + time = gethrtime(); + if (time == -1) { + if (raise) { + PyErr_SetFromErrno(PyExc_OSError); + } + return -1; + } + + *tp = time; + + if (info) { + info->implementation = "gethrtime()"; + info->resolution = 1e-9; + info->monotonic = 1; + info->adjustable = 0; + } + +#else + struct timespec ts; +#ifdef CLOCK_HIGHRES + const clockid_t clk_id = CLOCK_HIGHRES; + const char *implementation = "clock_gettime(CLOCK_HIGHRES)"; +#else + const clockid_t clk_id = CLOCK_MONOTONIC; + const char *implementation = "clock_gettime(CLOCK_MONOTONIC)"; +#endif + + assert(info == NULL || raise); + + if (clock_gettime(clk_id, &ts) != 0) { + if (raise) { + PyErr_SetFromErrno(PyExc_OSError); + return -1; + } + return -1; + } + + if (info) { + struct timespec res; + info->monotonic = 1; + info->implementation = implementation; + info->adjustable = 0; + if (clock_getres(clk_id, &res) != 0) { + PyErr_SetFromErrno(PyExc_OSError); + return -1; + } + info->resolution = res.tv_sec + res.tv_nsec * 1e-9; + } + if (pytime_fromtimespec(tp, &ts, raise) < 0) { + return -1; + } +#endif + return 0; +} + +_PyTime_t +_PyTime_GetMonotonicClock(void) +{ + _PyTime_t t; + if (pymonotonic(&t, NULL, 0) < 0) { + /* should not happen, _PyTime_Init() checked that monotonic clock at + startup */ + Py_UNREACHABLE(); + } + return t; +} + +int +_PyTime_GetMonotonicClockWithInfo(_PyTime_t *tp, _Py_clock_info_t *info) +{ + return pymonotonic(tp, info, 1); +} + + +#ifdef MS_WINDOWS +static int +win_perf_counter(_PyTime_t *tp, _Py_clock_info_t *info) +{ + static LONGLONG frequency = 0; + static LONGLONG t0 = 0; + LARGE_INTEGER now; + LONGLONG ticksll; + _PyTime_t ticks; + + if (frequency == 0) { + LARGE_INTEGER freq; + if (!QueryPerformanceFrequency(&freq)) { + PyErr_SetFromWindowsErr(0); + return -1; + } + frequency = freq.QuadPart; + + /* Sanity check: should never occur in practice */ + if (frequency < 1) { + PyErr_SetString(PyExc_RuntimeError, + "invalid QueryPerformanceFrequency"); + return -1; + } + + /* Check that frequency can be casted to _PyTime_t. + + Make also sure that (ticks * SEC_TO_NS) cannot overflow in + _PyTime_MulDiv(), with ticks < frequency. + + Known QueryPerformanceFrequency() values: + + * 10,000,000 (10 MHz): 100 ns resolution + * 3,579,545 Hz (3.6 MHz): 279 ns resolution + + None of these frequencies can overflow with 64-bit _PyTime_t, but + check for overflow, just in case. */ + if (frequency > _PyTime_MAX + || frequency > (LONGLONG)_PyTime_MAX / (LONGLONG)SEC_TO_NS) { + PyErr_SetString(PyExc_OverflowError, + "QueryPerformanceFrequency is too large"); + return -1; + } + + QueryPerformanceCounter(&now); + t0 = now.QuadPart; + } + + if (info) { + info->implementation = "QueryPerformanceCounter()"; + info->resolution = 1.0 / (double)frequency; + info->monotonic = 1; + info->adjustable = 0; + } + + QueryPerformanceCounter(&now); + ticksll = now.QuadPart; + + /* Use a "time zero" to reduce precision loss when converting time + to floatting point number, as in time.perf_counter(). */ + ticksll -= t0; + + /* Make sure that casting LONGLONG to _PyTime_t cannot overflow, + both types are signed */ + Py_BUILD_ASSERT(sizeof(ticksll) <= sizeof(ticks)); + ticks = (_PyTime_t)ticksll; + + *tp = _PyTime_MulDiv(ticks, SEC_TO_NS, (_PyTime_t)frequency); + return 0; +} +#endif + + +int +_PyTime_GetPerfCounterWithInfo(_PyTime_t *t, _Py_clock_info_t *info) +{ +#ifdef MS_WINDOWS + return win_perf_counter(t, info); +#else + return _PyTime_GetMonotonicClockWithInfo(t, info); +#endif +} + + +_PyTime_t +_PyTime_GetPerfCounter(void) +{ + _PyTime_t t; + if (_PyTime_GetPerfCounterWithInfo(&t, NULL)) { + Py_UNREACHABLE(); + } + return t; +} + + +int +_PyTime_Init(void) +{ + /* check that time.time(), time.monotonic() and time.perf_counter() clocks + are working properly to not have to check for exceptions at runtime. If + a clock works once, it cannot fail in next calls. */ + _PyTime_t t; + if (_PyTime_GetSystemClockWithInfo(&t, NULL) < 0) { + return -1; + } + if (_PyTime_GetMonotonicClockWithInfo(&t, NULL) < 0) { + return -1; + } + if (_PyTime_GetPerfCounterWithInfo(&t, NULL) < 0) { + return -1; + } + return 0; +} + +int +_PyTime_localtime(time_t t, struct tm *tm) +{ +#ifdef MS_WINDOWS + int error; + + error = localtime_s(tm, &t); + if (error != 0) { + errno = error; + PyErr_SetFromErrno(PyExc_OSError); + return -1; + } + return 0; +#else /* !MS_WINDOWS */ + +#ifdef _AIX + /* bpo-34373: AIX does not return NULL if t is too small or too large */ + if (t < -2145916800 /* 1902-01-01 */ + || t > 2145916800 /* 2038-01-01 */) { + errno = EINVAL; + PyErr_SetString(PyExc_OverflowError, + "localtime argument out of range"); + return -1; + } +#endif + + errno = 0; + if (localtime_r(&t, tm) == NULL) { + if (errno == 0) { + errno = EINVAL; + } + PyErr_SetFromErrno(PyExc_OSError); + return -1; + } + return 0; +#endif /* MS_WINDOWS */ +} + +int +_PyTime_gmtime(time_t t, struct tm *tm) +{ +#ifdef MS_WINDOWS + int error; + + error = gmtime_s(tm, &t); + if (error != 0) { + errno = error; + PyErr_SetFromErrno(PyExc_OSError); + return -1; + } + return 0; +#else /* !MS_WINDOWS */ + if (gmtime_r(&t, tm) == NULL) { +#ifdef EINVAL + if (errno == 0) { + errno = EINVAL; + } +#endif + PyErr_SetFromErrno(PyExc_OSError); + return -1; + } + return 0; +#endif /* MS_WINDOWS */ +}