diff --git a/internal/federation/server.go b/internal/federation/server.go index 2bf7555b..510068e9 100644 --- a/internal/federation/server.go +++ b/internal/federation/server.go @@ -155,6 +155,23 @@ func (s *Server) FederationClient(deployment *docker.Deployment) *gomatrixserver return f } +// SendFederationRequest signs and sends an arbitrary federation request from this server. +// +// The requests will be routed according to the deployment map in `deployment`. +func (s *Server) SendFederationRequest(deployment *docker.Deployment, req gomatrixserverlib.FederationRequest, resBody interface{}) error { + if err := req.Sign(gomatrixserverlib.ServerName(s.ServerName), s.KeyID, s.Priv); err != nil { + return err + } + + httpReq, err := req.HTTPRequest() + if err != nil { + return err + } + + httpClient := gomatrixserverlib.NewClient(gomatrixserverlib.WithTransport(&docker.RoundTripper{Deployment: deployment})) + return httpClient.DoRequestAndParseResponse(context.Background(), httpReq, resBody) +} + // MustCreateEvent will create and sign a new latest event for the given room. // It does not insert this event into the room however. See ServerRoom.AddEvent for that. func (s *Server) MustCreateEvent(t *testing.T, room *ServerRoom, ev b.Event) *gomatrixserverlib.Event { diff --git a/tests/federation_room_join_test.go b/tests/federation_room_join_test.go index ddaf3555..037d0d0d 100644 --- a/tests/federation_room_join_test.go +++ b/tests/federation_room_join_test.go @@ -3,6 +3,7 @@ package tests import ( "context" "encoding/json" + "fmt" "net/http" "net/url" "testing" @@ -264,3 +265,126 @@ func TestBannedUserCannotSendJoin(t *testing.T) { membership := must.GetJSONFieldStr(t, stateResp, "membership") must.EqualStr(t, membership, "ban", "membership of charlie") } + +// This test checks that we cannot submit anything via /v1/send_join except a join. +func TestCannotSendNonJoinViaSendJoinV1(t *testing.T) { + testValidationForSendMembershipEndpoint(t, "/_matrix/federation/v1/send_join", "join", nil) +} + +// This test checks that we cannot submit anything via /v2/send_join except a join. +func TestCannotSendNonJoinViaSendJoinV2(t *testing.T) { + testValidationForSendMembershipEndpoint(t, "/_matrix/federation/v2/send_join", "join", nil) +} + +// This test checks that we cannot submit anything via /v1/send_leave except a leave. +func TestCannotSendNonLeaveViaSendLeaveV1(t *testing.T) { + testValidationForSendMembershipEndpoint(t, "/_matrix/federation/v1/send_leave", "leave", nil) +} + +// This test checks that we cannot submit anything via /v2/send_leave except a leave. +func TestCannotSendNonLeaveViaSendLeaveV2(t *testing.T) { + testValidationForSendMembershipEndpoint(t, "/_matrix/federation/v2/send_leave", "leave", nil) +} + +// testValidationForSendMembershipEndpoint attempts to submit a range of events via the given endpoint +// and checks that they are all rejected. +func testValidationForSendMembershipEndpoint(t *testing.T, baseApiPath, expectedMembership string, createRoomOpts map[string]interface{}) { + if createRoomOpts == nil { + createRoomOpts = make(map[string]interface{}) + } + + deployment := Deploy(t, b.BlueprintAlice) + defer deployment.Destroy(t) + + srv := federation.NewServer(t, deployment, + federation.HandleKeyRequests(), + federation.HandleTransactionRequests(nil, nil), + ) + cancel := srv.Listen() + defer cancel() + + // alice creates a room, and charlie joins it + alice := deployment.Client(t, "hs1", "@alice:hs1") + roomId := alice.CreateRoom(t, createRoomOpts) + charlie := srv.UserID("charlie") + room := srv.MustJoinRoom(t, deployment, "hs1", roomId, charlie) + + // a helper function which makes a send_* request to the given path and checks + // that it fails with a 400 error + assertRequestFails := func(t *testing.T, event *gomatrixserverlib.Event) { + path := fmt.Sprintf("%s/%s/%s", + baseApiPath, + url.PathEscape(event.RoomID()), + url.PathEscape(event.EventID()), + ) + t.Logf("PUT %s", path) + req := gomatrixserverlib.NewFederationRequest("PUT", "hs1", path) + if err := req.SetContent(event); err != nil { + t.Errorf("req.SetContent: %v", err) + return + } + + var res map[string]interface{} + err := srv.SendFederationRequest(deployment, req, &res) + if err == nil { + t.Errorf("send request returned 200") + return + } + + httpError, ok := err.(gomatrix.HTTPError) + if !ok { + t.Errorf("not an HTTPError: %v", err) + return + } + + t.Logf("%s returned %d/%s", baseApiPath, httpError.Code, string(httpError.Contents)) + if httpError.Code != 400 { + t.Errorf("expected 400, got %d", httpError.Code) + } + } + + t.Run("regular event", func(t *testing.T) { + event := srv.MustCreateEvent(t, room, b.Event{ + Type: "m.room.message", + Sender: charlie, + Content: map[string]interface{}{"body": "bzz"}, + }) + assertRequestFails(t, event) + }) + t.Run("non-state membership event", func(t *testing.T) { + event := srv.MustCreateEvent(t, room, b.Event{ + Type: "m.room.member", + Sender: charlie, + Content: map[string]interface{}{"body": "bzz"}, + }) + assertRequestFails(t, event) + }) + + // try membership events of various types, other than that expected by + // the endpoint + for _, membershipType := range []string{"join", "leave", "knock", "invite"} { + if membershipType == expectedMembership { + continue + } + event := srv.MustCreateEvent(t, room, b.Event{ + Type: "m.room.member", + Sender: charlie, + StateKey: &charlie, + Content: map[string]interface{}{"membership": membershipType}, + }) + t.Run(membershipType+" event", func(t *testing.T) { + assertRequestFails(t, event) + }) + } + + // right sort of membership, but mismatched state_key + t.Run("event with mismatched state key", func(t *testing.T) { + event := srv.MustCreateEvent(t, room, b.Event{ + Type: "m.room.member", + Sender: charlie, + StateKey: b.Ptr(srv.UserID("doris")), + Content: map[string]interface{}{"membership": expectedMembership}, + }) + assertRequestFails(t, event) + }) +} diff --git a/tests/msc2403_test.go b/tests/msc2403_test.go index 23751cdf..4e0cde4f 100644 --- a/tests/msc2403_test.go +++ b/tests/msc2403_test.go @@ -465,5 +465,13 @@ func publishAndCheckRoomJoinRule(t *testing.T, c *client.CSAPI, roomID, expected if !roomFound { t.Fatalf("Room was not present in public room directory response") } +} +// TestCannotSendNonKnockViaSendKnock checks that we cannot submit anything via /send_knock except a knock +func TestCannotSendNonKnockViaSendKnock(t *testing.T) { + testValidationForSendMembershipEndpoint(t, "/_matrix/federation/v1/send_knock", "knock", + map[string]interface{}{ + "room_version": "7", + }, + ) }