diff --git a/pytest_socket.py b/pytest_socket.py index 296f730..0c8dd63 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -74,10 +74,12 @@ def disable_socket(): """ disable socket.socket to disable the Internet. useful in testing. """ - def guarded(*args, **kwargs): - raise SocketBlockedError() + class GuardedSocket(socket.socket): + """ socket guard to disable socket creation (from pytest-socket) """ + def __new__(cls, *args, **kwargs): + raise SocketBlockedError() - socket.socket = guarded + socket.socket = GuardedSocket def enable_socket(): diff --git a/tests/test_socket.py b/tests/test_socket.py index 9c13d4d..acccfce 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -210,3 +210,21 @@ def test_socket2(): """) result = testdir.runpytest("--verbose", "--disable-socket") result.assert_outcomes(1, 0, 2) + + +def test_socket_subclass_is_still_blocked(testdir): + testdir.makepyfile(""" + import pytest + import pytest_socket + import socket + + @pytest.mark.disable_socket + def test_subclass_is_still_blocked(): + + class MySocket(socket.socket): + pass + + MySocket(socket.AF_INET, socket.SOCK_STREAM) + """) + result = testdir.runpytest("--verbose") + assert_socket_blocked(result)