Skip to content

Commit b1026d2

Browse files
committed
gh-135056: Use response_headers only in SimpleHTTPRequestHandler
1 parent a3243fe commit b1026d2

File tree

4 files changed

+54
-51
lines changed

4 files changed

+54
-51
lines changed

Doc/library/http.server.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,9 @@ instantiation, of which this module provides three different variants:
375375
The *directory* parameter accepts a :term:`path-like object`.
376376

377377
.. versionchanged:: next
378-
The *response_headers* parameter accepts an optional dictionary of
379-
additional HTTP headers to add to each response.
378+
Added *response_headers*, which accepts an optional dictionary of
379+
additional HTTP headers to add to each successful HTTP status 200
380+
response. All other status code responses will not include these headers.
380381

381382
A lot of the work, such as parsing the request, is done by the base class
382383
:class:`BaseHTTPRequestHandler`. This class implements the :func:`do_GET`

Lib/http/server.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -117,24 +117,13 @@ class HTTPServer(socketserver.TCPServer):
117117
allow_reuse_address = True # Seems to make sense in testing environment
118118
allow_reuse_port = True
119119

120-
def __init__(self, *args, response_headers=None, **kwargs):
121-
self.response_headers = response_headers
122-
super().__init__(*args, **kwargs)
123-
124120
def server_bind(self):
125121
"""Override server_bind to store the server name."""
126122
socketserver.TCPServer.server_bind(self)
127123
host, port = self.server_address[:2]
128124
self.server_name = socket.getfqdn(host)
129125
self.server_port = port
130126

131-
def finish_request(self, request, client_address):
132-
"""Finish one request by instantiating RequestHandlerClass."""
133-
args = (request, client_address, self)
134-
kwargs = {}
135-
if hasattr(self, 'response_headers'):
136-
kwargs['response_headers'] = self.response_headers
137-
self.RequestHandlerClass(request, client_address, self, **kwargs)
138127

139128
class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
140129
daemon_threads = True
@@ -143,7 +132,7 @@ class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
143132
class HTTPSServer(HTTPServer):
144133
def __init__(self, server_address, RequestHandlerClass,
145134
bind_and_activate=True, *, certfile, keyfile=None,
146-
password=None, alpn_protocols=None, **http_server_kwargs):
135+
password=None, alpn_protocols=None):
147136
try:
148137
import ssl
149138
except ImportError:
@@ -161,8 +150,7 @@ def __init__(self, server_address, RequestHandlerClass,
161150

162151
super().__init__(server_address,
163152
RequestHandlerClass,
164-
bind_and_activate,
165-
**http_server_kwargs)
153+
bind_and_activate)
166154

167155
def server_activate(self):
168156
"""Wrap the socket in SSLSocket."""
@@ -726,6 +714,13 @@ def do_HEAD(self):
726714
if f:
727715
f.close()
728716

717+
def send_custom_response_headers(self):
718+
"""Send the headers stored in self.response_headers"""
719+
# User specified response_headers
720+
if self.response_headers is not None:
721+
for header, value in self.response_headers.items():
722+
self.send_header(header, value)
723+
729724
def send_head(self):
730725
"""Common code for GET and HEAD commands.
731726
@@ -749,10 +744,6 @@ def send_head(self):
749744
new_url = urllib.parse.urlunsplit(new_parts)
750745
self.send_header("Location", new_url)
751746
self.send_header("Content-Length", "0")
752-
# User specified response_headers
753-
if self.response_headers is not None:
754-
for header, value in self.response_headers.items():
755-
self.send_header(header, value)
756747
self.end_headers()
757748
return None
758749
for index in self.index_pages:
@@ -812,9 +803,7 @@ def send_head(self):
812803
self.send_header("Content-Length", str(fs[6]))
813804
self.send_header("Last-Modified",
814805
self.date_time_string(fs.st_mtime))
815-
if self.response_headers is not None:
816-
for header, value in self.response_headers.items():
817-
self.send_header(header, value)
806+
self.send_custom_response_headers()
818807
self.end_headers()
819808
return f
820809
except:
@@ -879,6 +868,7 @@ def list_directory(self, path):
879868
self.send_response(HTTPStatus.OK)
880869
self.send_header("Content-type", "text/html; charset=%s" % enc)
881870
self.send_header("Content-Length", str(len(encoded)))
871+
self.send_custom_response_headers()
882872
self.end_headers()
883873
return f
884874

@@ -990,8 +980,7 @@ def _get_best_family(*address):
990980
def test(HandlerClass=BaseHTTPRequestHandler,
991981
ServerClass=ThreadingHTTPServer,
992982
protocol="HTTP/1.0", port=8000, bind=None,
993-
tls_cert=None, tls_key=None, tls_password=None,
994-
response_headers=None):
983+
tls_cert=None, tls_key=None, tls_password=None):
995984
"""Test the HTTP request handler class.
996985
997986
This runs an HTTP server on port 8000 (or the port argument).
@@ -1002,10 +991,9 @@ def test(HandlerClass=BaseHTTPRequestHandler,
1002991

1003992
if tls_cert:
1004993
server = ServerClass(addr, HandlerClass, certfile=tls_cert,
1005-
keyfile=tls_key, password=tls_password,
1006-
response_headers=response_headers)
994+
keyfile=tls_key, password=tls_password)
1007995
else:
1008-
server = ServerClass(addr, HandlerClass, response_headers=response_headers)
996+
server = ServerClass(addr, HandlerClass)
1009997

1010998
with server as httpd:
1011999
host, port = httpd.socket.getsockname()[:2]
@@ -1067,6 +1055,10 @@ def _main(args=None):
10671055
except OSError as e:
10681056
parser.error(f"Failed to read TLS password file: {e}")
10691057

1058+
response_headers = {}
1059+
for header, value in args.header or []:
1060+
response_headers[header] = value
1061+
10701062
# ensure dual-stack is not disabled; ref #38907
10711063
class DualStackServerMixin:
10721064

@@ -1080,18 +1072,14 @@ def server_bind(self):
10801072
def finish_request(self, request, client_address):
10811073
self.RequestHandlerClass(request, client_address, self,
10821074
directory=args.directory,
1083-
response_headers=self.response_headers)
1075+
response_headers=response_headers)
10841076

10851077
class HTTPDualStackServer(DualStackServerMixin, ThreadingHTTPServer):
10861078
pass
10871079
class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
10881080
pass
10891081

10901082
ServerClass = HTTPSDualStackServer if args.tls_cert else HTTPDualStackServer
1091-
response_headers = {}
1092-
for header, value in args.header or []:
1093-
response_headers[header] = value
1094-
10951083

10961084
test(
10971085
HandlerClass=SimpleHTTPRequestHandler,
@@ -1102,7 +1090,6 @@ class HTTPSDualStackServer(DualStackServerMixin, ThreadingHTTPSServer):
11021090
tls_cert=args.tls_cert,
11031091
tls_key=args.tls_key,
11041092
tls_password=tls_key_password,
1105-
response_headers=response_headers or None
11061093
)
11071094

11081095

Lib/socketserver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ class BaseRequestHandler:
757757
758758
"""
759759

760-
def __init__(self, request, client_address, server, **kwargs):
760+
def __init__(self, request, client_address, server):
761761
self.request = request
762762
self.client_address = client_address
763763
self.server = server

Lib/test/test_httpservers.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,24 +81,21 @@ def test_https_server_raises_runtime_error(self):
8181

8282

8383
class TestServerThread(threading.Thread):
84-
def __init__(self, test_object, request_handler, tls=None, server_kwargs=None):
84+
def __init__(self, test_object, request_handler, tls=None):
8585
threading.Thread.__init__(self)
8686
self.request_handler = request_handler
8787
self.test_object = test_object
8888
self.tls = tls
89-
self.server_kwargs = server_kwargs or {}
9089

9190
def run(self):
9291
if self.tls:
9392
certfile, keyfile, password = self.tls
9493
self.server = create_https_server(
9594
certfile, keyfile, password,
9695
request_handler=self.request_handler,
97-
**self.server_kwargs
9896
)
9997
else:
100-
self.server = HTTPServer(('localhost', 0), self.request_handler,
101-
**self.server_kwargs)
98+
self.server = HTTPServer(('localhost', 0), self.request_handler)
10299
self.test_object.HOST, self.test_object.PORT = self.server.socket.getsockname()
103100
self.test_object.server_started.set()
104101
self.test_object = None
@@ -116,14 +113,12 @@ class BaseTestCase(unittest.TestCase):
116113

117114
# Optional tuple (certfile, keyfile, password) to use for HTTPS servers.
118115
tls = None
119-
server_kwargs = None
120116

121117
def setUp(self):
122118
self._threads = threading_helper.threading_setup()
123119
os.environ = os_helper.EnvironmentVarGuard()
124120
self.server_started = threading.Event()
125-
self.thread = TestServerThread(self, self.request_handler, self.tls,
126-
self.server_kwargs)
121+
self.thread = TestServerThread(self, self.request_handler, self.tls)
127122
self.thread.start()
128123
self.server_started.wait()
129124

@@ -470,8 +465,14 @@ def test_err(self):
470465
self.assertEndsWith(lines[1], '"ERROR / HTTP/1.1" 404 -')
471466

472467

468+
class CustomHeaderSimpleHTTPRequestHandler(SimpleHTTPRequestHandler):
469+
custom_headers = None
470+
def __init__(self, *args, directory=None, response_headers=None, **kwargs):
471+
super().__init__(*args, response_headers=self.custom_headers, **kwargs)
472+
473+
473474
class SimpleHTTPServerTestCase(BaseTestCase):
474-
class request_handler(NoLogRequestHandler, SimpleHTTPRequestHandler):
475+
class request_handler(NoLogRequestHandler, CustomHeaderSimpleHTTPRequestHandler):
475476
pass
476477

477478
def setUp(self):
@@ -828,6 +829,26 @@ def test_path_without_leading_slash(self):
828829
self.assertEqual(response.getheader("Location"),
829830
self.tempdir_name + "/?hi=1")
830831

832+
def test_custom_headers_list_dir(self):
833+
with mock.patch.object(self.request_handler, 'custom_headers', new={
834+
'X-Test1': 'test1',
835+
'X-Test2': 'test2',
836+
}):
837+
response = self.request(self.base_url + '/')
838+
self.assertEqual(response.getheader("X-Test1"), 'test1')
839+
self.assertEqual(response.getheader("X-Test2"), 'test2')
840+
841+
def test_custom_headers_get_file(self):
842+
with mock.patch.object(self.request_handler, 'custom_headers', new={
843+
'X-Test1': 'test1',
844+
'X-Test2': 'test2',
845+
}):
846+
data = b"Dummy index file\r\n"
847+
with open(os.path.join(self.tempdir_name, 'index.html'), 'wb') as f:
848+
f.write(data)
849+
response = self.request(self.base_url + '/')
850+
self.assertEqual(response.getheader("X-Test1"), 'test1')
851+
self.assertEqual(response.getheader("X-Test2"), 'test2')
831852

832853
class SocketlessRequestHandler(SimpleHTTPRequestHandler):
833854
def __init__(self, directory=None):
@@ -1311,7 +1332,6 @@ class CommandLineTestCase(unittest.TestCase):
13111332
'tls_cert': None,
13121333
'tls_key': None,
13131334
'tls_password': None,
1314-
'response_headers': None,
13151335
}
13161336

13171337
def setUp(self):
@@ -1379,13 +1399,8 @@ def test_protocol_flag(self, mock_func):
13791399

13801400
@mock.patch('http.server.test')
13811401
def test_header_flag(self, mock_func):
1402+
call_args = self.args
13821403
self.invoke_httpd('--header', 'h1', 'v1', '-H', 'h2', 'v2')
1383-
call_args = self.args | dict(
1384-
response_headers={
1385-
'h1': 'v1',
1386-
'h2': 'v2'
1387-
}
1388-
)
13891404
mock_func.assert_called_once_with(**call_args)
13901405
mock_func.reset_mock()
13911406

0 commit comments

Comments
 (0)