diff --git a/internal/federation/handle.go b/internal/federation/handle.go index 4133e79e..e00affb1 100644 --- a/internal/federation/handle.go +++ b/internal/federation/handle.go @@ -128,12 +128,12 @@ func SendJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request authEvents := room.AuthChainForEvents(stateEvents) + // get servers in room *before* the join event + serversInRoom := room.ServersInRoom() + // insert the join event into the room state room.AddEvent(event) - // servers in room: just us. TODO(faster_joins): this may not be correct - serversInRoom := []string{s.serverName} - // return state and auth chain b, err := json.Marshal(gomatrixserverlib.RespSendJoin{ Origin: gomatrixserverlib.ServerName(s.serverName), diff --git a/internal/federation/server_room.go b/internal/federation/server_room.go index ae906278..779d85be 100644 --- a/internal/federation/server_room.go +++ b/internal/federation/server_room.go @@ -97,24 +97,38 @@ func (r *ServerRoom) AuthChain() (chain []*gomatrixserverlib.Event) { return r.AuthChainForEvents(r.AllCurrentState()) } -// AuthChainForEvents returns all auth events for all events in the given state TODO: recursively +// AuthChainForEvents returns all auth events for all events in the given state func (r *ServerRoom) AuthChainForEvents(events []*gomatrixserverlib.Event) (chain []*gomatrixserverlib.Event) { chainMap := make(map[string]bool) - // get all the auth event IDs - for _, ev := range events { + + // build a map of all events in the room + // Timeline and State contain different sets of events, so check them both. + eventsByID := map[string]*gomatrixserverlib.Event{} + for _, ev := range r.Timeline { + eventsByID[ev.EventID()] = ev + } + for _, ev := range r.State { + eventsByID[ev.EventID()] = ev + } + + // a queue of events whose auth events are to be included in the auth chain + queue := []*gomatrixserverlib.Event{} + queue = append(queue, events...) + + // get all the auth events recursively + // we extend the "queue" as we go along + for i := 0; i < len(queue); i++ { + ev := queue[i] for _, evID := range ev.AuthEventIDs() { if chainMap[evID] { continue } chainMap[evID] = true + chain = append(chain, eventsByID[evID]) + queue = append(queue, eventsByID[evID]) } } - // find them in the timeline - for _, tev := range r.Timeline { - if chainMap[tev.EventID()] { - chain = append(chain, tev) - } - } + return } @@ -134,6 +148,33 @@ func (r *ServerRoom) MustHaveMembershipForUser(t *testing.T, userID, wantMembers } } +// ServersInRoom gets all servers currently joined to the room +func (r *ServerRoom) ServersInRoom() (servers []string) { + serverSet := make(map[string]struct{}) + + for _, ev := range r.State { + if ev.Type() != "m.room.member" { + continue + } + membership, err := ev.Membership() + if err != nil || membership != "join" { + continue + } + _, server, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + if err != nil { + continue + } + + serverSet[string(server)] = struct{}{} + } + + for server := range serverSet { + servers = append(servers, server) + } + + return +} + func initialPowerLevelsContent(roomCreator string) (c gomatrixserverlib.PowerLevelContent) { c.Defaults() c.Events = map[string]int64{ diff --git a/tests/federation_room_join_partial_state_test.go b/tests/federation_room_join_partial_state_test.go index 45ec6a74..3a4d6909 100644 --- a/tests/federation_room_join_partial_state_test.go +++ b/tests/federation_room_join_partial_state_test.go @@ -182,6 +182,103 @@ func TestPartialStateJoin(t *testing.T) { }) } }) + + // test that a partial-state join can fall back to other homeservers when re-syncing + // partial state. + t.Run("PartialStateJoinSyncsUsingOtherHomeservers", func(t *testing.T) { + // set up 3 homeservers: hs1, hs2 and complement + deployment := Deploy(t, b.BlueprintFederationTwoLocalOneRemote) + defer deployment.Destroy(t) + alice := deployment.Client(t, "hs1", "@alice:hs1") + charlie := deployment.Client(t, "hs2", "@charlie:hs2") + + // create a public room + roomID := alice.CreateRoom(t, map[string]interface{}{ + "preset": "public_chat", + }) + + // create the complement homeserver + server := federation.NewServer(t, deployment, + federation.HandleKeyRequests(), + federation.HandlePartialStateMakeSendJoinRequests(), + federation.HandleEventRequests(), + federation.HandleTransactionRequests( + func(e *gomatrixserverlib.Event) { + t.Fatalf("Received unexpected PDU: %s", string(e.JSON())) + }, + // hs1 may send us presence when alice syncs + nil, + ), + ) + cancelListener := server.Listen() + defer cancelListener() + + // join complement to the public room + room := server.MustJoinRoom(t, deployment, "hs1", roomID, server.UserID("david")) + + // we expect a /state_ids request from hs2 after it joins the room + // we will respond to the request with garbage + fedStateIdsRequestReceivedWaiter := NewWaiter() + fedStateIdsSendResponseWaiter := NewWaiter() + server.Mux().Handle( + fmt.Sprintf("/_matrix/federation/v1/state_ids/%s", roomID), + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + queryParams := req.URL.Query() + t.Logf("Incoming state_ids request for event %s in room %s", queryParams["event_id"], roomID) + fedStateIdsRequestReceivedWaiter.Finish() + fedStateIdsSendResponseWaiter.Wait(t, 60*time.Second) + t.Logf("Replying to /state_ids request with invalid response") + + w.WriteHeader(200) + + if _, err := w.Write([]byte("{}")); err != nil { + t.Errorf("Error writing to request: %v", err) + } + }), + ).Methods("GET") + + // join charlie on hs2 to the room, via the complement homeserver + charlie.JoinRoom(t, roomID, []string{server.ServerName()}) + + // and let hs1 know that charlie has joined, + // otherwise hs1 will refuse /state_ids requests + member_event := room.CurrentState("m.room.member", charlie.UserID).JSON() + server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{member_event}, nil) + alice.MustSyncUntil(t, client.SyncReq{}, client.SyncJoinedTo(charlie.UserID, roomID)) + + // wait until hs2 starts syncing state + fedStateIdsRequestReceivedWaiter.Waitf(t, 5*time.Second, "Waiting for /state_ids request") + + syncResponseChan := make(chan gjson.Result) + defer close(syncResponseChan) + go func() { + response, _ := charlie.MustSync(t, client.SyncReq{}) + syncResponseChan <- response + }() + + // the client-side requests should still be waiting + select { + case <-syncResponseChan: + t.Fatalf("hs2 sync completed before state resync complete") + default: + } + + // reply to hs2 with a bogus /state_ids response + fedStateIdsSendResponseWaiter.Finish() + + // charlie's /sync request should now complete, with the new room + var syncRes gjson.Result + select { + case <-time.After(1 * time.Second): + t.Fatalf("hs2 /sync request request did not complete") + case syncRes = <-syncResponseChan: + } + + roomRes := syncRes.Get("rooms.join." + client.GjsonEscape(roomID)) + if !roomRes.Exists() { + t.Fatalf("hs2 /sync completed without join to new room\n") + } + }) } // buildLazyLoadingSyncFilter constructs a json-marshalled filter suitable the 'Filter' field of a client.SyncReq