diff --git a/changelog.d/3550.misc b/changelog.d/3550.misc new file mode 100644 index 000000000000..2374dc0c44d5 --- /dev/null +++ b/changelog.d/3550.misc @@ -0,0 +1 @@ +Lazily load state on master process when using workers to reduce DB consumption diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f83a1581a6f1..c5ae9845c52c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -13,33 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc + from six import iteritems from frozendict import frozendict from twisted.internet import defer +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + -class EventContext(object): +class StatelessContext(object): """ Attributes: - current_state_ids (dict[(str, str), str]): - The current state map including the current event. - (type, state_key) -> event_id - - prev_state_ids (dict[(str, str), str]): - The current state map excluding the current event. - (type, state_key) -> event_id - state_group (int|None): state group id, if the state has been stored as a state group. This is usually only None if e.g. the event is an outlier. rejected (bool|str): A rejection reason if the event was rejected, else False - push_actions (list[(str, list[object])]): list of (user_id, actions) - tuples - prev_group (int): Previously persisted state group. ``None`` for an outlier. delta_ids (dict[(str, str), str]): Delta from ``prev_group``. @@ -49,9 +42,9 @@ class EventContext(object): the empty list? """ + __metaclass__ = abc.ABCMeta + __slots__ = [ - "current_state_ids", - "prev_state_ids", "state_group", "rejected", "prev_group", @@ -61,10 +54,6 @@ class EventContext(object): ] def __init__(self): - # The current state including the current event - self.current_state_ids = None - # The current state excluding the current event - self.prev_state_ids = None self.state_group = None self.rejected = False @@ -78,9 +67,56 @@ def __init__(self): self.app_service = None + @abc.abstractmethod + def get_current_state_ids(self, store): + """Gets the current state IDs + + Returns: + Deferred[dict[(str, str), str]|None]: Returns None if state_group + is None, which happens when the associated event is an outlier. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_prev_state_ids(self, store): + """Gets the prev state IDs + + Returns: + Deferred[dict[(str, str), str]|None]: Returns None if state_group + is None, which happens when the associated event is an outlier. + """ + raise NotImplementedError() + + +class EventContext(StatelessContext): + """This is the same as StatelessContext, except that + current_state_ids and prev_state_ids are already calculated. + + Attributes: + current_state_ids (dict[(str, str), str]|None): + The current state map including the current event. + (type, state_key) -> event_id + Is None if event is an outlier + + prev_state_ids (dict[(str, str), str]|None): + The current state map excluding the current event. + (type, state_key) -> event_id` + Is None if event is an outlier + """ + __slots__ = [ + "current_state_ids", + "prev_state_ids", + ] + + def __init__(self): + super(EventContext, self).__init__() + + self.current_state_ids = None + self.prev_state_ids = None + def serialize(self, event): """Converts self to a type that can be serialized as JSON, and then - deserialized by `deserialize` + deserialized by `DeserializedContext.deserialize` Args: event (FrozenEvent): The event that this context relates to @@ -110,46 +146,124 @@ def serialize(self, event): "app_service_id": self.app_service.id if self.app_service else None } + def get_current_state_ids(self, store): + """Implements StatelessContext""" + return defer.succeed(self.current_state_ids) + + def get_prev_state_ids(self, store): + """Implements StatelessContext""" + return defer.succeed(self.prev_state_ids) + + +class DeserializedContext(StatelessContext): + """A context that comes from a serialized version of a StatelessContext. + + It does not necessarily have current_state_ids and prev_state_ids precomputed + (unlike EventContext), but does cache the results of + `get_current_state_ids` and `get_prev_state_ids`. + + Attributes: + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + _prev_state_id (str|None): If set then the event associated with the + context overrode the _prev_state_id + _event_type (str): The type of the event the context is associated with + _event_state_key (str|None): The state_key of the event the context is + associated with + _current_state_ids (dict[(str, str), str]|None): + The current state map including the current event. + (type, state_key) -> event_id + _prev_state_ids (dict[(str, str), str]|None): + The current state map excluding the current event. + (type, state_key) -> event_id` + """ + + __slots__ = [ + "_current_state_ids", + "_prev_state_ids", + "_fetching_state_deferred", + "_prev_state_id", + "_event_type", + "_event_state_key", + ] + @staticmethod - @defer.inlineCallbacks def deserialize(store, input): """Converts a dict that was produced by `serialize` back into a - EventContext. + StatelessContext. Args: store (DataStore): Used to convert AS ID to AS object input (dict): A dict produced by `serialize` Returns: - EventContext + DeserializedContext """ - context = EventContext() + context = DeserializedContext() context.state_group = input["state_group"] context.rejected = input["rejected"] context.prev_group = input["prev_group"] context.delta_ids = _decode_state_dict(input["delta_ids"]) context.prev_state_events = input["prev_state_events"] - # We use the state_group and prev_state_id stuff to pull the - # current_state_ids out of the DB and construct prev_state_ids. - prev_state_id = input["prev_state_id"] - event_type = input["event_type"] - event_state_key = input["event_state_key"] + context._prev_state_id = input["prev_state_id"] + context._event_type = input["event_type"] + context._event_state_key = input["event_state_key"] - context.current_state_ids = yield store.get_state_ids_for_group( - context.state_group, - ) - if prev_state_id and event_state_key: - context.prev_state_ids = dict(context.current_state_ids) - context.prev_state_ids[(event_type, event_state_key)] = prev_state_id - else: - context.prev_state_ids = context.current_state_ids + context._fetching_state_deferred = None + context._current_state_ids = None + context._prev_state_ids = None app_service_id = input["app_service_id"] if app_service_id: context.app_service = store.get_app_service_by_id(app_service_id) - defer.returnValue(context) + return context + + @defer.inlineCallbacks + def get_current_state_ids(self, store): + """Implements StatelessContext""" + + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store, + ) + + yield make_deferred_yieldable(self._fetching_state_deferred) + + defer.returnValue(self._current_state_ids) + + @defer.inlineCallbacks + def get_prev_state_ids(self, store): + """Implements StatelessContext""" + + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store, + ) + + yield make_deferred_yieldable(self._fetching_state_deferred) + + defer.returnValue(self._prev_state_ids) + + @defer.inlineCallbacks + def _fill_out_state(self, store): + """Called to populate the _current_state_ids and _prev_state_ids + attributes by loading from the database. + """ + if self.state_group is None: + return + + self._current_state_ids = yield store.get_state_ids_for_group( + self.state_group, + ) + if self._prev_state_id and self._event_state_key is not None: + self._prev_state_ids = dict(self._current_state_ids) + + key = (self._event_type, self._event_state_key) + self._prev_state_ids[key] = self._prev_state_id + else: + self._prev_state_ids = self._current_state_ids def _encode_state_dict(state_dict): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index b6a8b3aa3b31..50ad1d329ae8 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -106,14 +106,20 @@ def ratelimit(self, requester, update=True): @defer.inlineCallbacks def maybe_kick_guest_users(self, event, context=None): + """ + Args: + event (FrozenEvent) + context (StatelessContext) + """ # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: + current_state_ids = yield context.get_current_state_ids(self.store) current_state = yield self.store.get_events( - list(context.current_state_ids.values()) + list(current_state_ids.values()) ) else: current_state = yield self.state_handler.get_current_state( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index abc07ea87c88..1473f7ec3c56 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -850,6 +850,13 @@ def persist_and_notify_client_event( calculated the push actions for the event, and checked auth. This should only be run on master. + + Args: + requester (Requester) + event (FrozenEvent) + context (StatelessContext) + ratelimit(bool) + extra_users (list[UserID]) """ assert not self.config.worker_app @@ -884,9 +891,11 @@ def is_inviter_member_event(e): e.sender == event.sender ) + current_state_ids = yield context.get_current_state_ids(self.store) + state_to_include_ids = [ e_id - for k, e_id in iteritems(context.current_state_ids) + for k, e_id in current_state_ids.iteritems() if k[0] in self.hs.config.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -922,8 +931,9 @@ def is_inviter_member_event(e): ) if event.type == EventTypes.Redaction: + prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, context.prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True, ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = { @@ -943,11 +953,13 @@ def is_inviter_member_event(e): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.prev_state_ids: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) + if event.type == EventTypes.Create: + prev_state_ids = yield context.get_prev_state_ids(self.store) + if prev_state_ids: + raise AuthError( + 403, + "Changing the room create event is forbidden", + ) (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 2eede547921e..a054030b443e 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -24,7 +24,7 @@ SynapseError, ) from synapse.events import FrozenEvent -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import DeserializedContext from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import Requester, UserID from synapse.util.caches.response_cache import ResponseCache @@ -136,7 +136,9 @@ def _handle_request(self, request): event = FrozenEvent(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = yield EventContext.deserialize(self.store, content["context"]) + context = yield DeserializedContext.deserialize( + self.store, content["context"], + ) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 4ff0fdc4abab..210a6a164254 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -32,7 +32,7 @@ from synapse.api.errors import SynapseError # these are only included to make the type annotations work from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events.snapshot import StatelessContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id @@ -90,7 +90,7 @@ def add_to_queue(self, room_id, events_and_contexts, backfilled): Args: room_id (str): - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts (list[(EventBase, StatelessContext)]): backfilled (bool): Returns: @@ -264,7 +264,7 @@ def persist_event(self, event, context, backfilled=False): Args: event (EventBase): - context (EventContext): + context (StatelessContext): backfilled (bool): Returns: @@ -301,7 +301,7 @@ def _persist_events(self, events_and_contexts, backfilled=False, """Persist events to db Args: - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts (list[(EventBase, StatelessContext)]): backfilled (bool): delete_existing (bool): @@ -518,7 +518,7 @@ def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ room_id (str): room to which the events are being added. Used for logging etc - events_context (list[(EventBase, EventContext)]): + events_context (list[(EventBase, StatelessContext)]): events and contexts which are being added to the room old_latest_event_ids (iterable[str]): @@ -549,7 +549,7 @@ def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ if ctx.state_group in state_groups_map: continue - state_groups_map[ctx.state_group] = ctx.current_state_ids + state_groups_map[ctx.state_group] = yield ctx.get_current_state_ids(self) # We need to map the event_ids to their state groups. First, let's # check if the event is one we're persisting, in which case we can @@ -670,7 +670,7 @@ def _persist_events_txn(self, txn, events_and_contexts, backfilled, Args: txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts (list[(EventBase, StatelessContext)]): events to persist backfilled (bool): True if the events were backfilled delete_existing (bool): True to purge existing table rows for the @@ -882,9 +882,9 @@ def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): Pick the earliest non-outlier if there is one, else the earliest one. Args: - events_and_contexts (list[(EventBase, EventContext)]): + events_and_contexts (list[(EventBase, StatelessContext)]): Returns: - list[(EventBase, EventContext)]: filtered list + list[(EventBase, StatelessContext)]: filtered list """ new_events_and_contexts = OrderedDict() for event, context in events_and_contexts: @@ -905,7 +905,7 @@ def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): Args: txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events + events_and_contexts (list[(EventBase, StatelessContext)]): events we are persisting backfilled (bool): True if the events were backfilled """ @@ -935,11 +935,11 @@ def _update_outliers_txn(self, txn, events_and_contexts): Args: txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events + events_and_contexts (list[(EventBase, StatelessContext)]): events we are persisting Returns: - list[(EventBase, EventContext)] new list, without events which + list[(EventBase, StatelessContext)] new list, without events which are already in the events table. """ txn.execute( @@ -1072,7 +1072,7 @@ def _store_event_txn(self, txn, events_and_contexts): Args: txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events + events_and_contexts (list[(EventBase, StatelessContext)]): events we are persisting """ @@ -1134,11 +1134,11 @@ def _store_rejected_events_txn(self, txn, events_and_contexts): Args: txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events + events_and_contexts (list[(EventBase, StatelessContext)]): events we are persisting Returns: - list[(EventBase, EventContext)] new list, without the rejected + list[(EventBase, StatelessContext)] new list, without the rejected events. """ # Remove the rejected events from the list now that we've added them @@ -1162,9 +1162,9 @@ def _update_metadata_tables_txn(self, txn, events_and_contexts, Args: txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events + events_and_contexts (list[(EventBase, StatelessContext)]): events we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all + all_events_and_contexts (list[(EventBase, StatelessContext)]): all events that we were going to persist. This includes events we've already persisted, etc, that wouldn't appear in events_and_context.