Skip to content

Commit 17b0be8

Browse files
Update VerifyHTTPRequest to accept a server name function (#344)
This is needed for matrix-org/dendrite#2829.
1 parent 7c772f1 commit 17b0be8

2 files changed

Lines changed: 15 additions & 7 deletions

File tree

request.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,17 +201,25 @@ func isSafeInHTTPQuotedString(text string) bool { // nolint: gocyclo
201201
// the query parameters, and the JSON content. In particular the version of
202202
// HTTP and the headers aren't protected by the signature.
203203
func VerifyHTTPRequest(
204-
req *http.Request, now time.Time, destination ServerName, keys JSONVerifier,
204+
req *http.Request, now time.Time,
205+
destination ServerName, // the default server name, if none other is given
206+
isLocalServerName func(ServerName) bool, // optional, verify secondary server names
207+
keys JSONVerifier,
205208
) (*FederationRequest, util.JSONResponse) {
206209
request, err := readHTTPRequest(req)
207210
if err != nil {
208211
util.GetLogger(req.Context()).WithError(err).Print("Error parsing HTTP headers")
209212
return nil, util.MessageResponse(400, "Bad Request")
210213
}
211-
if request.fields.Destination != "" && request.fields.Destination != destination {
212-
message := "Unrecognised server name for Destination"
213-
util.GetLogger(req.Context()).WithError(err).Print(message)
214-
return nil, util.MessageResponse(400, message)
214+
if request.fields.Destination != "" {
215+
switch {
216+
case isLocalServerName != nil && !isLocalServerName(request.fields.Destination):
217+
fallthrough
218+
case isLocalServerName == nil && destination != request.fields.Destination:
219+
message := fmt.Sprintf("Unrecognised server name %q for Destination", request.fields.Destination)
220+
util.GetLogger(req.Context()).Warn(message)
221+
return nil, util.MessageResponse(400, message)
222+
}
215223
} else if request.fields.Destination == "" {
216224
request.fields.Destination = destination
217225
}

request_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func TestVerifyGetRequest(t *testing.T) {
7878
t.Fatal(err)
7979
}
8080
request, jsonResp := VerifyHTTPRequest(
81-
hr, time.Unix(1493142432, 96400), "localhost:44033", KeyRing{nil, &testKeyDatabase{}},
81+
hr, time.Unix(1493142432, 96400), "localhost:44033", nil, KeyRing{nil, &testKeyDatabase{}},
8282
)
8383
if request == nil {
8484
t.Fatalf("Wanted non-nil request got nil. (request was %#v, response was %#v)", hr, jsonResp)
@@ -137,7 +137,7 @@ func TestVerifyPutRequest(t *testing.T) {
137137
t.Fatal(err)
138138
}
139139
request, jsonResp := VerifyHTTPRequest(
140-
hr, time.Unix(1493142432, 96400), "localhost:44033", KeyRing{nil, &testKeyDatabase{}},
140+
hr, time.Unix(1493142432, 96400), "localhost:44033", nil, KeyRing{nil, &testKeyDatabase{}},
141141
)
142142
if request == nil {
143143
t.Fatalf("Wanted non-nil request got nil. (request was %#v, response was %#v)", hr, jsonResp)

0 commit comments

Comments
 (0)