diff --git a/awscrt/http.py b/awscrt/http.py index d6b5f3a3b..35093cd3f 100644 --- a/awscrt/http.py +++ b/awscrt/http.py @@ -281,7 +281,8 @@ def new(cls, def request(self, request: 'HttpRequest', on_response: Optional[Callable[..., None]] = None, - on_body: Optional[Callable[..., None]] = None) -> 'HttpClientStream': + on_body: Optional[Callable[..., None]] = None, + manual_write: bool = False) -> 'HttpClientStream': """Create :class:`HttpClientStream` to carry out the request/response exchange. NOTE: The HTTP stream sends no data until :meth:`HttpClientStream.activate()` @@ -320,10 +321,33 @@ def request(self, An exception raise by this function will cause the HTTP stream to end in error. This callback is always invoked on the connection's event-loop thread. + manual_write (bool): If True, enables manual data writing on the stream. + This allows calling :meth:`HttpClientStream.write_data()` to stream + the request body in chunks. Works for both HTTP/1.1 and HTTP/2. + + By design, CRT does not support setting both a body stream and + enabling manual writes for HTTP/1.1. Body streams are intended + for requests whose payload is available at the time of sending. + Manual writes let the caller control when data is sent. The two + cannot coexist on HTTP/1.1. + + If ``manual_write`` is True and ``request`` has a ``body_stream``, + this method raises :class:`ValueError`. + + HTTP/2 does not have this restriction. + Returns: HttpClientStream: + + Raises: + ValueError: If ``manual_write`` is True and the request has a + ``body_stream`` set (HTTP/1.1 only). """ - return HttpClientStream(self, request, on_response, on_body) + if manual_write and request.body_stream is not None: + raise ValueError( + "Cannot use manual data writes with a body_stream on an HTTP/1.1 request. " + "Either remove the body_stream or set manual_write=False.") + return HttpClientStream(self, request, on_response, on_body, manual_write) def close(self) -> "concurrent.futures.Future": """Close the connection. @@ -526,7 +550,7 @@ def _init_common(self, request: 'HttpRequest', on_response: Optional[Callable[..., None]] = None, on_body: Optional[Callable[..., None]] = None, - http2_manual_write: bool = False) -> None: + manual_write: bool = False) -> None: assert isinstance(connection, HttpClientConnectionBase) assert isinstance(request, HttpRequest) assert callable(on_response) or on_response is None @@ -540,7 +564,7 @@ def _init_common(self, # keep HttpRequest alive until stream completes self._request = request self._version = connection.version - self._binding = _awscrt.http_client_stream_new(self, connection, request, http2_manual_write) + self._binding = _awscrt.http_client_stream_new(self, connection, request, manual_write) @property def version(self) -> HttpVersion: @@ -586,6 +610,49 @@ def update_window(self, increment_size: int) -> None: """ _awscrt.http_stream_update_window(self, increment_size) + def write_data(self, + data_stream: Union[InputStream, Any], + end_stream: bool = False) -> "concurrent.futures.Future": + '''Write data to the request body. + + Works for both HTTP/1.1 and HTTP/2 streams. + The stream must have been created with ``manual_write=True``. + You must call :meth:`HttpClientStream.activate()` before using this method. + + .. note:: + This is the unified API for manual body writes, superseding the + C-level ``aws_http1_stream_write_chunk``. Use ``write_data()`` + for all new code — the old chunked-write API is deprecated. + + See :meth:`HttpClientStream.write_data` and + :meth:`Http2ClientStream.write_data` for protocol-specific + behaviour and constraints. + + Args: + data_stream (InputStream): Data to write. If not an InputStream, + it will be wrapped in one. Can be None to write zero bytes. + + end_stream (bool): True to indicate this is the last chunk and no more data + will be sent. False if more chunks will follow. + + Returns: + concurrent.futures.Future: Future that completes when the write operation + is done. The future will contain None on success, or an exception on failure. + ''' + future = Future() + body_stream = InputStream.wrap(data_stream, allow_none=True) + + def on_write_complete(error_code: int) -> None: + if future.cancelled(): + return + if error_code: + future.set_exception(awscrt.exceptions.from_code(error_code)) + else: + future.set_result(None) + + _awscrt.http_stream_write_data(self, body_stream, end_stream, on_write_complete) + return future + class HttpClientStream(HttpClientStreamBase): """HTTP stream that sends a request and receives a response. @@ -608,8 +675,9 @@ def __init__(self, connection: HttpClientConnection, request: 'HttpRequest', on_response: Optional[Callable[..., None]] = None, - on_body: Optional[Callable[..., None]] = None) -> None: - self._init_common(connection, request, on_response, on_body) + on_body: Optional[Callable[..., None]] = None, + manual_write: bool = False) -> None: + self._init_common(connection, request, on_response, on_body, manual_write) def activate(self) -> None: """Begin sending the request. @@ -683,39 +751,26 @@ def activate(self) -> None: def write_data(self, data_stream: Union[InputStream, Any], end_stream: bool = False) -> "concurrent.futures.Future": - """Write a chunk of data to the request body stream. + '''Write data to the HTTP/2 request body. - This method is only available when the stream was created with - manual_write=True. This allows incremental writing of request data. + The stream must have been created with ``manual_write=True`` and + :meth:`activate()` must have been called before using this method. - Note: In the asyncio version, this is replaced by the request_body_generator parameter - which accepts an async generator. + When both a body stream and manual writes are enabled, the body + stream is sent first and the connection then waits asynchronously + for subsequent ``write_data()`` calls. However, if the body stream + has not signalled end-of-stream, the event loop will keep getting + scheduled for requesting more data until it completes. Args: - data_stream (Union[InputStream, Any]): Data to write. If not an InputStream, - it will be wrapped in one. Can be None to send an empty chunk. - - end_stream (bool): True to indicate this is the last chunk and no more data - will be sent. False if more chunks will follow. + data_stream: Data to write. Wrapped in :class:`~awscrt.io.InputStream` if + needed. ``None`` sends zero bytes. + end_stream (bool): ``True`` if this is the last write. Returns: - concurrent.futures.Future: Future that completes when the write operation - is done. The future will contain None on success, or an exception on failure. - """ - future = Future() - body_stream = InputStream.wrap(data_stream, allow_none=True) - - def on_write_complete(error_code: int) -> None: - if future.cancelled(): - # the future was cancelled, so we don't need to set the result or exception - return - if error_code: - future.set_exception(awscrt.exceptions.from_code(error_code)) - else: - future.set_result(None) - - _awscrt.http2_client_stream_write_data(self, body_stream, end_stream, on_write_complete) - return future + concurrent.futures.Future: Completes with ``None`` on success. + ''' + return super().write_data(data_stream, end_stream) class HttpMessageBase(NativeResource): diff --git a/crt/aws-c-http b/crt/aws-c-http index 8bf9e53dd..da535b1bf 160000 --- a/crt/aws-c-http +++ b/crt/aws-c-http @@ -1 +1 @@ -Subproject commit 8bf9e53ddc1057d8581f407c609e372370fd1e40 +Subproject commit da535b1bf9c9334730eb78a26a1bbb3c069b38c9 diff --git a/source/http.h b/source/http.h index dbdab9af1..bb330a957 100644 --- a/source/http.h +++ b/source/http.h @@ -50,6 +50,7 @@ PyObject *aws_py_http_client_stream_new(PyObject *self, PyObject *args); PyObject *aws_py_http_client_stream_activate(PyObject *self, PyObject *args); PyObject *aws_py_http2_client_stream_write_data(PyObject *self, PyObject *args); +PyObject *aws_py_http_stream_write_data(PyObject *self, PyObject *args); /* Create capsule around new request-style aws_http_message struct */ PyObject *aws_py_http_message_new_request(PyObject *self, PyObject *args); diff --git a/source/http_stream.c b/source/http_stream.c index 503a9fcad..57b7a1495 100644 --- a/source/http_stream.c +++ b/source/http_stream.c @@ -303,6 +303,7 @@ PyObject *aws_py_http_client_stream_new(PyObject *self, PyObject *args) { .on_complete = s_on_stream_complete, .on_h2_remote_end_stream = s_on_h2_remote_end_stream, .user_data = stream, + .use_manual_data_writes = http2_manual_write, .http2_use_manual_data_writes = http2_manual_write, }; @@ -410,3 +411,63 @@ PyObject *aws_py_http2_client_stream_write_data(PyObject *self, PyObject *args) } Py_RETURN_NONE; } + +static void s_on_http_stream_write_data_complete(struct aws_http_stream *stream, int error_code, void *user_data) { + (void)stream; + PyObject *py_on_write_complete = (PyObject *)user_data; + AWS_FATAL_ASSERT(py_on_write_complete); + PyGILState_STATE state; + if (aws_py_gilstate_ensure(&state)) { + return; /* Python has shut down. Nothing matters anymore, but don't crash */ + } + + PyObject *result = PyObject_CallFunction(py_on_write_complete, "(i)", error_code); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + Py_DECREF(py_on_write_complete); + PyGILState_Release(state); +} + +PyObject *aws_py_http_stream_write_data(PyObject *self, PyObject *args) { + (void)self; + + PyObject *py_stream = NULL; + PyObject *py_body_stream = NULL; + int end_stream = false; + PyObject *py_on_write_complete = NULL; + if (!PyArg_ParseTuple(args, "OOpO", &py_stream, &py_body_stream, &end_stream, &py_on_write_complete)) { + return NULL; + } + + struct aws_http_stream *http_stream = aws_py_get_http_stream(py_stream); + if (!http_stream) { + return NULL; + } + + struct aws_input_stream *body_stream = NULL; + if (py_body_stream != Py_None) { + body_stream = aws_py_get_input_stream(py_body_stream); + if (!body_stream) { + return PyErr_AwsLastError(); + } + } + + Py_INCREF(py_on_write_complete); + + struct aws_http_stream_write_data_options write_options = { + .data = body_stream, + .end_stream = end_stream, + .on_complete = s_on_http_stream_write_data_complete, + .user_data = py_on_write_complete, + }; + + int error = aws_http_stream_write_data(http_stream, &write_options); + if (error) { + Py_DECREF(py_on_write_complete); + return PyErr_AwsLastError(); + } + Py_RETURN_NONE; +} diff --git a/source/module.c b/source/module.c index 45f3322ea..0d1c52fed 100644 --- a/source/module.c +++ b/source/module.c @@ -1052,6 +1052,7 @@ static PyMethodDef s_module_methods[] = { AWS_PY_METHOD_DEF(http_client_stream_new, METH_VARARGS), AWS_PY_METHOD_DEF(http_client_stream_activate, METH_VARARGS), AWS_PY_METHOD_DEF(http2_client_stream_write_data, METH_VARARGS), + AWS_PY_METHOD_DEF(http_stream_write_data, METH_VARARGS), AWS_PY_METHOD_DEF(http_message_new_request, METH_VARARGS), AWS_PY_METHOD_DEF(http_message_get_request_method, METH_VARARGS), AWS_PY_METHOD_DEF(http_message_set_request_method, METH_VARARGS), diff --git a/test/test_http_client.py b/test/test_http_client.py index 5364d7e3b..30476dfcb 100644 --- a/test/test_http_client.py +++ b/test/test_http_client.py @@ -961,5 +961,129 @@ def test_h2_remote_end_stream_ordering(self): connection.close().result(self.timeout) +@unittest.skipUnless(os.environ.get('AWS_TEST_LOCALHOST'), 'set env var to run test: AWS_TEST_LOCALHOST') +class TestH1WriteData(LocalServerTestBase): + """HTTP/1.1 write_data() tests — mirrors Java WriteDataTest scenarios. + Uses the existing LocalServerTestBase and PUT echo pattern.""" + timeout = 10 + + def _new_h1_tls_connection(self): + """Create HTTP/1.1 TLS connection to local server""" + tls_ctx_opt = TlsContextOptions() + tls_ctx_opt.verify_peer = False + tls_ctx = ClientTlsContext(tls_ctx_opt) + tls_conn_opt = tls_ctx.new_connection_options() + tls_conn_opt.set_server_name(self.hostname) + + event_loop_group = EventLoopGroup() + host_resolver = DefaultHostResolver(event_loop_group) + bootstrap = ClientBootstrap(event_loop_group, host_resolver) + + connection = HttpClientConnection.new( + host_name=self.hostname, + port=self.port, + bootstrap=bootstrap, + tls_connection_options=tls_conn_opt, + ).result(self.timeout) + self.assertEqual(connection.version, HttpVersion.Http1_1) + return connection + + def test_h1_write_data(self): + """H1 PUT with manual write — mirrors Java testHttp1WriteData""" + self._start_server(secure=True) + try: + connection = self._new_h1_tls_connection() + payload = b'hello from writeData h1' + + request = HttpRequest('PUT', '/write_data_test') + request.headers.add('host', self.hostname) + request.headers.add('Content-Length', str(len(payload))) + + response = Response() + stream = connection.request(request, response.on_response, response.on_body, manual_write=True) + stream.activate() + + stream.write_data(BytesIO(payload), end_stream=True).result(self.timeout) + status = stream.completion_future.result(self.timeout) + + self.assertEqual(200, status) + self.assertEqual(payload, self.server.put_requests.get('/write_data_test')) + + connection.close().result(self.timeout) + finally: + self._stop_server() + + def test_h1_write_data_end_stream_only(self): + """H1 PUT with zero-byte body — mirrors Java testHttp1WriteDataEndStreamOnly""" + self._start_server(secure=True) + try: + connection = self._new_h1_tls_connection() + + request = HttpRequest('PUT', '/write_data_empty') + request.headers.add('host', self.hostname) + request.headers.add('Content-Length', '0') + + response = Response() + stream = connection.request(request, response.on_response, response.on_body, manual_write=True) + stream.activate() + + stream.write_data(None, end_stream=True).result(self.timeout) + status = stream.completion_future.result(self.timeout) + + self.assertEqual(200, status) + self.assertEqual(b'', self.server.put_requests.get('/write_data_empty')) + + connection.close().result(self.timeout) + finally: + self._stop_server() + + def test_h1_write_data_multi_chunk(self): + """H1 PUT with multiple write_data calls""" + self._start_server(secure=True) + try: + connection = self._new_h1_tls_connection() + chunks = [b'chunk1', b'chunk2', b'chunk3'] + total = b''.join(chunks) + + request = HttpRequest('PUT', '/write_data_multi') + request.headers.add('host', self.hostname) + request.headers.add('Content-Length', str(len(total))) + + response = Response() + stream = connection.request(request, response.on_response, response.on_body, manual_write=True) + stream.activate() + + for i, chunk in enumerate(chunks): + stream.write_data(BytesIO(chunk), end_stream=(i == len(chunks) - 1)).result(self.timeout) + + status = stream.completion_future.result(self.timeout) + + self.assertEqual(200, status) + self.assertEqual(total, self.server.put_requests.get('/write_data_multi')) + + connection.close().result(self.timeout) + finally: + self._stop_server() + + def test_h1_write_data_with_body_stream_raises(self): + """H1 request() must raise ValueError if manual_write=True and body_stream is set.""" + self._start_server(secure=True) + try: + connection = self._new_h1_tls_connection() + + request = HttpRequest('PUT', '/write_data_guard', body_stream=BytesIO(b'body')) + request.headers.add('host', self.hostname) + + try: + connection.request(request, manual_write=True) + self.fail("Expected ValueError from request()") + except ValueError as e: + self.assertIn('manual data writes', str(e)) + + connection.close().result(self.timeout) + finally: + self._stop_server() + + if __name__ == '__main__': unittest.main()