From 7f08b30eab0e109b69a8bf22fc7e4a115b7412ac Mon Sep 17 00:00:00 2001 From: Arun Philip Date: Thu, 29 Jan 2026 14:42:11 -0500 Subject: [PATCH] upgrade quic-go to v0.59.0, fix Go 1.25 compatibility - Remove deprecated ConnectionTracingID/ConnectionTracingKey - Replace hijacker callbacks with RawClientConn (upstream API) - Fix SupportsDatagrams.Remote check - Add transport tests Fixes #482 --- go.mod | 2 +- go.sum | 6 +- internal/http3/client.go | 66 +++++++------- internal/http3/conn.go | 119 ++++++++++--------------- internal/http3/transport.go | 48 +++++++++-- internal/http3/transport_test.go | 143 +++++++++++++++++++++++++++++++ 6 files changed, 260 insertions(+), 124 deletions(-) create mode 100644 internal/http3/transport_test.go diff --git a/go.mod b/go.mod index e08be99..ff4b779 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/icholy/digest v1.1.0 github.com/klauspost/compress v1.18.2 github.com/quic-go/qpack v0.6.0 - github.com/quic-go/quic-go v0.57.1 + github.com/quic-go/quic-go v0.59.0 github.com/refraction-networking/utls v1.8.1 golang.org/x/net v0.48.0 golang.org/x/text v0.32.0 diff --git a/go.sum b/go.sum index abc0b4e..ed6aa9e 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= -github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI10= -github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -33,8 +33,6 @@ golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/http3/client.go b/internal/http3/client.go index f878612..704c0a5 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -84,8 +84,6 @@ func newClientConn( conn *quic.Conn, enableDatagrams bool, additionalSettings map[uint64]uint64, - streamHijacker func(FrameType, quic.ConnectionTracingID, *quic.Stream, error) (hijacked bool, err error), - uniStreamHijacker func(StreamType, quic.ConnectionTracingID, *quic.ReceiveStream, error) (hijacked bool), maxResponseHeaderBytes int, disableCompression bool, logger *slog.Logger, @@ -122,13 +120,14 @@ func newClientConn( c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") } }() - if streamHijacker != nil { - go c.handleBidirectionalStreams(streamHijacker) - } - go c.conn.handleUnidirectionalStreams(uniStreamHijacker) return c } +// handleUnidirectionalStream handles an incoming unidirectional stream. +func (c *ClientConn) handleUnidirectionalStream(str *quic.ReceiveStream) { + c.conn.handleUnidirectionalStream(str) +} + // OpenRequestStream opens a new request stream on the HTTP/3 connection. func (c *ClientConn) OpenRequestStream(ctx context.Context) (*RequestStream, error) { return c.conn.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes) @@ -166,37 +165,6 @@ func (c *ClientConn) setupConn() error { return err } -func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, *quic.Stream, error) (hijacked bool, err error)) { - for { - str, err := c.conn.conn.AcceptStream(context.Background()) - if err != nil { - if c.logger != nil { - c.logger.Debug("accepting bidirectional stream failed", "error", err) - } - return - } - fp := &frameParser{ - r: str, - closeConn: c.conn.CloseWithError, - unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) { - id := c.conn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) - return streamHijacker(ft, id, str, e) - }, - } - go func() { - if _, err := fp.ParseNext(c.conn.qlogger); err == errHijacked { - return - } - if err != nil { - if c.logger != nil { - c.logger.Debug("error handling stream", "error", err) - } - } - c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") - }() - } -} - // RoundTrip executes a request and returns a response func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { rsp, err := c.roundTrip(req) @@ -435,3 +403,27 @@ func (c *ClientConn) doRequest(req *http.Request, str *RequestStream) (*http.Res func (c *ClientConn) Conn() *Conn { return c.conn } + +// HandleBidirectionalStream handles an incoming bidirectional stream. +// According to RFC 9114, the server is not allowed to open bidirectional streams, +// so this method closes the connection with an error. +func (c *ClientConn) HandleBidirectionalStream(str *quic.Stream) { + c.conn.CloseWithError( + quic.ApplicationErrorCode(ErrCodeStreamCreationError), + fmt.Sprintf("server opened bidirectional stream %d", str.StreamID()), + ) +} + +// RawClientConn is a low-level HTTP/3 client connection. +// It allows the application to take control of the stream accept loops, +// giving the application the ability to handle streams originating from the server. +// This is useful for implementing WebTransport or other advanced protocols. +type RawClientConn struct { + *ClientConn +} + +// HandleUnidirectionalStream handles an incoming unidirectional stream. +// This should be called for each unidirectional stream accepted from the QUIC connection. +func (c *RawClientConn) HandleUnidirectionalStream(str *quic.ReceiveStream) { + c.handleUnidirectionalStream(str) +} diff --git a/internal/http3/conn.go b/internal/http3/conn.go index 5e20ce6..b1f09c2 100644 --- a/internal/http3/conn.go +++ b/internal/http3/conn.go @@ -57,6 +57,11 @@ type Conn struct { idleTimer *time.Timer qlogger qlogwriter.Recorder + + // Track received unidirectional streams (only one of each type allowed) + rcvdControlStr atomic.Bool + rcvdQPACKEncoderStr atomic.Bool + rcvdQPACKDecoderStr atomic.Bool } func newConnection( @@ -232,80 +237,48 @@ func (c *Conn) CloseWithError(code quic.ApplicationErrorCode, msg string) error return c.conn.CloseWithError(code, msg) } -func (c *Conn) handleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, *quic.ReceiveStream, error) (hijacked bool)) { - var ( - rcvdControlStr atomic.Bool - rcvdQPACKEncoderStr atomic.Bool - rcvdQPACKDecoderStr atomic.Bool - ) - - for { - str, err := c.conn.AcceptUniStream(context.Background()) - if err != nil { - if c.logger != nil { - c.logger.Debug("accepting unidirectional stream failed", "error", err) - } - return +func (c *Conn) handleUnidirectionalStream(str *quic.ReceiveStream) { + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) + if err != nil { + if c.logger != nil { + c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err) } - - go func(str *quic.ReceiveStream) { - streamType, err := quicvarint.Read(quicvarint.NewReader(str)) - if err != nil { - id := c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) - if hijack != nil && hijack(StreamType(streamType), id, str, err) { - return - } - if c.logger != nil { - c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err) - } - return - } - // We're only interested in the control stream here. - switch streamType { - case streamTypeControlStream: - case streamTypeQPACKEncoderStream: - if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst { - c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") - } - // Our QPACK implementation doesn't use the dynamic table yet. - return - case streamTypeQPACKDecoderStream: - if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst { - c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") - } - // Our QPACK implementation doesn't use the dynamic table yet. - return - case streamTypePushStream: - if c.isServer { - // only the server can push - c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") - } else { - // we never increased the Push ID, so we don't expect any push streams - c.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") - } - return - default: - if hijack != nil { - if hijack( - StreamType(streamType), - c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID), - str, - nil, - ) { - return - } - } - str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) - return - } - // Only a single control stream is allowed. - if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr { - c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") - return - } - c.handleControlStream(str) - }(str) + return + } + // We're only interested in the control stream here. + switch streamType { + case streamTypeControlStream: + case streamTypeQPACKEncoderStream: + if isFirst := c.rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst { + c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") + } + // Our QPACK implementation doesn't use the dynamic table yet. + return + case streamTypeQPACKDecoderStream: + if isFirst := c.rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst { + c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") + } + // Our QPACK implementation doesn't use the dynamic table yet. + return + case streamTypePushStream: + if c.isServer { + // only the server can push + c.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") + } else { + // we never increased the Push ID, so we don't expect any push streams + c.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") + } + return + default: + str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) + return + } + // Only a single control stream is allowed. + if isFirstControlStr := c.rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr { + c.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") + return } + c.handleControlStream(str) } func (c *Conn) handleControlStream(str *quic.ReceiveStream) { @@ -335,7 +308,7 @@ func (c *Conn) handleControlStream(str *quic.ReceiveStream) { // If datagram support was enabled on our side as well as on the server side, // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if c.enableDatagrams && !c.ConnectionState().SupportsDatagrams { + if c.enableDatagrams && !c.ConnectionState().SupportsDatagrams.Remote { c.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") return } diff --git a/internal/http3/transport.go b/internal/http3/transport.go index ebe4b5b..acde40d 100644 --- a/internal/http3/transport.go +++ b/internal/http3/transport.go @@ -41,6 +41,7 @@ type RoundTripOpt struct { type clientConn interface { OpenRequestStream(context.Context) (*RequestStream, error) RoundTrip(*http.Request) (*http.Response, error) + handleUnidirectionalStream(*quic.ReceiveStream) } type roundTripperWithCount struct { @@ -99,9 +100,6 @@ type Transport struct { // However, if the user explicitly requested gzip it is not automatically uncompressed. DisableCompression bool - StreamHijacker func(FrameType, quic.ConnectionTracingID, *quic.Stream, error) (hijacked bool, err error) - UniStreamHijacker func(StreamType, quic.ConnectionTracingID, *quic.ReceiveStream, error) (hijacked bool) - Logger *slog.Logger mutex sync.Mutex @@ -136,8 +134,6 @@ func (t *Transport) init() error { conn, t.EnableDatagrams, t.AdditionalSettings, - t.StreamHijacker, - t.UniStreamHijacker, t.MaxResponseHeaderBytes, t.DisableCompression, t.Logger, @@ -410,7 +406,23 @@ func (t *Transport) dial(ctx context.Context, hostname string) (*quic.Conn, clie if err != nil { return nil, nil, err } - return conn, t.newClientConn(conn), nil + cc := t.newClientConn(conn) + startUnidirectionalStreamAcceptLoop(conn, cc) + return conn, cc, nil +} + +// startUnidirectionalStreamAcceptLoop starts a goroutine that accepts incoming +// unidirectional streams and passes them to the clientConn for handling. +func startUnidirectionalStreamAcceptLoop(conn *quic.Conn, cc clientConn) { + go func() { + for { + str, err := conn.AcceptUniStream(context.Background()) + if err != nil { + return + } + go cc.handleUnidirectionalStream(str) + } + }() } func (t *Transport) resolveUDPAddr(ctx context.Context, network, addr string) (*net.UDPAddr, error) { @@ -448,17 +460,35 @@ func (t *Transport) removeClient(hostname string) { // Obtaining a ClientConn is only needed for more advanced use cases, such as // using Extended CONNECT for WebTransport or the various MASQUE protocols. func (t *Transport) NewClientConn(conn *quic.Conn) *ClientConn { - return newClientConn( + c := newClientConn( t.Options, conn, t.EnableDatagrams, t.AdditionalSettings, - t.StreamHijacker, - t.UniStreamHijacker, t.MaxResponseHeaderBytes, t.DisableCompression, t.Logger, ) + startUnidirectionalStreamAcceptLoop(conn, c) + return c +} + +// NewRawClientConn creates a new low-level HTTP/3 client connection on top of a QUIC connection. +// Unlike NewClientConn, the returned RawClientConn allows the application to take control +// of the stream accept loops, by calling HandleUnidirectionalStream for incoming unidirectional +// streams and HandleBidirectionalStream for incoming bidirectional streams. +func (t *Transport) NewRawClientConn(conn *quic.Conn) *RawClientConn { + return &RawClientConn{ + ClientConn: newClientConn( + t.Options, + conn, + t.EnableDatagrams, + t.AdditionalSettings, + t.MaxResponseHeaderBytes, + t.DisableCompression, + t.Logger, + ), + } } // Close closes the QUIC connections that this Transport has used. diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go new file mode 100644 index 0000000..4c8713c --- /dev/null +++ b/internal/http3/transport_test.go @@ -0,0 +1,143 @@ +package http3 + +import ( + "context" + "crypto/tls" + "net/http" + "testing" + "time" + + "github.com/imroc/req/v3/internal/testcert" + "github.com/quic-go/quic-go" +) + +func TestTransportInit(t *testing.T) { + tr := &Transport{} + // Trigger init by calling RoundTrip with invalid request + // This tests that init() doesn't panic + _, err := tr.RoundTrip(&http.Request{}) + if err == nil { + t.Fatal("expected error for nil URL") + } +} + +func TestTransportInitWithDatagrams(t *testing.T) { + tr := &Transport{ + EnableDatagrams: true, + QUICConfig: &quic.Config{ + EnableDatagrams: true, + }, + } + _, err := tr.RoundTrip(&http.Request{}) + if err == nil { + t.Fatal("expected error for nil URL") + } +} + +func TestTransportInitDatagramMismatch(t *testing.T) { + tr := &Transport{ + EnableDatagrams: true, + QUICConfig: &quic.Config{ + EnableDatagrams: false, + }, + } + _, err := tr.RoundTrip(&http.Request{}) + if err == nil || err.Error() != "HTTP Datagrams enabled, but QUIC Datagrams disabled" { + t.Fatalf("expected datagram mismatch error, got: %v", err) + } +} + +func TestRawClientConnType(t *testing.T) { + // Test that RawClientConn embeds ClientConn correctly + // This is a compile-time type check + type hasClientConn interface { + RoundTrip(*http.Request) (*http.Response, error) + } + // Verify RawClientConn satisfies the same interface as ClientConn + var _ hasClientConn = (*RawClientConn)(nil) + var _ hasClientConn = (*ClientConn)(nil) +} + +// TestNewClientConnAndRawClientConn tests the creation of client connections +func TestNewClientConnAndRawClientConn(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + // Create TLS config from test certificates + cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) + if err != nil { + t.Fatalf("failed to load test cert: %v", err) + } + serverTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{"h3"}, + } + + // Start QUIC listener + listener, err := quic.ListenAddr("127.0.0.1:0", serverTLSConfig, &quic.Config{}) + if err != nil { + t.Fatalf("failed to start QUIC listener: %v", err) + } + defer listener.Close() + + serverAddr := listener.Addr().String() + + // Accept connections in background + go func() { + for { + conn, err := listener.Accept(context.Background()) + if err != nil { + return + } + conn.CloseWithError(0, "test done") + } + }() + + // Create client transport + clientTLSConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h3"}, + } + + tr := &Transport{ + TLSClientConfig: clientTLSConfig, + QUICConfig: &quic.Config{ + MaxIdleTimeout: 5 * time.Second, + }, + } + defer tr.Close() + + // Test dialing a QUIC connection + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + quicConn, err := quic.DialAddr(ctx, serverAddr, clientTLSConfig, tr.QUICConfig) + if err != nil { + t.Fatalf("failed to dial QUIC: %v", err) + } + defer quicConn.CloseWithError(0, "") + + // Test NewClientConn + clientConn := tr.NewClientConn(quicConn) + if clientConn == nil { + t.Fatal("NewClientConn returned nil") + } + + // Test NewRawClientConn with a new connection + quicConn2, err := quic.DialAddr(ctx, serverAddr, clientTLSConfig, tr.QUICConfig) + if err != nil { + t.Fatalf("failed to dial QUIC for raw conn: %v", err) + } + defer quicConn2.CloseWithError(0, "") + + rawConn := tr.NewRawClientConn(quicConn2) + if rawConn == nil { + t.Fatal("NewRawClientConn returned nil") + } + if rawConn.ClientConn == nil { + t.Fatal("RawClientConn.ClientConn is nil") + } + + t.Log("Successfully created ClientConn and RawClientConn") +}