Skip to content
Merged
6 changes: 3 additions & 3 deletions internal/federation/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ func SendJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request

authEvents := room.AuthChainForEvents(stateEvents)

// get servers in room *before* the join event
serversInRoom := room.ServersInRoom()

// insert the join event into the room state
room.AddEvent(event)

// servers in room: just us. TODO(faster_joins): this may not be correct
serversInRoom := []string{s.serverName}

// return state and auth chain
b, err := json.Marshal(gomatrixserverlib.RespSendJoin{
Origin: gomatrixserverlib.ServerName(s.serverName),
Expand Down
59 changes: 50 additions & 9 deletions internal/federation/server_room.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,38 @@ func (r *ServerRoom) AuthChain() (chain []*gomatrixserverlib.Event) {
return r.AuthChainForEvents(r.AllCurrentState())
}

// AuthChainForEvents returns all auth events for all events in the given state TODO: recursively
// AuthChainForEvents returns all auth events for all events in the given state
func (r *ServerRoom) AuthChainForEvents(events []*gomatrixserverlib.Event) (chain []*gomatrixserverlib.Event) {
chainMap := make(map[string]bool)
// get all the auth event IDs
for _, ev := range events {

// build a map of all events in the room
// Timeline and State contain different sets of events, so check them both.
eventsByID := map[string]*gomatrixserverlib.Event{}
for _, ev := range r.Timeline {
eventsByID[ev.EventID()] = ev
}
for _, ev := range r.State {
eventsByID[ev.EventID()] = ev
}

// a queue of events whose auth events are to be included in the auth chain
queue := []*gomatrixserverlib.Event{}
queue = append(queue, events...)

// get all the auth events recursively
// we extend the "queue" as we go along
for i := 0; i < len(queue); i++ {
ev := queue[i]
for _, evID := range ev.AuthEventIDs() {
if chainMap[evID] {
continue
}
chainMap[evID] = true
chain = append(chain, eventsByID[evID])
queue = append(queue, eventsByID[evID])
}
}
// find them in the timeline
for _, tev := range r.Timeline {
if chainMap[tev.EventID()] {
chain = append(chain, tev)
}
}

return
}
Comment thread
kegsay marked this conversation as resolved.

Expand All @@ -134,6 +148,33 @@ func (r *ServerRoom) MustHaveMembershipForUser(t *testing.T, userID, wantMembers
}
}

// ServersInRoom gets all servers currently joined to the room
func (r *ServerRoom) ServersInRoom() (servers []string) {
serverSet := make(map[string]struct{})

for _, ev := range r.State {
if ev.Type() != "m.room.member" {
continue
}
membership, err := ev.Membership()
if err != nil || membership != "join" {
continue
}
_, server, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
continue
}

serverSet[string(server)] = struct{}{}
}

for server := range serverSet {
servers = append(servers, server)
}

return
}

func initialPowerLevelsContent(roomCreator string) (c gomatrixserverlib.PowerLevelContent) {
c.Defaults()
c.Events = map[string]int64{
Expand Down
97 changes: 97 additions & 0 deletions tests/federation_room_join_partial_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,103 @@ func TestPartialStateJoin(t *testing.T) {
})
}
})

// test that a partial-state join can fall back to other homeservers when re-syncing
// partial state.
t.Run("PartialStateJoinSyncsUsingOtherHomeservers", func(t *testing.T) {
// set up 3 homeservers: hs1, hs2 and complement
deployment := Deploy(t, b.BlueprintFederationTwoLocalOneRemote)
defer deployment.Destroy(t)
alice := deployment.Client(t, "hs1", "@alice:hs1")
charlie := deployment.Client(t, "hs2", "@charlie:hs2")

// create a public room
roomID := alice.CreateRoom(t, map[string]interface{}{
"preset": "public_chat",
})

// create the complement homeserver
server := federation.NewServer(t, deployment,
federation.HandleKeyRequests(),
federation.HandlePartialStateMakeSendJoinRequests(),
federation.HandleEventRequests(),
federation.HandleTransactionRequests(
func(e *gomatrixserverlib.Event) {
t.Fatalf("Received unexpected PDU: %s", string(e.JSON()))
},
// hs1 may send us presence when alice syncs
nil,
),
)
cancelListener := server.Listen()
defer cancelListener()

// join complement to the public room
room := server.MustJoinRoom(t, deployment, "hs1", roomID, server.UserID("david"))

// we expect a /state_ids request from hs2 after it joins the room
// we will respond to the request with garbage
fedStateIdsRequestReceivedWaiter := NewWaiter()
fedStateIdsSendResponseWaiter := NewWaiter()
server.Mux().Handle(
fmt.Sprintf("/_matrix/federation/v1/state_ids/%s", roomID),
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
queryParams := req.URL.Query()
t.Logf("Incoming state_ids request for event %s in room %s", queryParams["event_id"], roomID)
fedStateIdsRequestReceivedWaiter.Finish()
fedStateIdsSendResponseWaiter.Wait(t, 60*time.Second)
t.Logf("Replying to /state_ids request with invalid response")

w.WriteHeader(200)

if _, err := w.Write([]byte("{}")); err != nil {
t.Errorf("Error writing to request: %v", err)
}
}),
).Methods("GET")

// join charlie on hs2 to the room, via the complement homeserver
charlie.JoinRoom(t, roomID, []string{server.ServerName()})

// and let hs1 know that charlie has joined,
// otherwise hs1 will refuse /state_ids requests
member_event := room.CurrentState("m.room.member", charlie.UserID).JSON()
server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{member_event}, nil)
alice.MustSyncUntil(t, client.SyncReq{}, client.SyncJoinedTo(charlie.UserID, roomID))

// wait until hs2 starts syncing state
fedStateIdsRequestReceivedWaiter.Waitf(t, 5*time.Second, "Waiting for /state_ids request")

syncResponseChan := make(chan gjson.Result)
defer close(syncResponseChan)
go func() {
response, _ := charlie.MustSync(t, client.SyncReq{})
syncResponseChan <- response
}()

// the client-side requests should still be waiting
select {
case <-syncResponseChan:
t.Fatalf("hs2 sync completed before state resync complete")
default:
}

// reply to hs2 with a bogus /state_ids response
fedStateIdsSendResponseWaiter.Finish()

// charlie's /sync request should now complete, with the new room
var syncRes gjson.Result
select {
case <-time.After(1 * time.Second):
t.Fatalf("hs2 /sync request request did not complete")
case syncRes = <-syncResponseChan:
}

roomRes := syncRes.Get("rooms.join." + client.GjsonEscape(roomID))
if !roomRes.Exists() {
t.Fatalf("hs2 /sync completed without join to new room\n")
}
})
}

// buildLazyLoadingSyncFilter constructs a json-marshalled filter suitable the 'Filter' field of a client.SyncReq
Expand Down