diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index f01e48f806..1952c270ab 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -11,4 +11,5 @@ const ( LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" LoginTypeToken = "m.login.token" + LoginTypeTerms = "m.login.terms" ) diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index 1fc1c0c016..1c2278392f 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -29,6 +29,13 @@ type MatrixError struct { Err string `json:"error"` } +// ConsentError is an error returned to users, who didn't accept the +// TOS of this server yet. +type ConsentError struct { + MatrixError + ConsentURI string `json:"consent_uri"` +} + func (e MatrixError) Error() string { return fmt.Sprintf("%s: %s", e.ErrCode, e.Err) } @@ -207,3 +214,15 @@ func NotTrusted(serverName string) *MatrixError { Err: fmt.Sprintf("Untrusted server '%s'", serverName), } } + +// ConsentNotGiven is an error returned to users, who didn't accept the +// TOS of this server yet. +func ConsentNotGiven(consentURI string, msg string) *ConsentError { + return &ConsentError{ + MatrixError: MatrixError{ + ErrCode: "M_CONSENT_NOT_GIVEN", + Err: msg, + }, + ConsentURI: consentURI, + } +} diff --git a/clientapi/routing/consent_tracking.go b/clientapi/routing/consent_tracking.go new file mode 100644 index 0000000000..3136756470 --- /dev/null +++ b/clientapi/routing/consent_tracking.go @@ -0,0 +1,219 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// The data used to populate the /consent request +type constentTemplateData struct { + UserID string + Version string + UserHMAC string + HasConsented bool + ReadOnly bool +} + +func writeHeaderAndText(w http.ResponseWriter, statusCode int) { + w.WriteHeader(statusCode) + _, _ = w.Write([]byte(http.StatusText(statusCode))) +} + +func consent(writer http.ResponseWriter, req *http.Request, userAPI userapi.UserConsentPolicyAPI, cfg *config.ClientAPI) { + consentCfg := cfg.Matrix.UserConsentOptions + + // The data used to populate the /consent request + data := constentTemplateData{ + UserID: req.FormValue("u"), + Version: req.FormValue("v"), + UserHMAC: req.FormValue("h"), + } + + switch req.Method { + case http.MethodGet: + // display the privacy policy without a form + data.ReadOnly = data.UserID == "" || data.UserHMAC == "" || data.Version == "" + + // let's see if the user already consented to the current version + if !data.ReadOnly { + if ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret); err != nil || !ok { + writeHeaderAndText(writer, http.StatusForbidden) + return + } + + res := &userapi.QueryPolicyVersionResponse{} + localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID) + if err != nil { + logrus.WithError(err).Error("unable to split username") + writeHeaderAndText(writer, http.StatusInternalServerError) + return + } + if err = userAPI.QueryPolicyVersion(req.Context(), &userapi.QueryPolicyVersionRequest{ + Localpart: localpart, + }, res); err != nil { + logrus.WithError(err).Error("unable query policy version") + writeHeaderAndText(writer, http.StatusInternalServerError) + return + } + data.HasConsented = res.PolicyVersion == consentCfg.Version + } + + err := consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data) + if err != nil { + logrus.WithError(err).Error("unable to execute consent template") + writeHeaderAndText(writer, http.StatusInternalServerError) + return + } + case http.MethodPost: + ok, err := validHMAC(data.UserID, data.UserHMAC, consentCfg.FormSecret) + if err != nil || !ok { + if !ok { + writeHeaderAndText(writer, http.StatusForbidden) + return + } + writeHeaderAndText(writer, http.StatusInternalServerError) + return + } + localpart, _, err := gomatrixserverlib.SplitID('@', data.UserID) + if err != nil { + logrus.WithError(err).Error("unable to split username") + writeHeaderAndText(writer, http.StatusInternalServerError) + return + } + if err = userAPI.PerformUpdatePolicyVersion( + req.Context(), + &userapi.UpdatePolicyVersionRequest{ + PolicyVersion: data.Version, + Localpart: localpart, + }, + &userapi.UpdatePolicyVersionResponse{}, + ); err != nil { + writeHeaderAndText(writer, http.StatusInternalServerError) + return + } + // display the privacy policy without a form + data.ReadOnly = false + data.HasConsented = true + + err = consentCfg.Templates.ExecuteTemplate(writer, consentCfg.Version+".gohtml", data) + if err != nil { + logrus.WithError(err).Error("unable to print consent template") + writeHeaderAndText(writer, http.StatusInternalServerError) + } + } +} + +func sendServerNoticeForConsent(userAPI userapi.ClientUserAPI, rsAPI api.ClientRoomserverAPI, + cfgNotices *config.ServerNotices, + cfgClient *config.ClientAPI, + senderDevice *userapi.Device, + asAPI appserviceAPI.AppServiceInternalAPI, +) { + res := &userapi.QueryOutdatedPolicyResponse{} + if err := userAPI.QueryOutdatedPolicy(context.Background(), &userapi.QueryOutdatedPolicyRequest{ + PolicyVersion: cfgClient.Matrix.UserConsentOptions.Version, + }, res); err != nil { + logrus.WithError(err).Error("unable to fetch users with outdated consent policy") + return + } + + var ( + consentOpts = cfgClient.Matrix.UserConsentOptions + data = make(map[string]string) + err error + sentMessages int + ) + + if len(res.UserLocalparts) == 0 { + return + } + + logrus.WithField("count", len(res.UserLocalparts)).Infof("Sending server notice to users who have not yet accepted the policy") + + for _, localpart := range res.UserLocalparts { + if localpart == cfgClient.Matrix.ServerNotices.LocalPart { + continue + } + userID := fmt.Sprintf("@%s:%s", localpart, cfgClient.Matrix.ServerName) + data["ConsentURL"], err = consentOpts.ConsentURL(userID) + if err != nil { + logrus.WithError(err).WithField("userID", userID).Error("unable to construct consentURI") + continue + } + msgBody := &bytes.Buffer{} + + if err = consentOpts.TextTemplates.ExecuteTemplate(msgBody, "serverNoticeTemplate", data); err != nil { + logrus.WithError(err).WithField("userID", userID).Error("unable to execute serverNoticeTemplate") + continue + } + + req := sendServerNoticeRequest{ + UserID: userID, + Content: struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + }{ + MsgType: consentOpts.ServerNoticeContent.MsgType, + Body: msgBody.String(), + }, + } + _, err = sendServerNotice(context.Background(), req, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, nil, nil, nil) + if err != nil { + logrus.WithError(err).WithField("userID", userID).Error("failed to send server notice for consent to user") + continue + } + sentMessages++ + res := &userapi.UpdatePolicyVersionResponse{} + if err = userAPI.PerformUpdatePolicyVersion(context.Background(), &userapi.UpdatePolicyVersionRequest{ + PolicyVersion: consentOpts.Version, + Localpart: userID, + ServerNoticeUpdate: true, + }, res); err != nil { + logrus.WithError(err).WithField("userID", userID).Error("failed to update policy version") + continue + } + } + if sentMessages > 0 { + logrus.Infof("Sent messages to %d users", sentMessages) + } +} + +func validHMAC(username, userHMAC, secret string) (bool, error) { + mac := hmac.New(sha256.New, []byte(secret)) + _, err := mac.Write([]byte(username)) + if err != nil { + return false, err + } + expectedMAC := mac.Sum(nil) + decoded, err := hex.DecodeString(userHMAC) + if err != nil { + return false, err + } + return hmac.Equal(decoded, expectedMAC), nil +} diff --git a/clientapi/routing/consent_tracking_test.go b/clientapi/routing/consent_tracking_test.go new file mode 100644 index 0000000000..75a684eba4 --- /dev/null +++ b/clientapi/routing/consent_tracking_test.go @@ -0,0 +1,236 @@ +package routing + +import ( + "context" + "fmt" + "html/template" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/matrix-org/dendrite/setup/config" + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +func Test_validHMAC(t *testing.T) { + type args struct { + username string + userHMAC string + secret string + } + tests := []struct { + name string + args args + want bool + wantErr bool + }{ + { + name: "invalid hmac", + args: args{}, + wantErr: false, + want: false, + }, + // $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld' + //(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e + // + { + name: "valid hmac", + args: args{ + username: "@alice:localhost", + userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e", + secret: "helloWorld", + }, + want: true, + }, + { + name: "invalid hmac", + args: args{ + username: "@bob:localhost", + userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e", + secret: "helloWorld", + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := validHMAC(tt.args.username, tt.args.userHMAC, tt.args.secret) + if (err != nil) != tt.wantErr { + t.Errorf("validHMAC() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("validHMAC() got = %v, want %v", got, tt.want) + } + }) + } +} + +type dummyAPI struct { + usersConsent map[string]string +} + +func (d dummyAPI) QueryOutdatedPolicy(ctx context.Context, req *userapi.QueryOutdatedPolicyRequest, res *userapi.QueryOutdatedPolicyResponse) error { + return nil +} + +func (d dummyAPI) PerformUpdatePolicyVersion(ctx context.Context, req *userapi.UpdatePolicyVersionRequest, res *userapi.UpdatePolicyVersionResponse) error { + d.usersConsent[req.Localpart] = req.PolicyVersion + return nil +} + +func (d dummyAPI) QueryPolicyVersion(ctx context.Context, req *userapi.QueryPolicyVersionRequest, res *userapi.QueryPolicyVersionResponse) error { + res.PolicyVersion = "v2.0" + return nil +} + +const dummyTemplate = ` +{{ if .HasConsented }} +Consent given. +{{ else }} +WithoutForm + {{ if not .ReadOnly }} + With Form. + {{ end }} +{{ end }}` + +func Test_consent(t *testing.T) { + type args struct { + username string + userHMAC string + version string + method string + } + tests := []struct { + name string + args args + wantRespCode int + wantBodyContains string + }{ + { + name: "not a userID, valid hmac", + args: args{ + username: "notAuserID", + userHMAC: "7578bbface5ebb250a63935cebc05ca12060f58ebdbd271ecbc25e25a3da154d", + version: "v1.0", + method: http.MethodGet, + }, + wantRespCode: http.StatusInternalServerError, + }, + + // $ echo -n '@alice:localhost' | openssl sha256 -hmac 'helloWorld' + //(stdin)= 121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e + // + { + name: "valid hmac for alice GET, not consented", + args: args{ + username: "@alice:localhost", + userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e", + version: "v1.0", + method: http.MethodGet, + }, + wantRespCode: http.StatusOK, + wantBodyContains: "With form", + }, + { + name: "alice consents successfully", + args: args{ + username: "@alice:localhost", + userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e", + version: "v1.0", + method: http.MethodPost, + }, + wantRespCode: http.StatusOK, + wantBodyContains: "Consent given", + }, + { + name: "valid hmac for alice GET, new version", + args: args{ + username: "@alice:localhost", + userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e", + version: "v2.0", + method: http.MethodGet, + }, + wantRespCode: http.StatusOK, + wantBodyContains: "With form", + }, + { + name: "no hmac provided for alice, read only should be displayed", + args: args{ + username: "@alice:localhost", + userHMAC: "", + version: "v1.0", + method: http.MethodGet, + }, + wantRespCode: http.StatusOK, + wantBodyContains: "WithoutForm", + }, + { + name: "alice trying to get bobs status is forbidden", + args: args{ + username: "@bob:localhost", + userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e", + version: "v1.0", + method: http.MethodGet, + }, + wantRespCode: http.StatusForbidden, + wantBodyContains: "forbidden", + }, + { + name: "alice trying to consent for bob is forbidden", + args: args{ + username: "@bob:localhost", + userHMAC: "121c9bab767ed87a3136db0c3002144dfe414720aa328d235199082e4757541e", + version: "v1.0", + method: http.MethodPost, + }, + wantRespCode: http.StatusForbidden, + wantBodyContains: "forbidden", + }, + } + + userAPI := dummyAPI{ + usersConsent: map[string]string{}, + } + consentTemplates := template.Must(template.New("v1.0.gohtml").Parse(dummyTemplate)) + consentTemplates = template.Must(consentTemplates.New("v2.0.gohtml").Parse(dummyTemplate)) + userconsentOpts := config.UserConsentOptions{ + FormSecret: "helloWorld", + Version: "v1.0", + Templates: consentTemplates, + BaseURL: "http://localhost", + } + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + UserConsentOptions: userconsentOpts, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := fmt.Sprintf("%s/consent?u=%s&v=%s&h=%s", + userconsentOpts.BaseURL, tt.args.username, tt.args.version, tt.args.userHMAC, + ) + + req := httptest.NewRequest(tt.args.method, url, nil) + w := httptest.NewRecorder() + + consent(w, req, userAPI, cfg) + + resp := w.Result() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read response body: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != tt.wantRespCode { + t.Fatalf("expected http %d, got %d", tt.wantRespCode, resp.StatusCode) + } + + if !strings.Contains(strings.ToLower(string(body)), strings.ToLower(tt.wantBodyContains)) { + t.Fatalf("expected body to contain %s, but got %s", tt.wantBodyContains, string(body)) + } + }) + } +} diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index eba4920c69..53f833e356 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -31,13 +31,12 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/tidwall/gjson" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -721,6 +720,8 @@ func handleRegistrationFlow( } switch r.Auth.Type { + case authtypes.LoginTypeTerms: + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeTerms) case authtypes.LoginTypeRecaptcha: // Check given captcha response resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) @@ -788,11 +789,16 @@ func handleApplicationServiceRegistration( return *err } + policyVersion := "" + if cfg.Matrix.UserConsentOptions.Enabled { + policyVersion = cfg.Matrix.UserConsentOptions.Version + } + // If no error, application service was successfully validated. // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, + req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, policyVersion, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) } @@ -809,9 +815,13 @@ func checkAndCompleteFlow( userAPI userapi.ClientUserAPI, ) util.JSONResponse { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { + policyVersion := "" + if cfg.Matrix.UserConsentOptions.Enabled { + policyVersion = cfg.Matrix.UserConsentOptions.Version + } // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, + req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, policyVersion, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) } @@ -834,7 +844,7 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username, password, appserviceID, ipAddr, userAgent, sessionID string, + username, password, appserviceID, ipAddr, userAgent, sessionID, policyVersion string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, @@ -861,11 +871,12 @@ func completeRegistration( } var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ - AppServiceID: appserviceID, - Localpart: username, - Password: password, - AccountType: accType, - OnConflict: userapi.ConflictAbort, + AppServiceID: appserviceID, + Localpart: username, + Password: password, + AccountType: accType, + OnConflict: userapi.ConflictAbort, + PolicyVersion: policyVersion, }, &accRes) if err != nil { if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists @@ -1073,5 +1084,5 @@ func handleSharedSecretRegistration(userAPI userapi.ClientUserAPI, sr *SharedSec if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 0392895698..3810c9f1b3 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -15,6 +15,7 @@ package routing import ( + "context" "encoding/json" "net/http" @@ -93,7 +94,7 @@ func PutTag( } tagContent.Tags[tag] = properties - if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { + if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } @@ -145,7 +146,7 @@ func DeleteTag( } } - if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { + if err = saveTagData(req.Context(), userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } @@ -191,7 +192,7 @@ func obtainSavedTags( // saveTagData saves the provided tag data into the database func saveTagData( - req *http.Request, + context context.Context, userID string, roomID string, userAPI api.ClientUserAPI, @@ -208,5 +209,5 @@ func saveTagData( AccountData: json.RawMessage(newTagData), } dataRes := api.InputAccountDataResponse{} - return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes) + return userAPI.InputAccountData(context, &dataReq, &dataRes) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 94becf465a..5216800f85 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -127,9 +127,13 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) // server notifications + var ( + serverNotificationSender *userapi.Device + err error + ) if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") - serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg) + serverNotificationSender, err = getSenderDevice(context.Background(), userAPI, cfg) if err != nil { logrus.WithError(err).Fatal("unable to get account for sending sending server notices") } @@ -177,13 +181,27 @@ func Setup( // using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching. // Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing! v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() - unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() + // NOTSPEC: consent tracking + if cfg.Matrix.UserConsentOptions.Enabled { + if !cfg.Matrix.ServerNotices.Enabled { + logrus.Warnf("Consent tracking is enabled, but server notes are not. No server notice will be sent to users") + } else { + // start a new go routine to send messages about consent + go sendServerNoticeForConsent(userAPI, rsAPI, &cfg.Matrix.ServerNotices, cfg, serverNotificationSender, asAPI) + } + publicAPIMux.HandleFunc("/consent", func(writer http.ResponseWriter, request *http.Request) { + consent(writer, request, userAPI, cfg) + }).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + } + + consentRequiredCheck := httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI) + v3mux.Handle("/createRoom", httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateRoom(req, device, cfg, userAPI, rsAPI, asAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/join/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -213,7 +231,7 @@ func Setup( return PeekRoomByIDOrAlias( req, device, rsAPI, vars["roomIDOrAlias"], ) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) } v3mux.Handle("/joined_rooms", @@ -258,7 +276,7 @@ func Setup( return UnpeekRoomByID( req, device, rsAPI, vars["roomID"], ) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/ban", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -267,7 +285,7 @@ func Setup( return util.ErrorResponse(err) } return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/invite", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -279,7 +297,7 @@ func Setup( return util.ErrorResponse(err) } return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/kick", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -288,7 +306,7 @@ func Setup( return util.ErrorResponse(err) } return SendKick(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/unban", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -297,7 +315,7 @@ func Setup( return util.ErrorResponse(err) } return SendUnban(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/send/{eventType}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -317,7 +335,7 @@ func Setup( txnID := vars["txnID"] return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, nil, cfg, rsAPI, transactionsCache) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/event/{eventID}", httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -326,7 +344,7 @@ func Setup( return util.ErrorResponse(err) } return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -335,7 +353,7 @@ func Setup( return util.ErrorResponse(err) } return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) - })).Methods(http.MethodGet, http.MethodOptions) + }, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -343,7 +361,7 @@ func Setup( return util.ErrorResponse(err) } return GetAliases(req, rsAPI, device, vars["roomID"]) - })).Methods(http.MethodGet, http.MethodOptions) + }, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -354,7 +372,7 @@ func Setup( eventType := strings.TrimSuffix(vars["type"], "/") eventFormat := req.URL.Query().Get("format") == "event" return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) - })).Methods(http.MethodGet, http.MethodOptions) + }, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -363,7 +381,7 @@ func Setup( } eventFormat := req.URL.Query().Get("format") == "event" return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) - })).Methods(http.MethodGet, http.MethodOptions) + }, consentRequiredCheck)).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -374,7 +392,7 @@ func Setup( emptyString := "" eventType := strings.TrimSuffix(vars["eventType"], "/") return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", @@ -385,7 +403,7 @@ func Setup( } stateKey := vars["stateKey"] return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { @@ -487,7 +505,7 @@ func Setup( return util.ErrorResponse(err) } return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -497,7 +515,7 @@ func Setup( } txnID := vars["txnId"] return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/sendToDevice/{eventType}/{txnID}", @@ -508,7 +526,7 @@ func Setup( } txnID := vars["txnID"] return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) // This is only here because sytest refers to /unstable for this endpoint @@ -522,7 +540,7 @@ func Setup( } txnID := vars["txnID"] return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/account/whoami", @@ -531,7 +549,7 @@ func Setup( return *r } return Whoami(req, device) - }), + }, consentRequiredCheck), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/password", @@ -738,7 +756,7 @@ func Setup( return util.ErrorResponse(err) } return SetAvatarURL(req, userAPI, device, vars["userID"], cfg, rsAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method @@ -763,7 +781,7 @@ func Setup( return util.ErrorResponse(err) } return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method @@ -771,19 +789,19 @@ func Setup( v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAssociated3PIDs(req, userAPI, device) - }), + }, consentRequiredCheck), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, userAPI, device, cfg) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/account/3pid/delete", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Forget3PID(req, userAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", @@ -798,7 +816,7 @@ func Setup( return *r } return RequestTurnServer(req, device, cfg) - }), + }, consentRequiredCheck), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocols", @@ -868,7 +886,7 @@ func Setup( return util.ErrorResponse(err) } return GetAdminWhois(req, userAPI, device, vars["userID"]) - }), + }, consentRequiredCheck), ).Methods(http.MethodGet) v3mux.Handle("/user/{userID}/openid/request_token", @@ -881,7 +899,7 @@ func Setup( return util.ErrorResponse(err) } return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/user_directory/search", @@ -907,7 +925,7 @@ func Setup( postContent.SearchString, postContent.Limit, ) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/members", @@ -953,7 +971,7 @@ func Setup( return util.ErrorResponse(err) } return SendForget(req, device, vars["roomID"], rsAPI) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/upgrade", @@ -1065,7 +1083,7 @@ func Setup( return util.ErrorResponse(err) } return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) - }), + }, consentRequiredCheck), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", @@ -1075,7 +1093,7 @@ func Setup( return util.ErrorResponse(err) } return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) - }), + }, consentRequiredCheck), ).Methods(http.MethodDelete, http.MethodOptions) v3mux.Handle("/capabilities", @@ -1095,11 +1113,11 @@ func Setup( return util.ErrorResponse(err) } return KeyBackupVersion(req, userAPI, device, vars["version"]) - }) + }, consentRequiredCheck) getLatestBackupKeysVersion := httputil.MakeAuthAPI("get_latest_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return KeyBackupVersion(req, userAPI, device, "") - }) + }, consentRequiredCheck) putBackupKeysVersion := httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -1107,7 +1125,7 @@ func Setup( return util.ErrorResponse(err) } return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"]) - }) + }, consentRequiredCheck) deleteBackupKeysVersion := httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -1119,7 +1137,7 @@ func Setup( postNewBackupKeysVersion := httputil.MakeAuthAPI("post_new_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateKeyBackupVersion(req, userAPI, device) - }) + }, consentRequiredCheck) v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) @@ -1150,7 +1168,7 @@ func Setup( return *resErr } return UploadBackupKeys(req, userAPI, device, version, &reqBody) - }) + }, consentRequiredCheck) // Single room bulk session putBackupKeysRoom := httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -1182,7 +1200,7 @@ func Setup( } reqBody.Rooms[roomID] = body return UploadBackupKeys(req, userAPI, device, version, &reqBody) - }) + }, consentRequiredCheck) // Single room, single session putBackupKeysRoomSession := httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -1215,7 +1233,7 @@ func Setup( } keyReq.Rooms[roomID].Sessions[sessionID] = reqBody return UploadBackupKeys(req, userAPI, device, version, &keyReq) - }) + }, consentRequiredCheck) v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) @@ -1229,7 +1247,7 @@ func Setup( getBackupKeys := httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "") - }) + }, consentRequiredCheck) getBackupKeysRoom := httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -1245,7 +1263,7 @@ func Setup( return util.ErrorResponse(err) } return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) - }) + }, consentRequiredCheck) v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) @@ -1261,11 +1279,11 @@ func Setup( postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg) - }) + }, consentRequiredCheck) postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceSignatures(req, keyAPI, device) - }) + }, consentRequiredCheck) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) @@ -1277,12 +1295,12 @@ func Setup( v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -1305,7 +1323,7 @@ func Setup( } return SetReceipt(req, syncProducer, device, vars["roomId"], vars["receiptType"], vars["eventId"]) - }), + }, consentRequiredCheck), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/presence/{userId}/status", httputil.MakeAuthAPI("set_presence", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 9edeed2f7d..40efe03a25 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -84,74 +84,66 @@ func SendServerNotice( if resErr != nil { return *resErr } + res, _ := sendServerNotice(ctx, r, rsAPI, cfgNotices, cfgClient, senderDevice, asAPI, userAPI, txnID, device, txnCache) + return res +} + +func sendServerNotice( + ctx context.Context, + serverNoticeRequest sendServerNoticeRequest, + rsAPI api.ClientRoomserverAPI, + cfgNotices *config.ServerNotices, + cfgClient *config.ClientAPI, + senderDevice *userapi.Device, + asAPI appserviceAPI.AppServiceInternalAPI, + userAPI userapi.ClientUserAPI, + txnID *string, + device *userapi.Device, + txnCache *transactions.Cache, +) (util.JSONResponse, error) { // check that all required fields are set - if !r.valid() { + if !serverNoticeRequest.valid() { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("Invalid request"), - } + }, fmt.Errorf("Invalid JSON") } - // get rooms for specified user - allUserRooms := []string{} - userRooms := api.QueryRoomsForUserResponse{} - // Get rooms the user is either joined, invited or has left. - for _, membership := range []string{"join", "invite", "leave"} { - if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ - UserID: r.UserID, - WantMembership: membership, - }, &userRooms); err != nil { - return util.ErrorResponse(err) - } - allUserRooms = append(allUserRooms, userRooms.RoomIDs...) - } - - // get rooms of the sender - senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) - senderRooms := api.QueryRoomsForUserResponse{} - if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ - UserID: senderUserID, - WantMembership: "join", - }, &senderRooms); err != nil { - return util.ErrorResponse(err) - } - - // check if we have rooms in common - commonRooms := []string{} - for _, userRoomID := range allUserRooms { - for _, senderRoomID := range senderRooms.RoomIDs { - if userRoomID == senderRoomID { - commonRooms = append(commonRooms, senderRoomID) - } - } + qryServerNoticeRoom := &userapi.QueryServerNoticeRoomResponse{} + localpart, _, err := gomatrixserverlib.SplitID('@', serverNoticeRequest.UserID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Invalid request"), + }, err } - - if len(commonRooms) > 1 { - return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms))) + err = userAPI.SelectServerNoticeRoomID(ctx, &userapi.QueryServerNoticeRoomRequest{Localpart: localpart}, qryServerNoticeRoom) + if err != nil { + return util.ErrorResponse(err), err } - var ( - roomID string - roomVersion = version.DefaultRoomVersion() - ) + senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) + roomID := qryServerNoticeRoom.RoomID + roomVersion := version.DefaultRoomVersion() // create a new room for the user - if len(commonRooms) == 0 { + if qryServerNoticeRoom.RoomID == "" { + var pl, cc []byte powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID) - powerLevelContent.Users[r.UserID] = -10 // taken from Synapse - pl, err := json.Marshal(powerLevelContent) + powerLevelContent.Users[serverNoticeRequest.UserID] = -10 // taken from Synapse + pl, err = json.Marshal(powerLevelContent) if err != nil { - return util.ErrorResponse(err) + return util.ErrorResponse(err), err } createContent := map[string]interface{}{} createContent["m.federate"] = false - cc, err := json.Marshal(createContent) + cc, err = json.Marshal(createContent) if err != nil { - return util.ErrorResponse(err) + return util.ErrorResponse(err), err } crReq := createRoomRequest{ - Invite: []string{r.UserID}, + Invite: []string{serverNoticeRequest.UserID}, Name: cfgNotices.RoomName, Visibility: "private", Preset: presetPrivateChat, @@ -166,36 +158,40 @@ func SendServerNotice( switch data := roomRes.JSON.(type) { case createRoomResponse: roomID = data.RoomID - + res := &userapi.UpdateServerNoticeRoomResponse{} + err = userAPI.UpdateServerNoticeRoomID(ctx, &userapi.UpdateServerNoticeRoomRequest{RoomID: roomID, Localpart: localpart}, res) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("UpdateServerNoticeRoomID failed") + return jsonerror.InternalServerError(), err + } // tag the room, so we can later check if the user tries to reject an invite serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{ "m.server_notice": { Order: 1.0, }, }} - if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil { + if err = saveTagData(ctx, serverNoticeRequest.UserID, roomID, userAPI, serverAlertTag); err != nil { util.GetLogger(ctx).WithError(err).Error("saveTagData failed") - return jsonerror.InternalServerError() + return jsonerror.InternalServerError(), err } default: // if we didn't get a createRoomResponse, we probably received an error, so return that. - return roomRes + return roomRes, fmt.Errorf("Unable to create room") } + } else { - // we've found a room in common, check the membership - roomID = commonRooms[0] - membershipRes := api.QueryMembershipForUserResponse{} - err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes) + res := &api.QueryMembershipForUserResponse{} + err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: serverNoticeRequest.UserID, RoomID: roomID}, res) if err != nil { - util.GetLogger(ctx).WithError(err).Error("unable to query membership for user") - return jsonerror.InternalServerError() + return util.ErrorResponse(err), err } - if !membershipRes.IsInRoom { - // re-invite the user - res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) + // re-invite the user + if res.Membership != gomatrixserverlib.Join { + var inviteRes util.JSONResponse + inviteRes, err = sendInvite(ctx, userAPI, senderDevice, roomID, serverNoticeRequest.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) if err != nil { - return res + return inviteRes, err } } } @@ -203,13 +199,13 @@ func SendServerNotice( startedGeneratingEvent := time.Now() request := map[string]interface{}{ - "body": r.Content.Body, - "msgtype": r.Content.MsgType, + "body": serverNoticeRequest.Content.Body, + "msgtype": serverNoticeRequest.Content.MsgType, } e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now()) if resErr != nil { logrus.Errorf("failed to send message: %+v", resErr) - return *resErr + return *resErr, fmt.Errorf("Unable to send event") } timeToGenerateEvent := time.Since(startedGeneratingEvent) @@ -224,7 +220,7 @@ func SendServerNotice( // pass the new event to the roomserver and receive the correct event ID // event ID in case of duplicate transaction is discarded startedSubmittingEvent := time.Now() - if err := api.SendEvents( + if err = api.SendEvents( ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{ @@ -236,7 +232,7 @@ func SendServerNotice( false, ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + return jsonerror.InternalServerError(), err } util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": e.EventID(), @@ -259,7 +255,7 @@ func SendServerNotice( sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds())) sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds())) - return res + return res, nil } func (r sendServerNoticeRequest) valid() (ok bool) { diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 7f6d5105e0..37103db163 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -146,7 +146,12 @@ func main() { logrus.Fatalln("Username is already in use.") } - _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) + policyVersion := "" + if cfg.Global.UserConsentOptions.Enabled { + policyVersion = cfg.Global.UserConsentOptions.Version + } + + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", policyVersion, accType) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) } diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml index 4b67aaa94f..dfbdab5196 100644 --- a/dendrite-sample.polylith.yaml +++ b/dendrite-sample.polylith.yaml @@ -71,6 +71,35 @@ global: # appear in user clients. room_name: "Server Alerts" + # Consent tracking configuration + user_consent: + # If the user consent tracking is enabled or not + enabled: false + # The base URL this homeserver will serve clients on, e.g. https://matrix.org + base_url: http://localhost + # Randomly generated string (e.g. by using "pwgen -sy 32") to be used to calculate the HMAC + form_secret: "superSecretRandomlyGeneratedSecret" + # Require consent when user registers for the first time + require_at_registration: false + # The name to be shown to the user + policy_name: "Privacy policy" + # The directory to search for templates + template_dir: "./templates/privacy" + # The version of the policy. When loading templates, ".gohtml" template is added as a suffix + # e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to "1.0" + version: "1.0" + # Send a consent message to guest users + send_server_notice_to_guest: false + # Default message to send to users + server_notice_content: + msg_type: "m.text" + body: >- + Please give your consent to the privacy policy at {{ .ConsentURL }}. + # The error message to display if the user hasn't given their consent yet + block_events_error: >- + You can't send any messages until you consent to the privacy policy at + {{ .ConsentURL }}. + # Configuration for NATS JetStream jetstream: # A list of NATS Server addresses to connect to. If none are specified, an diff --git a/docs/templates/privacy/1.0.gohtml b/docs/templates/privacy/1.0.gohtml new file mode 100644 index 0000000000..6e866cff2e --- /dev/null +++ b/docs/templates/privacy/1.0.gohtml @@ -0,0 +1,26 @@ + + + + Privacy policy + + +{{ if .HasConsented }} +

+ You have already given your consent. +

+{{ else }} +

+ Please give your consent to keep using this homeserver. +

+ {{ if not .ReadOnly }} + +
+ + + + +
+ {{ end }} +{{ end }} + + \ No newline at end of file diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index aba50ae4d2..55cd92620a 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -15,6 +15,8 @@ package httputil import ( + "bytes" + "context" "fmt" "io" "net/http" @@ -25,9 +27,12 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -41,10 +46,24 @@ type BasicAuth struct { Password string `yaml:"password"` } +// AuthAPICheck is an option to MakeAuthAPI to add additional checks (e.g. WithConsentCheck) to verify +// the user is allowed to do specific things. +type AuthAPICheck func(ctx context.Context, device *userapi.Device) *util.JSONResponse + +// WithConsentCheck checks that a user has given his consent. +func WithConsentCheck(options config.UserConsentOptions, api userapi.QueryPolicyVersionAPI) AuthAPICheck { + return func(ctx context.Context, device *userapi.Device) *util.JSONResponse { + if !options.Enabled { + return nil + } + return checkConsent(ctx, device.UserID, api, options) + } +} + // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( metricsName string, userAPI userapi.QueryAcccessTokenAPI, - f func(*http.Request, *userapi.Device) util.JSONResponse, + f func(*http.Request, *userapi.Device) util.JSONResponse, checks ...AuthAPICheck, ) http.Handler { h := func(req *http.Request) util.JSONResponse { logger := util.GetLogger(req.Context()) @@ -72,6 +91,14 @@ func MakeAuthAPI( } }() + // apply additional checks, if any + for _, opt := range checks { + resp := opt(req.Context(), device) + if resp != nil { + return *resp + } + } + jsonRes := f(req, device) // do not log 4xx as errors as they are client fails, not server fails if hub != nil && jsonRes.Code >= 500 { @@ -83,6 +110,53 @@ func MakeAuthAPI( return MakeExternalAPI(metricsName, h) } +func checkConsent(ctx context.Context, userID string, userAPI userapi.QueryPolicyVersionAPI, userConsentCfg config.UserConsentOptions) *util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return nil + } + // check which version of the policy the user accepted + res := &userapi.QueryPolicyVersionResponse{} + err = userAPI.QueryPolicyVersion(ctx, &userapi.QueryPolicyVersionRequest{ + Localpart: localpart, + }, res) + if err != nil { + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("unable to get policy version"), + } + } + + // user hasn't accepted any policy, block access. + if userConsentCfg.Version != res.PolicyVersion { + uri, err := userConsentCfg.ConsentURL(userID) + if err != nil { + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("unable to get consent URL"), + } + } + msg := &bytes.Buffer{} + c := struct { + ConsentURL string + }{ + ConsentURL: uri, + } + if err = userConsentCfg.TextTemplates.ExecuteTemplate(msg, "blockEventsError", c); err != nil { + logrus.Infof("error consent message: %+v", err) + return &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("unable to execute template"), + } + } + return &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.ConsentNotGiven(uri, msg.String()), + } + } + return nil +} + // MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler. // This is used for APIs that are called from the internet. func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 76f07415b4..8822f0e5b5 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -59,15 +59,12 @@ func Setup( PathToResult: map[string]*types.ThumbnailGenerationResult{}, } - uploadHandler := httputil.MakeAuthAPI( - "upload", userAPI, - func(req *http.Request, dev *userapi.Device) util.JSONResponse { - if r := rateLimits.Limit(req); r != nil { - return *r - } - return Upload(req, cfg, dev, db, activeThumbnailGeneration) - }, - ) + uploadHandler := httputil.MakeAuthAPI("upload", userAPI, func(req *http.Request, dev *userapi.Device) util.JSONResponse { + if r := rateLimits.Limit(req); r != nil { + return *r + } + return Upload(req, cfg, dev, db, activeThumbnailGeneration) + }) configHandler := httputil.MakeAuthAPI("config", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { diff --git a/setup/base/base.go b/setup/base/base.go index 5cbd7da9c3..740fa7c9bc 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -268,6 +268,7 @@ func (b *BaseDendrite) Close() error { func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) { if dbProperties.ConnectionString != "" || b == nil { // Open a new database connection using the supplied config. + logrus.Infof("Open a new database connection using the supplied config.: %+v", dbProperties.ConnectionString) db, err := sqlutil.Open(dbProperties, writer) return db, writer, err } diff --git a/setup/config/config.go b/setup/config/config.go index 9b9000a623..ff96e756d5 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -265,6 +265,21 @@ func loadConfig( return &c, nil } +type Terms struct { + Policies Policies `json:"policies"` +} +type En struct { + Name string `json:"name"` + URL string `json:"url"` +} +type PrivacyPolicy struct { + En En `json:"en"` + Version string `json:"version"` +} +type Policies struct { + PrivacyPolicy PrivacyPolicy `json:"privacy_policy"` +} + // Derive generates data that is derived from various values provided in // the config file. func (config *Dendrite) Derive() error { @@ -275,13 +290,39 @@ func (config *Dendrite) Derive() error { // TODO: Add email auth type // TODO: Add MSISDN auth type + if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration { + uri := config.Global.UserConsentOptions.BaseURL + "/_matrix/client/consent?v=" + config.Global.UserConsentOptions.Version + config.Derived.Registration.Params[authtypes.LoginTypeTerms] = Terms{ + Policies: Policies{ + PrivacyPolicy: PrivacyPolicy{ + En: En{ + Name: config.Global.UserConsentOptions.PolicyName, + URL: uri, + }, + Version: config.Global.UserConsentOptions.Version, + }, + }, + } + } if config.ClientAPI.RecaptchaEnabled { config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey} - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}) - } else { - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) + } + + if config.Derived.Registration.Flows == nil { + config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, authtypes.Flow{ + Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}, + }) + } + + // prepend each flow with LoginTypeTerms or LoginTypeRecaptcha + for i, flow := range config.Derived.Registration.Flows { + if config.Global.UserConsentOptions.Enabled && config.Global.UserConsentOptions.RequireAtRegistration { + flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeTerms}, flow.Stages...) + } + if config.ClientAPI.RecaptchaEnabled { + flow.Stages = append([]authtypes.LoginType{authtypes.LoginTypeRecaptcha}, flow.Stages...) + } + config.Derived.Registration.Flows[i] = flow } // Load application service configuration files diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 9d4c1485ed..df58f9e959 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -1,7 +1,15 @@ package config import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "fmt" + "html/template" "math/rand" + "net/url" + "path/filepath" + textTemplate "text/template" "time" "github.com/matrix-org/gomatrixserverlib" @@ -71,6 +79,9 @@ type Global struct { // ServerNotices configuration used for sending server notices ServerNotices ServerNotices `yaml:"server_notices"` + // Consent tracking options + UserConsentOptions UserConsentOptions `yaml:"user_consent"` + // ReportStats configures opt-in anonymous stats reporting. ReportStats ReportStats `yaml:"report_stats"` } @@ -88,6 +99,7 @@ func (c *Global) Defaults(generate bool) { c.Metrics.Defaults(generate) c.DNSCache.Defaults() c.Sentry.Defaults() + c.UserConsentOptions.Defaults() c.ServerNotices.Defaults(generate) c.ReportStats.Defaults() } @@ -100,6 +112,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { c.Metrics.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith) c.DNSCache.Verify(configErrs, isMonolith) + c.UserConsentOptions.Verify(configErrs, isMonolith) c.ServerNotices.Verify(configErrs, isMonolith) c.ReportStats.Verify(configErrs, isMonolith) } @@ -261,6 +274,102 @@ func (c *DNSCacheOptions) Verify(configErrs *ConfigErrors, isMonolith bool) { checkPositive(configErrs, "cache_lifetime", int64(c.CacheLifetime)) } +// Consent tracking configuration +// If either require_at_registration or send_server_notice_to_guest are true, consent +// messages will be sent to the users. +type UserConsentOptions struct { + // If consent tracking is enabled or not + Enabled bool `yaml:"enabled"` + // Randomly generated string to be used to calculate the HMAC + FormSecret string `yaml:"form_secret"` + // Require consent when user registers for the first time + RequireAtRegistration bool `yaml:"require_at_registration"` + // The name to be shown to the user + PolicyName string `yaml:"policy_name"` + // The directory to search for *.gohtml templates + TemplateDir string `yaml:"template_dir"` + // The version of the policy. When loading templates, ".gohtml" template is added as a suffix + // e.g: ${template_dir}/1.0.gohtml needs to exist, if this is set to 1.0 + Version string `yaml:"version"` + // Send a consent message to guest users + SendServerNoticeToGuest bool `yaml:"send_server_notice_to_guest"` + // Default message to send to users + ServerNoticeContent struct { + MsgType string `yaml:"msg_type"` + Body string `yaml:"body"` + } `yaml:"server_notice_content"` + // The error message to display if the user hasn't given their consent yet + BlockEventsError string `yaml:"block_events_error"` + // All loaded templates + Templates *template.Template `yaml:"-"` + TextTemplates *textTemplate.Template `yaml:"-"` + // The base URL this homeserver will serve clients on, e.g. https://matrix.org + BaseURL string `yaml:"base_url"` +} + +func (c *UserConsentOptions) Defaults() { + c.Enabled = false + c.RequireAtRegistration = false + c.SendServerNoticeToGuest = false + c.PolicyName = "Privacy Policy" + c.Version = "1.0" + c.TemplateDir = "./templates/privacy" +} + +func (c *UserConsentOptions) Verify(configErrors *ConfigErrors, isMonolith bool) { + if !c.Enabled { + return + } + + checkNotEmpty(configErrors, "template_dir", c.TemplateDir) + checkNotEmpty(configErrors, "version", c.Version) + checkNotEmpty(configErrors, "policy_name", c.PolicyName) + checkNotEmpty(configErrors, "form_secret", c.FormSecret) + checkNotEmpty(configErrors, "base_url", c.BaseURL) + if len(*configErrors) > 0 { + return + } + + p, err := filepath.Abs(c.TemplateDir) + if err != nil { + configErrors.Add("unable to get template directory") + return + } + + c.TextTemplates = textTemplate.Must(textTemplate.New("blockEventsError").Parse(c.BlockEventsError)) + c.TextTemplates = textTemplate.Must(c.TextTemplates.New("serverNoticeTemplate").Parse(c.ServerNoticeContent.Body)) + + // Read all defined *.gohtml templates + t, err := template.ParseGlob(filepath.Join(p, "*.gohtml")) + if err != nil || t == nil { + configErrors.Add(fmt.Sprintf("unable to read consent templates: %+v", err)) + return + } + c.Templates = t + // Verify we've got a template for the defined version + versionTemplate := c.Templates.Lookup(c.Version + ".gohtml") + if versionTemplate == nil { + configErrors.Add(fmt.Sprintf("unable to load defined '%s' policy template", c.Version)) + } +} + +// ConsentURL constructs the URL shown to users to accept the TOS +func (c *UserConsentOptions) ConsentURL(userID string) (string, error) { + mac := hmac.New(sha256.New, []byte(c.FormSecret)) + _, err := mac.Write([]byte(userID)) + if err != nil { + return "", err + } + userMAC := hex.EncodeToString(mac.Sum(nil)) + + params := url.Values{} + params.Add("u", userID) + params.Add("h", userMAC) + params.Add("v", c.Version) + + return fmt.Sprintf("%s/_matrix/client/consent?%s", c.BaseURL, params.Encode()), nil +} + // PresenceOptions defines possible configurations for presence events. type PresenceOptions struct { // Whether inbound presence events are allowed diff --git a/setup/config/config_global_test.go b/setup/config/config_global_test.go new file mode 100644 index 0000000000..ab61514ade --- /dev/null +++ b/setup/config/config_global_test.go @@ -0,0 +1,115 @@ +package config + +import ( + "testing" +) + +func TestUserConsentOptions_Verify(t *testing.T) { + type args struct { + configErrors *ConfigErrors + isMonolith bool + } + tests := []struct { + name string + fields UserConsentOptions + args args + wantErr bool + }{ + { + name: "template dir not set", + fields: UserConsentOptions{ + RequireAtRegistration: true, + }, + args: struct { + configErrors *ConfigErrors + isMonolith bool + }{configErrors: &ConfigErrors{}, isMonolith: true}, + wantErr: true, + }, + { + name: "template dir set", + fields: UserConsentOptions{ + RequireAtRegistration: true, + TemplateDir: "testdata/privacy", + }, + args: struct { + configErrors *ConfigErrors + isMonolith bool + }{configErrors: &ConfigErrors{}, isMonolith: true}, + wantErr: true, + }, + { + name: "policy name not set", + fields: UserConsentOptions{ + RequireAtRegistration: true, + TemplateDir: "testdata/privacy", + }, + args: struct { + configErrors *ConfigErrors + isMonolith bool + }{configErrors: &ConfigErrors{}, isMonolith: true}, + wantErr: true, + }, + { + name: "policy name set", + fields: UserConsentOptions{ + RequireAtRegistration: true, + TemplateDir: "testdata/privacy", + PolicyName: "Privacy policy", + }, + args: struct { + configErrors *ConfigErrors + isMonolith bool + }{configErrors: &ConfigErrors{}, isMonolith: true}, + wantErr: true, + }, + { + name: "version not set", + fields: UserConsentOptions{ + RequireAtRegistration: true, + TemplateDir: "testdata/privacy", + }, + args: struct { + configErrors *ConfigErrors + isMonolith bool + }{configErrors: &ConfigErrors{}, isMonolith: true}, + wantErr: true, + }, + { + name: "everyhing required set", + fields: UserConsentOptions{ + RequireAtRegistration: true, + TemplateDir: "./testdata/privacy", + Version: "1.0", + PolicyName: "Privacy policy", + FormSecret: "helloWorld", + BaseURL: "http://localhost", + }, + args: struct { + configErrors *ConfigErrors + isMonolith bool + }{configErrors: &ConfigErrors{}, isMonolith: true}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &UserConsentOptions{ + Enabled: true, + BaseURL: tt.fields.BaseURL, + FormSecret: tt.fields.FormSecret, + RequireAtRegistration: tt.fields.RequireAtRegistration, + PolicyName: tt.fields.PolicyName, + Version: tt.fields.Version, + TemplateDir: tt.fields.TemplateDir, + SendServerNoticeToGuest: tt.fields.SendServerNoticeToGuest, + ServerNoticeContent: tt.fields.ServerNoticeContent, + BlockEventsError: tt.fields.BlockEventsError, + } + c.Verify(tt.args.configErrors, tt.args.isMonolith) + if !tt.wantErr && len(*tt.args.configErrors) > 0 { + t.Errorf("expected no errors, got '%+v'", tt.args.configErrors) + } + }) + } +} diff --git a/setup/config/testdata/privacy/1.0.gohtml b/setup/config/testdata/privacy/1.0.gohtml new file mode 100644 index 0000000000..6e866cff2e --- /dev/null +++ b/setup/config/testdata/privacy/1.0.gohtml @@ -0,0 +1,26 @@ + + + + Privacy policy + + +{{ if .HasConsented }} +

+ You have already given your consent. +

+{{ else }} +

+ Please give your consent to keep using this homeserver. +

+ {{ if not .ReadOnly }} + +
+ + + + +
+ {{ end }} +{{ end }} + + \ No newline at end of file diff --git a/setup/flags.go b/setup/flags.go index a9dac61a17..1cd4f567eb 100644 --- a/setup/flags.go +++ b/setup/flags.go @@ -44,7 +44,6 @@ func ParseFlags(monolith bool) *config.Dendrite { } cfg, err := config.Load(*configPath, monolith) - if err != nil { logrus.Fatalf("Invalid config file: %s", err) } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 6bc495d8df..1960107029 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -93,6 +93,6 @@ func Setup( vars["roomId"], vars["eventId"], lazyLoadCache, ) - }), + }, httputil.WithConsentCheck(cfg.Matrix.UserConsentOptions, userAPI)), ).Methods(http.MethodGet, http.MethodOptions) } diff --git a/userapi/api/api.go b/userapi/api/api.go index df9408acbf..0fd06b41d0 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -66,6 +66,7 @@ type FederationUserAPI interface { // api functions required by the sync api type SyncUserAPI interface { QueryAcccessTokenAPI + QueryPolicyVersionAPI QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -78,6 +79,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI + UserConsentPolicyAPI QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -106,6 +108,18 @@ type ClientUserAPI interface { QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error + SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) (err error) + UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) (err error) +} + +type UserConsentPolicyAPI interface { + QueryPolicyVersionAPI + QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error + PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error +} + +type QueryPolicyVersionAPI interface { + QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error } // custom api functions required by pinecone / p2p demos @@ -318,12 +332,12 @@ type QuerySearchProfilesResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { - AccountType AccountType // Required: whether this is a guest or user account - Localpart string // Required: The localpart for this account. Ignored if account type is guest. - - AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. - Password string // optional: if missing then this account will be a passwordless account - OnConflict Conflict + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + PolicyVersion string // optional: the privacy policy this account has accepted + AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. + Password string // optional: if missing then this account will be a passwordless account + OnConflict Conflict } // PerformAccountCreationResponse is the response for PerformAccountCreation @@ -412,6 +426,53 @@ type QueryOpenIDTokenResponse struct { ExpiresAtMS int64 } +// QueryPolicyVersionRequest is the request for QueryPolicyVersionRequest +type QueryPolicyVersionRequest struct { + Localpart string +} + +// QueryPolicyVersionResponse is the response for QueryPolicyVersionRequest +type QueryPolicyVersionResponse struct { + PolicyVersion string +} + +// QueryOutdatedPolicyRequest is the request for QueryOutdatedPolicyRequest +type QueryOutdatedPolicyRequest struct { + PolicyVersion string +} + +// QueryOutdatedPolicyResponse is the response for QueryOutdatedPolicyRequest +type QueryOutdatedPolicyResponse struct { + UserLocalparts []string +} + +// UpdatePolicyVersionRequest is the request for UpdatePolicyVersionRequest +type UpdatePolicyVersionRequest struct { + PolicyVersion, Localpart string + ServerNoticeUpdate bool +} + +// UpdatePolicyVersionResponse is the response for UpdatePolicyVersionRequest +type UpdatePolicyVersionResponse struct{} + +// QueryServerNoticeRoomRequest is the request for QueryServerNoticeRoomRequest +type QueryServerNoticeRoomRequest struct { + Localpart string +} + +// QueryServerNoticeRoomResponse is the response for QueryServerNoticeRoomRequest +type QueryServerNoticeRoomResponse struct { + RoomID string +} + +// UpdateServerNoticeRoomRequest is the request for UpdateServerNoticeRoomRequest +type UpdateServerNoticeRoomRequest struct { + Localpart, RoomID string +} + +// UpdateServerNoticeRoomResponse is the response for UpdateServerNoticeRoomRequest +type UpdateServerNoticeRoomResponse struct{} + // Device represents a client's device (mobile, web, etc) type Device struct { ID string diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 6d8d280072..03b8700497 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -203,6 +203,36 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex return err } +func (t *UserInternalAPITrace) QueryPolicyVersion(ctx context.Context, req *QueryPolicyVersionRequest, res *QueryPolicyVersionResponse) error { + err := t.Impl.QueryPolicyVersion(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPolicyVersion req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) QueryOutdatedPolicy(ctx context.Context, req *QueryOutdatedPolicyRequest, res *QueryOutdatedPolicyResponse) error { + err := t.Impl.QueryOutdatedPolicy(ctx, req, res) + util.GetLogger(ctx).Infof("QueryOutdatedPolicy req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) PerformUpdatePolicyVersion(ctx context.Context, req *UpdatePolicyVersionRequest, res *UpdatePolicyVersionResponse) error { + err := t.Impl.PerformUpdatePolicyVersion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformUpdatePolicyVersion req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) SelectServerNoticeRoomID(ctx context.Context, req *QueryServerNoticeRoomRequest, res *QueryServerNoticeRoomResponse) error { + err := t.Impl.SelectServerNoticeRoomID(ctx, req, res) + util.GetLogger(ctx).Infof("SelectServerNoticeRoomID req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) UpdateServerNoticeRoomID(ctx context.Context, req *UpdateServerNoticeRoomRequest, res *UpdateServerNoticeRoomResponse) error { + err := t.Impl.UpdateServerNoticeRoomID(ctx, req, res) + util.GetLogger(ctx).Infof("UpdateServerNoticeRoomID req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 9d2f63c72b..2896fd345f 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -66,7 +66,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { - acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) + acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -833,3 +833,60 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re } const pushRulesAccountDataType = "m.push_rules" + +func (a *UserInternalAPI) QueryPolicyVersion( + ctx context.Context, + req *api.QueryPolicyVersionRequest, + res *api.QueryPolicyVersionResponse, +) error { + var err error + res.PolicyVersion, err = a.DB.GetPrivacyPolicy(ctx, req.Localpart) + if err != nil { + return err + } + + return nil +} + +func (a *UserInternalAPI) QueryOutdatedPolicy( + ctx context.Context, + req *api.QueryOutdatedPolicyRequest, + res *api.QueryOutdatedPolicyResponse, +) error { + var err error + res.UserLocalparts, err = a.DB.GetOutdatedPolicy(ctx, req.PolicyVersion) + if err != nil { + return err + } + + return nil +} + +func (a *UserInternalAPI) PerformUpdatePolicyVersion( + ctx context.Context, + req *api.UpdatePolicyVersionRequest, + res *api.UpdatePolicyVersionResponse, +) error { + return a.DB.UpdatePolicyVersion(ctx, req.PolicyVersion, req.Localpart, req.ServerNoticeUpdate) +} + +func (a *UserInternalAPI) SelectServerNoticeRoomID( + ctx context.Context, + req *api.QueryServerNoticeRoomRequest, + res *api.QueryServerNoticeRoomResponse, +) (err error) { + roomID, err := a.DB.SelectServerNoticeRoomID(ctx, req.Localpart) + if err != nil { + return err + } + res.RoomID = roomID + return nil +} + +func (a *UserInternalAPI) UpdateServerNoticeRoomID( + ctx context.Context, + req *api.UpdateServerNoticeRoomRequest, + res *api.UpdateServerNoticeRoomResponse, +) (err error) { + return a.DB.UpdateServerNoticeRoomID(ctx, req.Localpart, req.RoomID) +} diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 23c335cf22..d938a70cb2 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -44,6 +44,8 @@ const ( PerformSetDisplayNamePath = "/userapi/performSetDisplayName" PerformForgetThreePIDPath = "/userapi/performForgetThreePID" PerformSaveThreePIDAssociationPath = "/userapi/performSaveThreePIDAssociation" + PerformUpdatePolicyVersionPath = "/userapi/performUpdatePolicyVersion" + PerformUpdateServerNoticeRoomPath = "/userapi/performUpdateServerNoticeRoom" QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryProfilePath = "/userapi/queryProfile" @@ -61,6 +63,9 @@ const ( QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" + QueryPolicyVersionPath = "/userapi/queryPolicyVersion" + QueryOutdatedPolicyUsersPath = "/userapi/queryOutdatedPolicy" + QueryServerNoticeRoomPath = "/userapi/queryServerNoticeRoom" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -391,3 +396,43 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context apiURL := h.apiURL + PerformSaveThreePIDAssociationPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) QueryOutdatedPolicy(ctx context.Context, req *api.QueryOutdatedPolicyRequest, res *api.QueryOutdatedPolicyResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOutdatedPolicy") + defer span.Finish() + + apiURL := h.apiURL + QueryOutdatedPolicyUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) PerformUpdatePolicyVersion(ctx context.Context, req *api.UpdatePolicyVersionRequest, res *api.UpdatePolicyVersionResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUpdatePolicyVersion") + defer span.Finish() + + apiURL := h.apiURL + PerformUpdatePolicyVersionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) SelectServerNoticeRoomID(ctx context.Context, req *api.QueryServerNoticeRoomRequest, res *api.QueryServerNoticeRoomResponse) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "SelectServerNoticeRoomID") + defer span.Finish() + + apiURL := h.apiURL + QueryServerNoticeRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) QueryPolicyVersion(ctx context.Context, req *api.QueryPolicyVersionRequest, res *api.QueryPolicyVersionResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPolicyVersion") + defer span.Finish() + + apiURL := h.apiURL + QueryPolicyVersionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) UpdateServerNoticeRoomID(ctx context.Context, req *api.UpdateServerNoticeRoomRequest, res *api.UpdateServerNoticeRoomResponse) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "UpdateServerNoticeRoomID") + defer span.Finish() + + apiURL := h.apiURL + PerformUpdateServerNoticeRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ad532b901a..805d052597 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -457,4 +457,74 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}} }), ) + internalAPIMux.Handle(QueryPolicyVersionPath, + httputil.MakeInternalAPI("queryPolicyVersion", func(req *http.Request) util.JSONResponse { + request := api.QueryPolicyVersionRequest{} + response := api.QueryPolicyVersionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.QueryPolicyVersion(req.Context(), &request, &response) + if err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryOutdatedPolicyUsersPath, + httputil.MakeInternalAPI("queryOutdatedPolicyUsers", func(req *http.Request) util.JSONResponse { + request := api.QueryOutdatedPolicyRequest{} + response := api.QueryOutdatedPolicyResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.QueryOutdatedPolicy(req.Context(), &request, &response) + if err != nil { + return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformUpdatePolicyVersionPath, + httputil.MakeInternalAPI("performUpdatePolicyVersionPath", func(req *http.Request) util.JSONResponse { + request := api.UpdatePolicyVersionRequest{} + response := api.UpdatePolicyVersionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.PerformUpdatePolicyVersion(req.Context(), &request, &response) + if err != nil { + return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryServerNoticeRoomPath, + httputil.MakeInternalAPI("queryServerNoticeRoom", func(req *http.Request) util.JSONResponse { + request := api.QueryServerNoticeRoomRequest{} + response := api.QueryServerNoticeRoomResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.SelectServerNoticeRoomID(req.Context(), &request, &response) + if err != nil { + return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformUpdateServerNoticeRoomPath, + httputil.MakeInternalAPI("performUpdateServerNoticeRoom", func(req *http.Request) util.JSONResponse { + request := api.UpdateServerNoticeRoomRequest{} + response := api.UpdateServerNoticeRoomResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.UpdateServerNoticeRoomID(req.Context(), &request, &response) + if err != nil { + return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index f7cd1810ac..fff18eb284 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -36,7 +36,7 @@ type Account interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) + CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, policyVersion string, accountType api.AccountType) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetNewNumericLocalpart(ctx context.Context) (int64, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) @@ -126,9 +126,18 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } +type ConsentTracking interface { + GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) + GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) + UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) error + SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) + UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) +} + type Database interface { Account AccountData + ConsentTracking Device KeyBackup LoginToken diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index f86812f17b..a7a0ae3fa6 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/userutil" @@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- If the account is currently active is_deactivated BOOLEAN DEFAULT FALSE, -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) - account_type SMALLINT NOT NULL + account_type SMALLINT NOT NULL, + -- The policy version this user has accepted + policy_version TEXT, + -- The policy version the user received from the server notices room + policy_version_sent TEXT, + server_notice_room_id TEXT -- TODO: -- upgraded_ts, devices, any email reset stuff? ); ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -67,14 +73,38 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'" +const selectPrivacyPolicySQL = "" + + "SELECT policy_version FROM account_accounts WHERE localpart = $1" + +const batchSelectPrivacyPolicySQL = "" + + "SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $1)" + +const updatePolicyVersionSQL = "" + + "UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2" + +const updatePolicyVersionServerNoticeSQL = "" + + "UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2" + +const selectServerNoticeRoomSQL = "" + + "SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1" + +const updateServerNoticeRoomSQL = "" + + "UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2" + type accountsStatements struct { - insertAccountStmt *sql.Stmt - updatePasswordStmt *sql.Stmt - deactivateAccountStmt *sql.Stmt - selectAccountByLocalpartStmt *sql.Stmt - selectPasswordHashStmt *sql.Stmt - selectNewNumericLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt + deactivateAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + selectPrivacyPolicyStmt *sql.Stmt + batchSelectPrivacyPolicyStmt *sql.Stmt + updatePolicyVersionStmt *sql.Stmt + updatePolicyVersionServerNoticeStmt *sql.Stmt + selectServerNoticeRoomStmt *sql.Stmt + updateServerNoticeRoomStmt *sql.Stmt + serverName gomatrixserverlib.ServerName } func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { @@ -92,6 +122,12 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, + {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, + {&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, + {&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL}, + {&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL}, + {&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL}, }.Prepare(db) } @@ -99,16 +135,16 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error if accountType != api.AccountTypeAppService { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion) } if err != nil { return nil, err @@ -178,3 +214,71 @@ func (s *accountsStatements) SelectNewNumericLocalpart( err = stmt.QueryRowContext(ctx).Scan(&id) return id + 1, err } + +// selectPrivacyPolicy gets the current privacy policy a specific user accepted +func (s *accountsStatements) SelectPrivacyPolicy( + ctx context.Context, txn *sql.Tx, localPart string, +) (policy string, err error) { + var policyNull sql.NullString + stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt) + err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull) + return policyNull.String, err +} + +// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version +func (s *accountsStatements) BatchSelectPrivacyPolicy( + ctx context.Context, txn *sql.Tx, policyVersion string, +) (userIDs []string, err error) { + stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt) + rows, err := stmt.QueryContext(ctx, policyVersion) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return userIDs, err + } + userIDs = append(userIDs, userID) + } + return userIDs, rows.Err() +} + +// updatePolicyVersion sets the policy_version for a specific user +func (s *accountsStatements) UpdatePolicyVersion( + ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool, +) (err error) { + stmt := s.updatePolicyVersionStmt + if serverNotice { + stmt = s.updatePolicyVersionServerNoticeStmt + } + stmt = sqlutil.TxStmt(txn, stmt) + _, err = stmt.ExecContext(ctx, policyVersion, localpart) + return err +} + +// SelectServerNoticeRoomID queries the server notice room ID. +func (s *accountsStatements) SelectServerNoticeRoomID( + ctx context.Context, txn *sql.Tx, localpart string, +) (roomID string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt) + + roomIDNull := sql.NullString{} + row := stmt.QueryRowContext(ctx, localpart) + err = row.Scan(&roomIDNull) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // roomIDNull.String is either the roomID or an empty string + return roomIDNull.String, nil +} + +// UpdateServerNoticeRoomID sets the server notice room ID. +func (s *accountsStatements) UpdateServerNoticeRoomID( + ctx context.Context, txn *sql.Tx, localpart, roomID string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt) + _, err = stmt.ExecContext(ctx, roomID, localpart) + return +} diff --git a/userapi/storage/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go index 32d3235bef..c387599754 100644 --- a/userapi/storage/postgres/deltas/20200929203058_is_active.go +++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go @@ -12,6 +12,7 @@ import ( func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) goose.AddMigration(UpAddAccountType, DownAddAccountType) + goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion) } func LoadIsActive(m *sqlutil.Migrations) { diff --git a/userapi/storage/postgres/deltas/2022043014375800_add_policy_version.go b/userapi/storage/postgres/deltas/2022043014375800_add_policy_version.go new file mode 100644 index 0000000000..1638fb4feb --- /dev/null +++ b/userapi/storage/postgres/deltas/2022043014375800_add_policy_version.go @@ -0,0 +1,45 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadAddPolicyVersion(m *sqlutil.Migrations) { + m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion) +} + +func UpAddPolicyVersion(tx *sql.Tx) error { + _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version TEXT;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS policy_version_sent TEXT;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS server_notice_room_id TEXT;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + + return nil +} + +func DownAddPolicyVersion(tx *sql.Tx) error { + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index b9afb5a56b..6427560031 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -46,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, deltas.LoadIsActive(m) //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) + deltas.LoadAddPolicyVersion(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 0cf713dac7..a0531d857e 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -125,7 +126,7 @@ func (d *Database) SetPassword( // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, + ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType, ) (acc *api.Account, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { // For guest accounts, we create a new numeric local part @@ -139,7 +140,7 @@ func (d *Database) CreateAccount( plaintextPassword = "" appserviceID = "" } - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion, accountType) return err }) return @@ -148,7 +149,7 @@ func (d *Database) CreateAccount( // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string, accountType api.AccountType, ) (*api.Account, error) { var err error var account *api.Account @@ -160,7 +161,8 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion, accountType); err != nil { + logrus.WithError(err).Error("d.Accounts.InsertAccount error") return nil, sqlutil.ErrUserExists } if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { @@ -763,3 +765,42 @@ func (d *Database) RemovePushers( func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) { return d.Stats.UserStatistics(ctx, nil) } + +// GetPrivacyPolicy returns the accepted privacy policy version, if any. +func (d *Database) GetPrivacyPolicy(ctx context.Context, localpart string) (policyVersion string, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + policyVersion, err = d.Accounts.SelectPrivacyPolicy(ctx, txn, localpart) + return err + }) + return +} + +// GetOutdatedPolicy queries all users which didn't accept the current policy version +func (d *Database) GetOutdatedPolicy(ctx context.Context, policyVersion string) (userIDs []string, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + userIDs, err = d.Accounts.BatchSelectPrivacyPolicy(ctx, txn, policyVersion) + return err + }) + return +} + +// UpdatePolicyVersion sets the accepted policy_version for a user. +func (d *Database) UpdatePolicyVersion(ctx context.Context, policyVersion, localpart string, serverNotice bool) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Accounts.UpdatePolicyVersion(ctx, txn, policyVersion, localpart, serverNotice) + }) + return +} + +// SelectServerNoticeRoomID returns the server notice room, if one is set. +func (d *Database) SelectServerNoticeRoomID(ctx context.Context, localpart string) (roomID string, err error) { + return d.Accounts.SelectServerNoticeRoomID(ctx, nil, localpart) +} + +// UpdateServerNoticeRoomID updates the server notice room +func (d *Database) UpdateServerNoticeRoomID(ctx context.Context, localpart, roomID string) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Accounts.UpdateServerNoticeRoomID(ctx, txn, localpart, roomID) + }) + return +} diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 6c5fe30719..7a700a3067 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/userutil" @@ -43,14 +44,19 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- If the account is currently active is_deactivated BOOLEAN DEFAULT 0, -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) - account_type INTEGER NOT NULL + account_type INTEGER NOT NULL, + -- The policy version this user has accepted + policy_version TEXT, + -- The policy version the user received from the server notices room + policy_version_sent TEXT, + server_notice_room_id TEXT -- TODO: -- upgraded_ts, devices, any email reset stuff? ); ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type, policy_version) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -67,15 +73,39 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0" +const selectPrivacyPolicySQL = "" + + "SELECT policy_version FROM account_accounts WHERE localpart = $1" + +const batchSelectPrivacyPolicySQL = "" + + "SELECT localpart FROM account_accounts WHERE (policy_version IS NULL OR policy_version <> $1) AND (policy_version_sent IS NULL OR policy_version_sent <> $2)" + +const updatePolicyVersionSQL = "" + + "UPDATE account_accounts SET policy_version = $1 WHERE localpart = $2" + +const updatePolicyVersionServerNoticeSQL = "" + + "UPDATE account_accounts SET policy_version_sent = $1 WHERE localpart = $2" + +const selectServerNoticeRoomSQL = "" + + "SELECT server_notice_room_id FROM account_accounts WHERE localpart = $1" + +const updateServerNoticeRoomSQL = "" + + "UPDATE account_accounts SET server_notice_room_id = $1 WHERE localpart = $2" + type accountsStatements struct { - db *sql.DB - insertAccountStmt *sql.Stmt - updatePasswordStmt *sql.Stmt - deactivateAccountStmt *sql.Stmt - selectAccountByLocalpartStmt *sql.Stmt - selectPasswordHashStmt *sql.Stmt - selectNewNumericLocalpartStmt *sql.Stmt - serverName gomatrixserverlib.ServerName + db *sql.DB + insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt + deactivateAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + selectPrivacyPolicyStmt *sql.Stmt + batchSelectPrivacyPolicyStmt *sql.Stmt + updatePolicyVersionStmt *sql.Stmt + updatePolicyVersionServerNoticeStmt *sql.Stmt + selectServerNoticeRoomStmt *sql.Stmt + updateServerNoticeRoomStmt *sql.Stmt + serverName gomatrixserverlib.ServerName } func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { @@ -94,6 +124,12 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, {&s.selectPasswordHashStmt, selectPasswordHashSQL}, {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + {&s.selectPrivacyPolicyStmt, selectPrivacyPolicySQL}, + {&s.batchSelectPrivacyPolicyStmt, batchSelectPrivacyPolicySQL}, + {&s.updatePolicyVersionStmt, updatePolicyVersionSQL}, + {&s.updatePolicyVersionServerNoticeStmt, updatePolicyVersionServerNoticeSQL}, + {&s.selectServerNoticeRoomStmt, selectServerNoticeRoomSQL}, + {&s.updateServerNoticeRoomStmt, updateServerNoticeRoomSQL}, }.Prepare(db) } @@ -101,16 +137,16 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 - stmt := s.insertAccountStmt + stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error if accountType != api.AccountTypeAppService { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType, policyVersion) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType, policyVersion) } if err != nil { return nil, err @@ -183,3 +219,72 @@ func (s *accountsStatements) SelectNewNumericLocalpart( } return id + 1, err } + +// selectPrivacyPolicy gets the current privacy policy a specific user accepted + +func (s *accountsStatements) SelectPrivacyPolicy( + ctx context.Context, txn *sql.Tx, localPart string, +) (policy string, err error) { + var policyNull sql.NullString + stmt := sqlutil.TxStmt(txn, s.selectPrivacyPolicyStmt) + err = stmt.QueryRowContext(ctx, localPart).Scan(&policyNull) + return policyNull.String, err +} + +// batchSelectPrivacyPolicy queries all users which didn't accept the current policy version +func (s *accountsStatements) BatchSelectPrivacyPolicy( + ctx context.Context, txn *sql.Tx, policyVersion string, +) (userIDs []string, err error) { + stmt := sqlutil.TxStmt(txn, s.batchSelectPrivacyPolicyStmt) + rows, err := stmt.QueryContext(ctx, policyVersion, policyVersion) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "BatchSelectPrivacyPolicy: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return userIDs, err + } + userIDs = append(userIDs, userID) + } + return userIDs, rows.Err() +} + +// updatePolicyVersion sets the policy_version for a specific user +func (s *accountsStatements) UpdatePolicyVersion( + ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool, +) (err error) { + stmt := s.updatePolicyVersionStmt + if serverNotice { + stmt = s.updatePolicyVersionServerNoticeStmt + } + stmt = sqlutil.TxStmt(txn, stmt) + _, err = stmt.ExecContext(ctx, policyVersion, localpart) + return err +} + +// SelectServerNoticeRoomID queries the server notice room ID. +func (s *accountsStatements) SelectServerNoticeRoomID( + ctx context.Context, txn *sql.Tx, localpart string, +) (roomID string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectServerNoticeRoomStmt) + + roomIDNull := sql.NullString{} + row := stmt.QueryRowContext(ctx, localpart) + err = row.Scan(&roomIDNull) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // roomIDNull.String is either the roomID or an empty string + return roomIDNull.String, nil +} + +// UpdateServerNoticeRoomID sets the server notice room ID. +func (s *accountsStatements) UpdateServerNoticeRoomID( + ctx context.Context, txn *sql.Tx, localpart, roomID string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.updateServerNoticeRoomStmt) + _, err = stmt.ExecContext(ctx, roomID, localpart) + return +} diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index c69614e834..24ef265e78 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -12,6 +12,7 @@ import ( func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) goose.AddMigration(UpAddAccountType, DownAddAccountType) + goose.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion) } func LoadIsActive(m *sqlutil.Migrations) { diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 9b058dedd9..bcced55fd8 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -4,15 +4,9 @@ import ( "database/sql" "fmt" - "github.com/pressly/goose" - "github.com/matrix-org/dendrite/internal/sqlutil" ) -func init() { - goose.AddMigration(UpAddAccountType, DownAddAccountType) -} - func LoadAddAccountType(m *sqlutil.Migrations) { m.AddMigration(UpAddAccountType, DownAddAccountType) } diff --git a/userapi/storage/sqlite3/deltas/2022043014375800_add_policy_version.go b/userapi/storage/sqlite3/deltas/2022043014375800_add_policy_version.go new file mode 100644 index 0000000000..251ec4e40b --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022043014375800_add_policy_version.go @@ -0,0 +1,44 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadAddPolicyVersion(m *sqlutil.Migrations) { + m.AddMigration(UpAddPolicyVersion, DownAddPolicyVersion) +} + +func UpAddPolicyVersion(tx *sql.Tx) error { + _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version TEXT;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN policy_version_sent TEXT;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts ADD COLUMN server_notice_room_id TEXT;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddPolicyVersion(tx *sql.Tx) error { + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN policy_version_sent;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.Exec("ALTER TABLE account_accounts DROP COLUMN server_notice_room_id;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index a822f687d5..ca0e75fb99 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -18,11 +18,10 @@ import ( "fmt" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" @@ -47,6 +46,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, deltas.LoadIsActive(m) //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) + deltas.LoadAddPolicyVersion(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 5bee880d39..32583fb0f1 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -78,7 +78,7 @@ func Test_Accounts(t *testing.T) { aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) - accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") // verify the newly create account is the same as returned by CreateAccount var accGet *api.Account @@ -102,7 +102,7 @@ func Test_Accounts(t *testing.T) { first, err := db.GetNewNumericLocalpart(ctx) assert.NoError(t, err, "failed to get new numeric localpart") // Create a new account to verify the numeric localpart is updated - _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", "testing", "", "v1.0", api.AccountTypeGuest) assert.NoError(t, err, "failed to create account") second, err := db.GetNewNumericLocalpart(ctx) assert.NoError(t, err) @@ -350,7 +350,7 @@ func Test_Profile(t *testing.T) { defer close() // create account, which also creates a profile - _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", "v1.0", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 2fe9556707..39e4ea8cda 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -32,12 +32,18 @@ type AccountDataTable interface { } type AccountsTable interface { - InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string, accountType api.AccountType) (*api.Account, error) UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) DeactivateAccount(ctx context.Context, localpart string) (err error) SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) + + SelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, localPart string) (policy string, err error) + BatchSelectPrivacyPolicy(ctx context.Context, txn *sql.Tx, policyVersion string) (userIDs []string, err error) + UpdatePolicyVersion(ctx context.Context, txn *sql.Tx, policyVersion, localpart string, serverNotice bool) (err error) + SelectServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart string) (roomID string, err error) + UpdateServerNoticeRoomID(ctx context.Context, txn *sql.Tx, localpart, roomID string) (err error) } type DevicesTable interface { diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index 11521c8b00..1b07c14126 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -88,7 +88,7 @@ func mustMakeAccountAndDevice( appServiceID = util.RandomString(16) } - _, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) + _, err := accDB.InsertAccount(ctx, nil, localpart, "", "", appServiceID, accType) if err != nil { t.Fatalf("unable to create account: %v", err) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 40e37c5d6e..364744591a 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -75,7 +75,7 @@ func TestQueryProfile(t *testing.T) { // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) defer close() - _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) + _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } @@ -154,7 +154,7 @@ func TestLoginToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) defer close() - _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) + _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) }