diff --git a/internal/docker/deployer.go b/internal/docker/deployer.go index 7950ec6a..40ce645c 100644 --- a/internal/docker/deployer.go +++ b/internal/docker/deployer.go @@ -20,10 +20,12 @@ import ( "crypto/tls" "fmt" "log" + "net" "net/http" "net/url" "os" "runtime" + "strconv" "strings" "sync" "time" @@ -217,6 +219,11 @@ func deployImage( log.Printf("Sharing %v host environment variables with container", env) } + port1, port2, err := allocateHostPorts() + if err != nil { + return nil, err + } + body, err := docker.ContainerCreate(ctx, &container.Config{ Image: imageID, Env: env, @@ -232,12 +239,14 @@ func deployImage( PortBindings: nat.PortMap{ nat.Port("8008/tcp"): []nat.PortBinding{ { - HostIP: "127.0.0.1", + HostIP: "127.0.0.1", + HostPort: strconv.Itoa(port1), }, }, nat.Port("8448/tcp"): []nat.PortBinding{ { - HostIP: "127.0.0.1", + HostIP: "127.0.0.1", + HostPort: strconv.Itoa(port2), }, }, }, @@ -328,61 +337,6 @@ func deployImage( ) } - var lastErr error - - // Inspect health status of container to check it is up - stopTime := time.Now().Add(cfg.SpawnHSTimeout) - iterCount := 0 - if inspect.State.Health != nil { - // If the container has a healthcheck, wait for it first - for { - iterCount += 1 - if time.Now().After(stopTime) { - lastErr = fmt.Errorf("timed out checking for homeserver to be up: %s", lastErr) - break - } - inspect, err = docker.ContainerInspect(ctx, containerID) - if err != nil { - lastErr = fmt.Errorf("inspect container %s => error: %s", containerID, err) - time.Sleep(50 * time.Millisecond) - continue - } - if inspect.State.Health.Status != "healthy" { - lastErr = fmt.Errorf("inspect container %s => health: %s", containerID, inspect.State.Health.Status) - time.Sleep(50 * time.Millisecond) - continue - } - lastErr = nil - break - - } - } - - // Having optionally waited for container to self-report healthy - // hit /versions to check it is actually responding - versionsURL := fmt.Sprintf("%s/_matrix/client/versions", baseURL) - - for { - iterCount += 1 - if time.Now().After(stopTime) { - lastErr = fmt.Errorf("timed out checking for homeserver to be up: %s", lastErr) - break - } - res, err := http.Get(versionsURL) - if err != nil { - lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err) - time.Sleep(50 * time.Millisecond) - continue - } - if res.StatusCode != 200 { - lastErr = fmt.Errorf("GET %s => HTTP %s", versionsURL, res.Status) - time.Sleep(50 * time.Millisecond) - continue - } - lastErr = nil - break - } - d := &HomeserverDeployment{ BaseURL: baseURL, FedBaseURL: fedBaseURL, @@ -391,8 +345,11 @@ func deployImage( ApplicationServices: asIDToRegistrationFromLabels(inspect.Config.Labels), DeviceIDs: deviceIDsFromLabels(inspect.Config.Labels), } - if lastErr != nil { - return d, fmt.Errorf("%s: failed to check server is up. %w", contextStr, lastErr) + + stopTime := time.Now().Add(cfg.SpawnHSTimeout) + iterCount, err := waitForContainer(ctx, docker, d, stopTime) + if err != nil { + return d, fmt.Errorf("%s: failed to check server is up. %w", contextStr, err) } else { if cfg.DebugLoggingEnabled { log.Printf("%s: Server is responding after %d iterations", contextStr, iterCount) @@ -401,6 +358,39 @@ func deployImage( return d, nil } +// Picks two free ports on localhost. Does not reserve them in any way. +// The returned ports must be used before the next call to `allocateHostPorts`, +// otherwise the same pair of ports may be returned. +func allocateHostPorts() (int, int, error) { + localhostAnyPort := net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + } + + listener1, err := net.ListenTCP("tcp", &localhostAnyPort) + if err != nil { + return 0, 0, err + } + listener2, err := net.ListenTCP("tcp", &localhostAnyPort) + if err != nil { + return 0, 0, err + } + + port1 := listener1.Addr().(*net.TCPAddr).Port + port2 := listener2.Addr().(*net.TCPAddr).Port + + err = listener1.Close() + if err != nil { + return 0, 0, err + } + err = listener2.Close() + if err != nil { + return 0, 0, err + } + + return port1, port2, nil +} + func copyToContainer(docker *client.Client, containerID, path string, data []byte) error { // Create a fake/virtual file in memory that we can copy to the container // via https://stackoverflow.com/a/52131297/796832 @@ -427,6 +417,99 @@ func copyToContainer(docker *client.Client, containerID, path string, data []byt return nil } +// Waits until a homeserver deployment is ready to serve requests. +func waitForContainer(ctx context.Context, docker *client.Client, hsDep *HomeserverDeployment, stopTime time.Time) (iterCount int, err error) { + var lastErr error = nil + + iterCount = 0 + + // If the container has a healthcheck, wait for it first + for { + iterCount += 1 + if time.Now().After(stopTime) { + lastErr = fmt.Errorf("timed out checking for homeserver to be up: %s", lastErr) + break + } + inspect, err := docker.ContainerInspect(ctx, hsDep.ContainerID) + if err != nil { + lastErr = fmt.Errorf("inspect container %s => error: %s", hsDep.ContainerID, err) + time.Sleep(50 * time.Millisecond) + continue + } + if inspect.State.Health != nil && + inspect.State.Health.Status != "healthy" { + lastErr = fmt.Errorf("inspect container %s => health: %s", hsDep.ContainerID, inspect.State.Health.Status) + time.Sleep(50 * time.Millisecond) + continue + } + + // The container is healthy or has no health check. + lastErr = nil + break + } + + // Having optionally waited for container to self-report healthy + // hit /versions to check it is actually responding + versionsURL := fmt.Sprintf("%s/_matrix/client/versions", hsDep.BaseURL) + + for { + iterCount += 1 + if time.Now().After(stopTime) { + lastErr = fmt.Errorf("timed out checking for homeserver to be up: %s", lastErr) + break + } + res, err := http.Get(versionsURL) + if err != nil { + lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err) + time.Sleep(50 * time.Millisecond) + continue + } + if res.StatusCode != 200 { + lastErr = fmt.Errorf("GET %s => HTTP %s", versionsURL, res.Status) + time.Sleep(50 * time.Millisecond) + continue + } + lastErr = nil + break + } + + return iterCount, lastErr +} + +// Restart a deployment. +func (dep *Deployment) Restart() error { + ctx := context.Background() + + for _, hsDep := range dep.HS { + err := dep.Deployer.Docker.ContainerStop(ctx, hsDep.ContainerID, &dep.Config.SpawnHSTimeout) + if err != nil { + return fmt.Errorf("failed to restart container %s: %s", hsDep.ContainerID, err) + } + + // Remove the container from the network. If we don't do this, + // (re)starting the container fails with an error like + // "Error response from daemon: endpoint with name complement_fed_1_fed.alice.hs1_1 already exists in network complement_fed_alice". + err = dep.Deployer.Docker.NetworkDisconnect(ctx, dep.Deployer.networkID, hsDep.ContainerID, false) + if err != nil { + return fmt.Errorf("failed to restart container %s: %s", hsDep.ContainerID, err) + } + + err = dep.Deployer.Docker.ContainerStart(ctx, hsDep.ContainerID, types.ContainerStartOptions{}) + if err != nil { + return fmt.Errorf("failed to restart container %s: %s", hsDep.ContainerID, err) + } + + // Wait for the container to be ready. + stopTime := time.Now().Add(dep.Config.SpawnHSTimeout) + _, err = waitForContainer(ctx, dep.Deployer.Docker, &hsDep, stopTime) + if err != nil { + return fmt.Errorf("failed to restart container %s: %s", hsDep.ContainerID, err) + } + } + + return nil +} + // RoundTripper is a round tripper that maps https://hs1 to the federation port of the container // e.g https://localhost:35352 type RoundTripper struct { diff --git a/tests/federation_room_join_partial_state_test.go b/tests/federation_room_join_partial_state_test.go index 3ca760e4..41784261 100644 --- a/tests/federation_room_join_partial_state_test.go +++ b/tests/federation_room_join_partial_state_test.go @@ -1,3 +1,4 @@ +//go:build faster_joins // +build faster_joins // This file contains tests for joining rooms over federation, with the @@ -281,6 +282,62 @@ func TestPartialStateJoin(t *testing.T) { } }) + // test that a partial-state join continues syncing state after a restart + // the same as SyncBlocksDuringPartialStateJoin, with a restart in the middle + t.Run("PartialStateJoinContinuesAfterRestart", func(t *testing.T) { + deployment := Deploy(t, b.BlueprintAlice) + defer deployment.Destroy(t) + alice := deployment.Client(t, "hs1", "@alice:hs1") + + psjResult := beginPartialStateJoin(t, deployment, alice) + defer psjResult.Destroy() + + // Alice has now joined the room, and the server is syncing the state in the background. + + // wait for the state_ids request to arrive + psjResult.AwaitStateIdsRequest(t) + + // restart the homeserver + err := deployment.Restart() + if err != nil { + t.Errorf("Failed to restart homeserver: %s", err) + } + + // attempts to sync should block. Fire off a goroutine to try it. + syncResponseChan := make(chan gjson.Result) + defer close(syncResponseChan) + go func() { + response, _ := alice.MustSync(t, client.SyncReq{}) + syncResponseChan <- response + }() + + // we expect another state_ids request to arrive. + // we'd do another AwaitStateIdsRequest, except it's single-use. + + // the client-side requests should still be waiting + select { + case <-syncResponseChan: + t.Fatalf("Sync completed before state resync complete") + default: + } + + // release the federation /state response + psjResult.FinishStateRequest() + + // the /sync request should now complete, with the new room + var syncRes gjson.Result + select { + case <-time.After(1 * time.Second): + t.Fatalf("/sync request request did not complete") + case syncRes = <-syncResponseChan: + } + + roomRes := syncRes.Get("rooms.join." + client.GjsonEscape(psjResult.ServerRoom.RoomID)) + if !roomRes.Exists() { + t.Fatalf("/sync completed without join to new room\n") + } + }) + // test a lazy-load-members sync while re-syncing partial state, followed by completion of state syncing, // followed by a gappy sync. the gappy sync should include the correct member state, // since it was not sent on the previous sync.