From a0d57e4b49f5fc5916a9505ec358b7c7e47f24a4 Mon Sep 17 00:00:00 2001 From: Evan Anderson Date: Fri, 22 Apr 2022 12:19:15 -0700 Subject: [PATCH] Use listen to force a connection timeout --- network/transports_test.go | 47 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/network/transports_test.go b/network/transports_test.go index 156fe3aaf2..4ee37fd67f 100644 --- a/network/transports_test.go +++ b/network/transports_test.go @@ -111,8 +111,20 @@ func testDialWithBackoffConnectionRefused(tlsConf *tls.Config) func(t *testing.T func testDialWithBackoffTimeout(tlsConf *tls.Config) func(t *testing.T) { return func(t *testing.T) { - // Timeout. Use non-routable IP. See: https://stackoverflow.com/a/31581323/844449 - c, err := dialer(context.TODO(), tlsConf)("10.0.0.0:81") + // Create a listening socket with backlog 1, then occupy the backlog to force a timeout. + closer, addr, err := listenOne() + if err != nil { + t.Fatal("Unable to create listener:", err) + } + defer closer() + c1, err := net.Dial("tcp4", addr.String()) + if err != nil { + t.Fatalf("Unable to connect to server on %s: %s", addr, err) + } + defer c1.Close() + + // Since the backlog is full, the next request must time out. + c, err := dialer(context.TODO(), tlsConf)(addr.String()) if err == nil { closeOrFail(t, c) t.Error("Unexpected success dialing") @@ -193,3 +205,34 @@ func findUnusedPortOrFail(tb testing.TB) int { defer closeOrFail(tb, l) return l.Addr().(*net.TCPAddr).Port } + +// Golang doesn't allow us to set the backlog argument on syscall.Listen from +// net.ListenTCP, so we need to get directly into syscall land. +func listenOne() (func(), *net.TCPAddr, error) { + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0) + if err != nil { + return nil, nil, fmt.Errorf("Couldn't get socket: %w", err) + } + sa := &syscall.SockaddrInet4{ + Port: 0, + Addr: [4]byte{127, 0, 0, 1}, + } + if err = syscall.Bind(fd, sa); err != nil { + return nil, nil, fmt.Errorf("Unable to bind: %w", err) + } + if err = syscall.Listen(fd, 0); err != nil { + return nil, nil, fmt.Errorf("Unable to Listen: %w", err) + } + closer := func() { syscall.Close(fd) } + listenaddr, err := syscall.Getsockname(fd) + if err != nil { + closer() + return nil, nil, fmt.Errorf("Could not get sockname: %w", err) + } + sa = listenaddr.(*syscall.SockaddrInet4) + addr := &net.TCPAddr{ + IP: sa.Addr[:], + Port: sa.Port, + } + return closer, addr, nil +}