@@ -201,17 +201,25 @@ func isSafeInHTTPQuotedString(text string) bool { // nolint: gocyclo
201201// the query parameters, and the JSON content. In particular the version of
202202// HTTP and the headers aren't protected by the signature.
203203func VerifyHTTPRequest (
204- req * http.Request , now time.Time , destination ServerName , keys JSONVerifier ,
204+ req * http.Request , now time.Time ,
205+ destination ServerName , // the default server name, if none other is given
206+ isLocalServerName func (ServerName ) bool , // optional, verify secondary server names
207+ keys JSONVerifier ,
205208) (* FederationRequest , util.JSONResponse ) {
206209 request , err := readHTTPRequest (req )
207210 if err != nil {
208211 util .GetLogger (req .Context ()).WithError (err ).Print ("Error parsing HTTP headers" )
209212 return nil , util .MessageResponse (400 , "Bad Request" )
210213 }
211- if request .fields .Destination != "" && request .fields .Destination != destination {
212- message := "Unrecognised server name for Destination"
213- util .GetLogger (req .Context ()).WithError (err ).Print (message )
214- return nil , util .MessageResponse (400 , message )
214+ if request .fields .Destination != "" {
215+ switch {
216+ case isLocalServerName != nil && ! isLocalServerName (request .fields .Destination ):
217+ fallthrough
218+ case isLocalServerName == nil && destination != request .fields .Destination :
219+ message := fmt .Sprintf ("Unrecognised server name %q for Destination" , request .fields .Destination )
220+ util .GetLogger (req .Context ()).Warn (message )
221+ return nil , util .MessageResponse (400 , message )
222+ }
215223 } else if request .fields .Destination == "" {
216224 request .fields .Destination = destination
217225 }
0 commit comments