mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-31 16:21:56 +01:00 
			
		
		
		
	Refactor _get_events
This commit is contained in:
		
							parent
							
								
									36ea26c5c0
								
							
						
					
					
						commit
						cdb3757942
					
				| @ -17,6 +17,7 @@ import logging | ||||
| from synapse.api.errors import StoreError | ||||
| from synapse.events import FrozenEvent | ||||
| from synapse.events.utils import prune_event | ||||
| from synapse.util import unwrap_deferred | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.logcontext import preserve_context_over_fn, LoggingContext | ||||
| from synapse.util.lrucache import LruCache | ||||
| @ -28,7 +29,6 @@ from twisted.internet import defer | ||||
| 
 | ||||
| from collections import namedtuple, OrderedDict | ||||
| import functools | ||||
| import itertools | ||||
| import simplejson as json | ||||
| import sys | ||||
| import time | ||||
| @ -870,35 +870,43 @@ class SQLBaseStore(object): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_events(self, event_ids, check_redacted=True, | ||||
|                     get_prev_content=False, desc="_get_events"): | ||||
|         N = 50  # Only fetch 100 events at a time. | ||||
|                     get_prev_content=False, allow_rejected=False, txn=None): | ||||
|         if not event_ids: | ||||
|             defer.returnValue([]) | ||||
| 
 | ||||
|         ds = [ | ||||
|             self._fetch_events( | ||||
|                 event_ids[i*N:(i+1)*N], | ||||
|         event_map = self._get_events_from_cache( | ||||
|             event_ids, | ||||
|             check_redacted=check_redacted, | ||||
|             get_prev_content=get_prev_content, | ||||
|             allow_rejected=allow_rejected, | ||||
|         ) | ||||
|             for i in range(1 + len(event_ids) / N) | ||||
|         ] | ||||
| 
 | ||||
|         res = yield defer.gatherResults(ds, consumeErrors=True) | ||||
|         missing_events = [e for e in event_ids if e not in event_map] | ||||
| 
 | ||||
|         defer.returnValue( | ||||
|             list(itertools.chain(*res)) | ||||
|         missing_events = yield self._fetch_events( | ||||
|             txn, | ||||
|             missing_events, | ||||
|             check_redacted=check_redacted, | ||||
|             get_prev_content=get_prev_content, | ||||
|             allow_rejected=allow_rejected, | ||||
|         ) | ||||
| 
 | ||||
|         event_map.update(missing_events) | ||||
| 
 | ||||
|         defer.returnValue([ | ||||
|             event_map[e_id] for e_id in event_ids | ||||
|             if e_id in event_map and event_map[e_id] | ||||
|         ]) | ||||
| 
 | ||||
|     def _get_events_txn(self, txn, event_ids, check_redacted=True, | ||||
|                         get_prev_content=False): | ||||
|         N = 50  # Only fetch 100 events at a time. | ||||
|         return list(itertools.chain(*[ | ||||
|             self._fetch_events_txn( | ||||
|                 txn, event_ids[i*N:(i+1)*N], | ||||
|                         get_prev_content=False, allow_rejected=False): | ||||
|         return unwrap_deferred(self._get_events( | ||||
|             event_ids, | ||||
|             check_redacted=check_redacted, | ||||
|             get_prev_content=get_prev_content, | ||||
|             ) | ||||
|             for i in range(1 + len(event_ids) / N) | ||||
|         ])) | ||||
|             allow_rejected=allow_rejected, | ||||
|             txn=txn, | ||||
|         )) | ||||
| 
 | ||||
|     def _invalidate_get_event_cache(self, event_id): | ||||
|         for check_redacted in (False, True): | ||||
| @ -909,68 +917,24 @@ class SQLBaseStore(object): | ||||
|     def _get_event_txn(self, txn, event_id, check_redacted=True, | ||||
|                        get_prev_content=False, allow_rejected=False): | ||||
| 
 | ||||
|         start_time = time.time() * 1000 | ||||
| 
 | ||||
|         def update_counter(desc, last_time): | ||||
|             curr_time = self._get_event_counters.update(desc, last_time) | ||||
|             sql_getevents_timer.inc_by(curr_time - last_time, desc) | ||||
|             return curr_time | ||||
| 
 | ||||
|         try: | ||||
|             ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) | ||||
| 
 | ||||
|             if allow_rejected or not ret.rejected_reason: | ||||
|                 return ret | ||||
|             else: | ||||
|                 return None | ||||
|         except KeyError: | ||||
|             pass | ||||
|         finally: | ||||
|             start_time = update_counter("event_cache", start_time) | ||||
| 
 | ||||
|         sql = ( | ||||
|             "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " | ||||
|             "FROM event_json as e " | ||||
|             "LEFT JOIN rejections as rej USING (event_id) " | ||||
|             "LEFT JOIN redactions as r ON e.event_id = r.redacts " | ||||
|             "WHERE e.event_id = ? " | ||||
|             "LIMIT 1 " | ||||
|         ) | ||||
| 
 | ||||
|         txn.execute(sql, (event_id,)) | ||||
| 
 | ||||
|         res = txn.fetchone() | ||||
| 
 | ||||
|         if not res: | ||||
|             return None | ||||
| 
 | ||||
|         internal_metadata, js, redacted, rejected_reason = res | ||||
| 
 | ||||
|         start_time = update_counter("select_event", start_time) | ||||
| 
 | ||||
|         result = self._get_event_from_row_txn( | ||||
|             txn, internal_metadata, js, redacted, | ||||
|         events = self._get_events_txn( | ||||
|             txn, [event_id], | ||||
|             check_redacted=check_redacted, | ||||
|             get_prev_content=get_prev_content, | ||||
|             rejected_reason=rejected_reason, | ||||
|             allow_rejected=allow_rejected, | ||||
|         ) | ||||
|         self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result) | ||||
| 
 | ||||
|         if allow_rejected or not rejected_reason: | ||||
|             return result | ||||
|         else: | ||||
|             return None | ||||
| 
 | ||||
|     def _fetch_events_txn(self, txn, events, check_redacted=True, | ||||
|                           get_prev_content=False, allow_rejected=False): | ||||
|         if not events: | ||||
|             return [] | ||||
|         return events[0] if events else None | ||||
| 
 | ||||
|     def _get_events_from_cache(self, events, check_redacted, get_prev_content, | ||||
|                                allow_rejected): | ||||
|         event_map = {} | ||||
| 
 | ||||
|         for event_id in events: | ||||
|             try: | ||||
|                 ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) | ||||
|                 ret = self._get_event_cache.get( | ||||
|                     event_id, check_redacted, get_prev_content | ||||
|                 ) | ||||
| 
 | ||||
|                 if allow_rejected or not ret.rejected_reason: | ||||
|                     event_map[event_id] = ret | ||||
| @ -979,136 +943,82 @@ class SQLBaseStore(object): | ||||
|             except KeyError: | ||||
|                 pass | ||||
| 
 | ||||
|         missing_events = [ | ||||
|             e for e in events | ||||
|             if e not in event_map | ||||
|         ] | ||||
| 
 | ||||
|         if missing_events: | ||||
|             sql = ( | ||||
|                 "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " | ||||
|                 " FROM event_json as e" | ||||
|                 " LEFT JOIN rejections as rej USING (event_id)" | ||||
|                 " LEFT JOIN redactions as r ON e.event_id = r.redacts" | ||||
|                 " WHERE e.event_id IN (%s)" | ||||
|             ) % (",".join(["?"]*len(missing_events)),) | ||||
| 
 | ||||
|             txn.execute(sql, missing_events) | ||||
|             rows = txn.fetchall() | ||||
| 
 | ||||
|             res = [ | ||||
|                 self._get_event_from_row_txn( | ||||
|                     txn, row[0], row[1], row[2], | ||||
|                     check_redacted=check_redacted, | ||||
|                     get_prev_content=get_prev_content, | ||||
|                     rejected_reason=row[3], | ||||
|                 ) | ||||
|                 for row in rows | ||||
|             ] | ||||
| 
 | ||||
|             event_map.update({ | ||||
|                 e.event_id: e | ||||
|                 for e in res if e | ||||
|             }) | ||||
| 
 | ||||
|             for e in res: | ||||
|                 self._get_event_cache.prefill( | ||||
|                     e.event_id, check_redacted, get_prev_content, e | ||||
|                 ) | ||||
| 
 | ||||
|         return [ | ||||
|             event_map[e_id] for e_id in events | ||||
|             if e_id in event_map and event_map[e_id] | ||||
|         ] | ||||
|         return event_map | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _fetch_events(self, events, check_redacted=True, | ||||
|     def _fetch_events(self, txn, events, check_redacted=True, | ||||
|                       get_prev_content=False, allow_rejected=False): | ||||
|         if not events: | ||||
|             defer.returnValue([]) | ||||
|             defer.returnValue({}) | ||||
| 
 | ||||
|         event_map = {} | ||||
|         rows = [] | ||||
|         N = 2 | ||||
|         for i in range(1 + len(events) / N): | ||||
|             evs = events[i*N:(i + 1)*N] | ||||
|             if not evs: | ||||
|                 break | ||||
| 
 | ||||
|         for event_id in events: | ||||
|             try: | ||||
|                 ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) | ||||
| 
 | ||||
|                 if allow_rejected or not ret.rejected_reason: | ||||
|                     event_map[event_id] = ret | ||||
|                 else: | ||||
|                     event_map[event_id] = None | ||||
|             except KeyError: | ||||
|                 pass | ||||
| 
 | ||||
|         missing_events = [ | ||||
|             e for e in events | ||||
|             if e not in event_map | ||||
|         ] | ||||
| 
 | ||||
|         if missing_events: | ||||
|             sql = ( | ||||
|                 "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " | ||||
|                 " FROM event_json as e" | ||||
|                 " LEFT JOIN rejections as rej USING (event_id)" | ||||
|                 " LEFT JOIN redactions as r ON e.event_id = r.redacts" | ||||
|                 " WHERE e.event_id IN (%s)" | ||||
|             ) % (",".join(["?"]*len(missing_events)),) | ||||
|             ) % (",".join(["?"]*len(evs)),) | ||||
| 
 | ||||
|             rows = yield self._execute( | ||||
|                 "_fetch_events", | ||||
|                 None, | ||||
|                 sql, | ||||
|                 *missing_events | ||||
|             ) | ||||
|             if txn: | ||||
|                 txn.execute(sql, evs) | ||||
|                 rows.extend(txn.fetchall()) | ||||
|             else: | ||||
|                 res = yield self._execute("_fetch_events", None, sql, *evs) | ||||
|                 rows.extend(res) | ||||
| 
 | ||||
|             res_ds = [ | ||||
|                 self._get_event_from_row( | ||||
|         res = [] | ||||
|         for row in rows: | ||||
|             e = yield self._get_event_from_row( | ||||
|                 txn, | ||||
|                 row[0], row[1], row[2], | ||||
|                 check_redacted=check_redacted, | ||||
|                 get_prev_content=get_prev_content, | ||||
|                 rejected_reason=row[3], | ||||
|             ) | ||||
|                 for row in rows | ||||
|             ] | ||||
| 
 | ||||
|             res = yield defer.gatherResults(res_ds, consumeErrors=True) | ||||
| 
 | ||||
|             event_map.update({ | ||||
|                 e.event_id: e | ||||
|                 for e in res if e | ||||
|             }) | ||||
|             res.append(e) | ||||
| 
 | ||||
|         for e in res: | ||||
|             self._get_event_cache.prefill( | ||||
|                 e.event_id, check_redacted, get_prev_content, e | ||||
|             ) | ||||
| 
 | ||||
|         defer.returnValue([ | ||||
|             event_map[e_id] for e_id in events | ||||
|             if e_id in event_map and event_map[e_id] | ||||
|         ]) | ||||
|         defer.returnValue({ | ||||
|             e.event_id: e | ||||
|             for e in res if e | ||||
|         }) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_event_from_row(self, internal_metadata, js, redacted, | ||||
|     def _get_event_from_row(self, txn, internal_metadata, js, redacted, | ||||
|                             check_redacted=True, get_prev_content=False, | ||||
|                             rejected_reason=None): | ||||
| 
 | ||||
|         start_time = time.time() * 1000 | ||||
| 
 | ||||
|         def update_counter(desc, last_time): | ||||
|             curr_time = self._get_event_counters.update(desc, last_time) | ||||
|             sql_getevents_timer.inc_by(curr_time - last_time, desc) | ||||
|             return curr_time | ||||
| 
 | ||||
|         d = json.loads(js) | ||||
|         start_time = update_counter("decode_json", start_time) | ||||
| 
 | ||||
|         internal_metadata = json.loads(internal_metadata) | ||||
|         start_time = update_counter("decode_internal", start_time) | ||||
| 
 | ||||
|         def select(txn, *args, **kwargs): | ||||
|             if txn: | ||||
|                 return self._simple_select_one_onecol_txn(txn, *args, **kwargs) | ||||
|             else: | ||||
|                 return self._simple_select_one_onecol( | ||||
|                     *args, | ||||
|                     desc="_get_event_from_row", **kwargs | ||||
|                 ) | ||||
| 
 | ||||
|         def get_event(txn, *args, **kwargs): | ||||
|             if txn: | ||||
|                 return self._get_event_txn(txn, *args, **kwargs) | ||||
|             else: | ||||
|                 return self.get_event(*args, **kwargs) | ||||
| 
 | ||||
|         if rejected_reason: | ||||
|             rejected_reason = yield self._simple_select_one_onecol( | ||||
|                 desc="_get_event_from_row", | ||||
|             rejected_reason = yield select( | ||||
|                 txn, | ||||
|                 table="rejections", | ||||
|                 keyvalues={"event_id": rejected_reason}, | ||||
|                 retcol="reason", | ||||
| @ -1119,13 +1029,12 @@ class SQLBaseStore(object): | ||||
|             internal_metadata_dict=internal_metadata, | ||||
|             rejected_reason=rejected_reason, | ||||
|         ) | ||||
|         start_time = update_counter("build_frozen_event", start_time) | ||||
| 
 | ||||
|         if check_redacted and redacted: | ||||
|             ev = prune_event(ev) | ||||
| 
 | ||||
|             redaction_id = yield self._simple_select_one_onecol( | ||||
|                 desc="_get_event_from_row", | ||||
|             redaction_id = yield select( | ||||
|                 txn, | ||||
|                 table="redactions", | ||||
|                 keyvalues={"redacts": ev.event_id}, | ||||
|                 retcol="event_id", | ||||
| @ -1134,93 +1043,26 @@ class SQLBaseStore(object): | ||||
|             ev.unsigned["redacted_by"] = redaction_id | ||||
|             # Get the redaction event. | ||||
| 
 | ||||
|             because = yield self.get_event_txn( | ||||
|             because = yield get_event( | ||||
|                 txn, | ||||
|                 redaction_id, | ||||
|                 check_redacted=False | ||||
|             ) | ||||
| 
 | ||||
|             if because: | ||||
|                 ev.unsigned["redacted_because"] = because | ||||
|             start_time = update_counter("redact_event", start_time) | ||||
| 
 | ||||
|         if get_prev_content and "replaces_state" in ev.unsigned: | ||||
|             prev = yield self.get_event( | ||||
|             prev = yield get_event( | ||||
|                 txn, | ||||
|                 ev.unsigned["replaces_state"], | ||||
|                 get_prev_content=False, | ||||
|             ) | ||||
|             if prev: | ||||
|                 ev.unsigned["prev_content"] = prev.get_dict()["content"] | ||||
|             start_time = update_counter("get_prev_content", start_time) | ||||
| 
 | ||||
|         defer.returnValue(ev) | ||||
| 
 | ||||
|     def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, | ||||
|                                 check_redacted=True, get_prev_content=False, | ||||
|                                 rejected_reason=None): | ||||
| 
 | ||||
|         start_time = time.time() * 1000 | ||||
| 
 | ||||
|         def update_counter(desc, last_time): | ||||
|             curr_time = self._get_event_counters.update(desc, last_time) | ||||
|             sql_getevents_timer.inc_by(curr_time - last_time, desc) | ||||
|             return curr_time | ||||
| 
 | ||||
|         d = json.loads(js) | ||||
|         start_time = update_counter("decode_json", start_time) | ||||
| 
 | ||||
|         internal_metadata = json.loads(internal_metadata) | ||||
|         start_time = update_counter("decode_internal", start_time) | ||||
| 
 | ||||
|         if rejected_reason: | ||||
|             rejected_reason = self._simple_select_one_onecol_txn( | ||||
|                 txn, | ||||
|                 table="rejections", | ||||
|                 keyvalues={"event_id": rejected_reason}, | ||||
|                 retcol="reason", | ||||
|             ) | ||||
| 
 | ||||
|         ev = FrozenEvent( | ||||
|             d, | ||||
|             internal_metadata_dict=internal_metadata, | ||||
|             rejected_reason=rejected_reason, | ||||
|         ) | ||||
|         start_time = update_counter("build_frozen_event", start_time) | ||||
| 
 | ||||
|         if check_redacted and redacted: | ||||
|             ev = prune_event(ev) | ||||
| 
 | ||||
|             redaction_id = self._simple_select_one_onecol_txn( | ||||
|                 txn, | ||||
|                 table="redactions", | ||||
|                 keyvalues={"redacts": ev.event_id}, | ||||
|                 retcol="event_id", | ||||
|             ) | ||||
| 
 | ||||
|             ev.unsigned["redacted_by"] = redaction_id | ||||
|             # Get the redaction event. | ||||
| 
 | ||||
|             because = self._get_event_txn( | ||||
|                 txn, | ||||
|                 redaction_id, | ||||
|                 check_redacted=False | ||||
|             ) | ||||
| 
 | ||||
|             if because: | ||||
|                 ev.unsigned["redacted_because"] = because | ||||
|             start_time = update_counter("redact_event", start_time) | ||||
| 
 | ||||
|         if get_prev_content and "replaces_state" in ev.unsigned: | ||||
|             prev = self._get_event_txn( | ||||
|                 txn, | ||||
|                 ev.unsigned["replaces_state"], | ||||
|                 get_prev_content=False, | ||||
|             ) | ||||
|             if prev: | ||||
|                 ev.unsigned["prev_content"] = prev.get_dict()["content"] | ||||
|             start_time = update_counter("get_prev_content", start_time) | ||||
| 
 | ||||
|         return ev | ||||
| 
 | ||||
|     def _parse_events(self, rows): | ||||
|         return self.runInteraction( | ||||
|             "_parse_events", self._parse_events_txn, rows | ||||
|  | ||||
| @ -85,7 +85,7 @@ class StateStore(SQLBaseStore): | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def c(vals): | ||||
|             vals[:] = yield self._fetch_events(vals, get_prev_content=False) | ||||
|             vals[:] = yield self._get_events(vals, get_prev_content=False) | ||||
| 
 | ||||
|         yield defer.gatherResults( | ||||
|             [ | ||||
|  | ||||
| @ -29,6 +29,34 @@ def unwrapFirstError(failure): | ||||
|     return failure.value.subFailure | ||||
| 
 | ||||
| 
 | ||||
| def unwrap_deferred(d): | ||||
|     """Given a deferred that we know has completed, return its value or raise | ||||
|     the failure as an exception | ||||
|     """ | ||||
|     if not d.called: | ||||
|         raise RuntimeError("deferred has not finished") | ||||
| 
 | ||||
|     res = [] | ||||
| 
 | ||||
|     def f(r): | ||||
|         res.append(r) | ||||
|         return r | ||||
|     d.addCallback(f) | ||||
| 
 | ||||
|     if res: | ||||
|         return res[0] | ||||
| 
 | ||||
|     def f(r): | ||||
|         res.append(r) | ||||
|         return r | ||||
|     d.addErrback(f) | ||||
| 
 | ||||
|     if res: | ||||
|         res[0].raiseException() | ||||
|     else: | ||||
|         raise RuntimeError("deferred did not call callbacks") | ||||
| 
 | ||||
| 
 | ||||
| class Clock(object): | ||||
|     """A small utility that obtains current time-of-day so that time may be | ||||
|     mocked during unit-tests. | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user