Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions tests/federation_room_join_partial_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,12 +589,23 @@ func beginPartialStateJoin(t *testing.T, deployment *docker.Deployment, joiningU
},
}))

// register a handler for /state_ids requests, which finishes fedStateIdsRequestReceivedWaiter, then
// register a handler for /state_ids requests for the most recent event,
// which finishes fedStateIdsRequestReceivedWaiter, then
// waits for fedStateIdsSendResponseWaiter and sends a reply
handleStateIdsRequests(t, result.Server, result.ServerRoom, result.fedStateIdsRequestReceivedWaiter, result.fedStateIdsSendResponseWaiter)
lastEvent := result.ServerRoom.Timeline[len(result.ServerRoom.Timeline)-1]
currentState := result.ServerRoom.AllCurrentState()
handleStateIdsRequests(
t, result.Server, result.ServerRoom,
lastEvent.EventID(), currentState,
result.fedStateIdsRequestReceivedWaiter, result.fedStateIdsSendResponseWaiter,
)

// a handler for /state requests, which sends a sensible response
handleStateRequests(t, result.Server, result.ServerRoom, nil, nil)
handleStateRequests(
t, result.Server, result.ServerRoom,
lastEvent.EventID(), currentState,
nil, nil,
)

// have joiningUser join the room by room ID.
joiningUser.JoinRoom(t, result.ServerRoom.RoomID, []string{result.Server.ServerName()})
Expand Down Expand Up @@ -630,16 +641,20 @@ func (psj *partialStateJoinResult) FinishStateRequest() {
psj.fedStateIdsSendResponseWaiter.Finish()
}

// handleStateIdsRequests registers a handler for /state_ids requests for serverRoom.
// handleStateIdsRequests registers a handler for /state_ids requests for 'eventID'
//
// the returned state is as passed in 'roomState'
//
// if requestReceivedWaiter is not nil, it will be Finish()ed when the request arrives.
// if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response.
func handleStateIdsRequests(
t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom,
eventID string, roomState []*gomatrixserverlib.Event,
requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter,
) {
srv.Mux().Handle(
srv.Mux().NewRoute().Methods("GET").Path(
fmt.Sprintf("/_matrix/federation/v1/state_ids/%s", serverRoom.RoomID),
).Queries("event_id", eventID).Handler(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
queryParams := req.URL.Query()
t.Logf("Incoming state_ids request for event %s in room %s", queryParams["event_id"], serverRoom.RoomID)
Expand All @@ -652,8 +667,8 @@ func handleStateIdsRequests(
t.Logf("Replying to /state_ids request")

res := gomatrixserverlib.RespStateIDs{
AuthEventIDs: eventIDsFromEvents(serverRoom.AuthChain()),
StateEventIDs: eventIDsFromEvents(serverRoom.AllCurrentState()),
AuthEventIDs: eventIDsFromEvents(serverRoom.AuthChainForEvents(roomState)),
StateEventIDs: eventIDsFromEvents(roomState),
}
w.WriteHeader(200)
jsonb, _ := json.Marshal(res)
Expand All @@ -662,19 +677,24 @@ func handleStateIdsRequests(
t.Errorf("Error writing to request: %v", err)
}
}),
).Methods("GET")
)
t.Logf("Registered state_ids handler for event %s", eventID)
}

// makeStateHandler returns a handler for /state requests for serverRoom.
// makeStateHandler returns a handler for /state requests for 'eventID'
//
// the returned state is as passed in 'roomState'
//
// if requestReceivedWaiter is not nil, it will be Finish()ed when the request arrives.
// if sendResponseWaiter is not nil, we will Wait() for it to finish before sending the response.
func handleStateRequests(
t *testing.T, srv *federation.Server, serverRoom *federation.ServerRoom,
eventID string, roomState []*gomatrixserverlib.Event,
requestReceivedWaiter *Waiter, sendResponseWaiter *Waiter,
) {
srv.Mux().Handle(
srv.Mux().NewRoute().Methods("GET").Path(
fmt.Sprintf("/_matrix/federation/v1/state/%s", serverRoom.RoomID),
).Queries("event_id", eventID).Handler(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
queryParams := req.URL.Query()
t.Logf("Incoming state request for event %s in room %s", queryParams["event_id"], serverRoom.RoomID)
Expand All @@ -685,8 +705,8 @@ func handleStateRequests(
sendResponseWaiter.Waitf(t, 60*time.Second, "Waiting for /state request")
}
res := gomatrixserverlib.RespState{
AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AuthChain()),
StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AllCurrentState()),
AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(serverRoom.AuthChainForEvents(roomState)),
StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(roomState),
}
w.WriteHeader(200)
jsonb, _ := json.Marshal(res)
Expand All @@ -695,7 +715,7 @@ func handleStateRequests(
t.Errorf("Error writing to request: %v", err)
}
}),
).Methods("GET")
)
}

func eventIDsFromEvents(he []*gomatrixserverlib.Event) []string {
Expand Down