mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-31 00:01:33 +01:00 
			
		
		
		
	Merge pull request #613 from matrix-org/markjh/yield
Load the current id in the IdGenerator constructor
This commit is contained in:
		
						commit
						d50ca1b1ed
					
				| @ -115,13 +115,13 @@ class DataStore(RoomMemberStore, RoomStore, | |||||||
|             db_conn, "presence_stream", "stream_id" |             db_conn, "presence_stream", "stream_id" | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) |         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") | ||||||
|         self._state_groups_id_gen = IdGenerator("state_groups", "id", self) |         self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") | ||||||
|         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) |         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") | ||||||
|         self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) |         self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") | ||||||
|         self._pushers_id_gen = IdGenerator("pushers", "id", self) |         self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") | ||||||
|         self._push_rule_id_gen = IdGenerator("push_rules", "id", self) |         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") | ||||||
|         self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) |         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") | ||||||
| 
 | 
 | ||||||
|         events_max = self._stream_id_gen.get_max_token() |         events_max = self._stream_id_gen.get_max_token() | ||||||
|         event_cache_prefill, min_event_val = self._get_cache_dict( |         event_cache_prefill, min_event_val = self._get_cache_dict( | ||||||
|  | |||||||
| @ -163,12 +163,12 @@ class AccountDataStore(SQLBaseStore): | |||||||
|             ) |             ) | ||||||
|             self._update_max_stream_id(txn, next_id) |             self._update_max_stream_id(txn, next_id) | ||||||
| 
 | 
 | ||||||
|         with (yield self._account_data_id_gen.get_next(self)) as next_id: |         with self._account_data_id_gen.get_next() as next_id: | ||||||
|             yield self.runInteraction( |             yield self.runInteraction( | ||||||
|                 "add_room_account_data", add_account_data_txn, next_id |                 "add_room_account_data", add_account_data_txn, next_id | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         result = yield self._account_data_id_gen.get_max_token() |         result = self._account_data_id_gen.get_max_token() | ||||||
|         defer.returnValue(result) |         defer.returnValue(result) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
| @ -202,12 +202,12 @@ class AccountDataStore(SQLBaseStore): | |||||||
|             ) |             ) | ||||||
|             self._update_max_stream_id(txn, next_id) |             self._update_max_stream_id(txn, next_id) | ||||||
| 
 | 
 | ||||||
|         with (yield self._account_data_id_gen.get_next(self)) as next_id: |         with self._account_data_id_gen.get_next() as next_id: | ||||||
|             yield self.runInteraction( |             yield self.runInteraction( | ||||||
|                 "add_user_account_data", add_account_data_txn, next_id |                 "add_user_account_data", add_account_data_txn, next_id | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         result = yield self._account_data_id_gen.get_max_token() |         result = self._account_data_id_gen.get_max_token() | ||||||
|         defer.returnValue(result) |         defer.returnValue(result) | ||||||
| 
 | 
 | ||||||
|     def _update_max_stream_id(self, txn, next_id): |     def _update_max_stream_id(self, txn, next_id): | ||||||
|  | |||||||
| @ -75,8 +75,8 @@ class EventsStore(SQLBaseStore): | |||||||
|                 yield stream_orderings |                 yield stream_orderings | ||||||
|             stream_ordering_manager = stream_ordering_manager() |             stream_ordering_manager = stream_ordering_manager() | ||||||
|         else: |         else: | ||||||
|             stream_ordering_manager = yield self._stream_id_gen.get_next_mult( |             stream_ordering_manager = self._stream_id_gen.get_next_mult( | ||||||
|                 self, len(events_and_contexts) |                 len(events_and_contexts) | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         with stream_ordering_manager as stream_orderings: |         with stream_ordering_manager as stream_orderings: | ||||||
| @ -109,7 +109,7 @@ class EventsStore(SQLBaseStore): | |||||||
|             stream_ordering = self.min_stream_token |             stream_ordering = self.min_stream_token | ||||||
| 
 | 
 | ||||||
|         if stream_ordering is None: |         if stream_ordering is None: | ||||||
|             stream_ordering_manager = yield self._stream_id_gen.get_next(self) |             stream_ordering_manager = self._stream_id_gen.get_next() | ||||||
|         else: |         else: | ||||||
|             @contextmanager |             @contextmanager | ||||||
|             def stream_ordering_manager(): |             def stream_ordering_manager(): | ||||||
|  | |||||||
| @ -58,8 +58,8 @@ class UserPresenceState(namedtuple("UserPresenceState", | |||||||
| class PresenceStore(SQLBaseStore): | class PresenceStore(SQLBaseStore): | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
|     def update_presence(self, presence_states): |     def update_presence(self, presence_states): | ||||||
|         stream_ordering_manager = yield self._presence_id_gen.get_next_mult( |         stream_ordering_manager = self._presence_id_gen.get_next_mult( | ||||||
|             self, len(presence_states) |             len(presence_states) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         with stream_ordering_manager as stream_orderings: |         with stream_ordering_manager as stream_orderings: | ||||||
|  | |||||||
| @ -226,7 +226,7 @@ class PushRuleStore(SQLBaseStore): | |||||||
| 
 | 
 | ||||||
|         if txn.rowcount == 0: |         if txn.rowcount == 0: | ||||||
|             # We didn't update a row with the given rule_id so insert one |             # We didn't update a row with the given rule_id so insert one | ||||||
|             push_rule_id = self._push_rule_id_gen.get_next_txn(txn) |             push_rule_id = self._push_rule_id_gen.get_next() | ||||||
| 
 | 
 | ||||||
|             self._simple_insert_txn( |             self._simple_insert_txn( | ||||||
|                 txn, |                 txn, | ||||||
| @ -279,7 +279,7 @@ class PushRuleStore(SQLBaseStore): | |||||||
|         defer.returnValue(ret) |         defer.returnValue(ret) | ||||||
| 
 | 
 | ||||||
|     def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): |     def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): | ||||||
|         new_id = self._push_rules_enable_id_gen.get_next_txn(txn) |         new_id = self._push_rules_enable_id_gen.get_next() | ||||||
|         self._simple_upsert_txn( |         self._simple_upsert_txn( | ||||||
|             txn, |             txn, | ||||||
|             "push_rules_enable", |             "push_rules_enable", | ||||||
|  | |||||||
| @ -84,7 +84,7 @@ class PusherStore(SQLBaseStore): | |||||||
|                    app_display_name, device_display_name, |                    app_display_name, device_display_name, | ||||||
|                    pushkey, pushkey_ts, lang, data, profile_tag=""): |                    pushkey, pushkey_ts, lang, data, profile_tag=""): | ||||||
|         try: |         try: | ||||||
|             next_id = yield self._pushers_id_gen.get_next() |             next_id = self._pushers_id_gen.get_next() | ||||||
|             yield self._simple_upsert( |             yield self._simple_upsert( | ||||||
|                 "pushers", |                 "pushers", | ||||||
|                 dict( |                 dict( | ||||||
|  | |||||||
| @ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore): | |||||||
|                 "insert_receipt_conv", graph_to_linear |                 "insert_receipt_conv", graph_to_linear | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         stream_id_manager = yield self._receipts_id_gen.get_next(self) |         stream_id_manager = self._receipts_id_gen.get_next() | ||||||
|         with stream_id_manager as stream_id: |         with stream_id_manager as stream_id: | ||||||
|             have_persisted = yield self.runInteraction( |             have_persisted = yield self.runInteraction( | ||||||
|                 "insert_linearized_receipt", |                 "insert_linearized_receipt", | ||||||
| @ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore): | |||||||
|             room_id, receipt_type, user_id, event_ids, data |             room_id, receipt_type, user_id, event_ids, data | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         max_persisted_id = yield self._stream_id_gen.get_max_token() |         max_persisted_id = self._stream_id_gen.get_max_token() | ||||||
| 
 | 
 | ||||||
|         defer.returnValue((stream_id, max_persisted_id)) |         defer.returnValue((stream_id, max_persisted_id)) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore): | |||||||
|         Raises: |         Raises: | ||||||
|             StoreError if there was a problem adding this. |             StoreError if there was a problem adding this. | ||||||
|         """ |         """ | ||||||
|         next_id = yield self._access_tokens_id_gen.get_next() |         next_id = self._access_tokens_id_gen.get_next() | ||||||
| 
 | 
 | ||||||
|         yield self._simple_insert( |         yield self._simple_insert( | ||||||
|             "access_tokens", |             "access_tokens", | ||||||
| @ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore): | |||||||
|         Raises: |         Raises: | ||||||
|             StoreError if there was a problem adding this. |             StoreError if there was a problem adding this. | ||||||
|         """ |         """ | ||||||
|         next_id = yield self._refresh_tokens_id_gen.get_next() |         next_id = self._refresh_tokens_id_gen.get_next() | ||||||
| 
 | 
 | ||||||
|         yield self._simple_insert( |         yield self._simple_insert( | ||||||
|             "refresh_tokens", |             "refresh_tokens", | ||||||
| @ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore): | |||||||
|     def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): |     def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): | ||||||
|         now = int(self.clock.time()) |         now = int(self.clock.time()) | ||||||
| 
 | 
 | ||||||
|         next_id = self._access_tokens_id_gen.get_next_txn(txn) |         next_id = self._access_tokens_id_gen.get_next() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             if was_guest: |             if was_guest: | ||||||
|  | |||||||
| @ -83,7 +83,7 @@ class StateStore(SQLBaseStore): | |||||||
|             if event.is_state(): |             if event.is_state(): | ||||||
|                 state_events[(event.type, event.state_key)] = event |                 state_events[(event.type, event.state_key)] = event | ||||||
| 
 | 
 | ||||||
|             state_group = self._state_groups_id_gen.get_next_txn(txn) |             state_group = self._state_groups_id_gen.get_next() | ||||||
|             self._simple_insert_txn( |             self._simple_insert_txn( | ||||||
|                 txn, |                 txn, | ||||||
|                 table="state_groups", |                 table="state_groups", | ||||||
|  | |||||||
| @ -142,12 +142,12 @@ class TagsStore(SQLBaseStore): | |||||||
|             ) |             ) | ||||||
|             self._update_revision_txn(txn, user_id, room_id, next_id) |             self._update_revision_txn(txn, user_id, room_id, next_id) | ||||||
| 
 | 
 | ||||||
|         with (yield self._account_data_id_gen.get_next(self)) as next_id: |         with self._account_data_id_gen.get_next() as next_id: | ||||||
|             yield self.runInteraction("add_tag", add_tag_txn, next_id) |             yield self.runInteraction("add_tag", add_tag_txn, next_id) | ||||||
| 
 | 
 | ||||||
|         self.get_tags_for_user.invalidate((user_id,)) |         self.get_tags_for_user.invalidate((user_id,)) | ||||||
| 
 | 
 | ||||||
|         result = yield self._account_data_id_gen.get_max_token() |         result = self._account_data_id_gen.get_max_token() | ||||||
|         defer.returnValue(result) |         defer.returnValue(result) | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |     @defer.inlineCallbacks | ||||||
| @ -164,12 +164,12 @@ class TagsStore(SQLBaseStore): | |||||||
|             txn.execute(sql, (user_id, room_id, tag)) |             txn.execute(sql, (user_id, room_id, tag)) | ||||||
|             self._update_revision_txn(txn, user_id, room_id, next_id) |             self._update_revision_txn(txn, user_id, room_id, next_id) | ||||||
| 
 | 
 | ||||||
|         with (yield self._account_data_id_gen.get_next(self)) as next_id: |         with self._account_data_id_gen.get_next() as next_id: | ||||||
|             yield self.runInteraction("remove_tag", remove_tag_txn, next_id) |             yield self.runInteraction("remove_tag", remove_tag_txn, next_id) | ||||||
| 
 | 
 | ||||||
|         self.get_tags_for_user.invalidate((user_id,)) |         self.get_tags_for_user.invalidate((user_id,)) | ||||||
| 
 | 
 | ||||||
|         result = yield self._account_data_id_gen.get_max_token() |         result = self._account_data_id_gen.get_max_token() | ||||||
|         defer.returnValue(result) |         defer.returnValue(result) | ||||||
| 
 | 
 | ||||||
|     def _update_revision_txn(self, txn, user_id, room_id, next_id): |     def _update_revision_txn(self, txn, user_id, room_id, next_id): | ||||||
|  | |||||||
| @ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore): | |||||||
|     def _prep_send_transaction(self, txn, transaction_id, destination, |     def _prep_send_transaction(self, txn, transaction_id, destination, | ||||||
|                                origin_server_ts): |                                origin_server_ts): | ||||||
| 
 | 
 | ||||||
|         next_id = self._transaction_id_gen.get_next_txn(txn) |         next_id = self._transaction_id_gen.get_next() | ||||||
| 
 | 
 | ||||||
|         # First we find out what the prev_txns should be. |         # First we find out what the prev_txns should be. | ||||||
|         # Since we know that we are only sending one transaction at a time, |         # Since we know that we are only sending one transaction at a time, | ||||||
|  | |||||||
| @ -13,51 +13,30 @@ | |||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| 
 | 
 | ||||||
| from twisted.internet import defer |  | ||||||
| 
 |  | ||||||
| from collections import deque | from collections import deque | ||||||
| import contextlib | import contextlib | ||||||
| import threading | import threading | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class IdGenerator(object): | class IdGenerator(object): | ||||||
|     def __init__(self, table, column, store): |     def __init__(self, db_conn, table, column): | ||||||
|         self.table = table |         self.table = table | ||||||
|         self.column = column |         self.column = column | ||||||
|         self.store = store |  | ||||||
|         self._lock = threading.Lock() |         self._lock = threading.Lock() | ||||||
|         self._next_id = None |         cur = db_conn.cursor() | ||||||
|  |         self._next_id = self._load_next_id(cur) | ||||||
|  |         cur.close() | ||||||
|  | 
 | ||||||
|  |     def _load_next_id(self, txn): | ||||||
|  |         txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,)) | ||||||
|  |         val, = txn.fetchone() | ||||||
|  |         return val + 1 if val else 1 | ||||||
| 
 | 
 | ||||||
|     @defer.inlineCallbacks |  | ||||||
|     def get_next(self): |     def get_next(self): | ||||||
|         if self._next_id is None: |  | ||||||
|             yield self.store.runInteraction( |  | ||||||
|                 "IdGenerator_%s" % (self.table,), |  | ||||||
|                 self.get_next_txn, |  | ||||||
|             ) |  | ||||||
| 
 |  | ||||||
|         with self._lock: |         with self._lock: | ||||||
|             i = self._next_id |             i = self._next_id | ||||||
|             self._next_id += 1 |             self._next_id += 1 | ||||||
|             defer.returnValue(i) |             return i | ||||||
| 
 |  | ||||||
|     def get_next_txn(self, txn): |  | ||||||
|         with self._lock: |  | ||||||
|             if self._next_id: |  | ||||||
|                 i = self._next_id |  | ||||||
|                 self._next_id += 1 |  | ||||||
|                 return i |  | ||||||
|             else: |  | ||||||
|                 txn.execute( |  | ||||||
|                     "SELECT MAX(%s) FROM %s" % (self.column, self.table,) |  | ||||||
|                 ) |  | ||||||
| 
 |  | ||||||
|                 val, = txn.fetchone() |  | ||||||
|                 cur = val or 0 |  | ||||||
|                 cur += 1 |  | ||||||
|                 self._next_id = cur + 1 |  | ||||||
| 
 |  | ||||||
|                 return cur |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class StreamIdGenerator(object): | class StreamIdGenerator(object): | ||||||
| @ -69,7 +48,7 @@ class StreamIdGenerator(object): | |||||||
|     persistence of events can complete out of order. |     persistence of events can complete out of order. | ||||||
| 
 | 
 | ||||||
|     Usage: |     Usage: | ||||||
|         with stream_id_gen.get_next_txn(txn) as stream_id: |         with stream_id_gen.get_next() as stream_id: | ||||||
|             # ... persist event ... |             # ... persist event ... | ||||||
|     """ |     """ | ||||||
|     def __init__(self, db_conn, table, column): |     def __init__(self, db_conn, table, column): | ||||||
| @ -79,15 +58,21 @@ class StreamIdGenerator(object): | |||||||
|         self._lock = threading.Lock() |         self._lock = threading.Lock() | ||||||
| 
 | 
 | ||||||
|         cur = db_conn.cursor() |         cur = db_conn.cursor() | ||||||
|         self._current_max = self._get_or_compute_current_max(cur) |         self._current_max = self._load_current_max(cur) | ||||||
|         cur.close() |         cur.close() | ||||||
| 
 | 
 | ||||||
|         self._unfinished_ids = deque() |         self._unfinished_ids = deque() | ||||||
| 
 | 
 | ||||||
|     def get_next(self, store): |     def _load_current_max(self, txn): | ||||||
|  |         txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) | ||||||
|  |         rows = txn.fetchall() | ||||||
|  |         val, = rows[0] | ||||||
|  |         return int(val) if val else 1 | ||||||
|  | 
 | ||||||
|  |     def get_next(self): | ||||||
|         """ |         """ | ||||||
|         Usage: |         Usage: | ||||||
|             with yield stream_id_gen.get_next as stream_id: |             with stream_id_gen.get_next() as stream_id: | ||||||
|                 # ... persist event ... |                 # ... persist event ... | ||||||
|         """ |         """ | ||||||
|         with self._lock: |         with self._lock: | ||||||
| @ -106,10 +91,10 @@ class StreamIdGenerator(object): | |||||||
| 
 | 
 | ||||||
|         return manager() |         return manager() | ||||||
| 
 | 
 | ||||||
|     def get_next_mult(self, store, n): |     def get_next_mult(self, n): | ||||||
|         """ |         """ | ||||||
|         Usage: |         Usage: | ||||||
|             with yield stream_id_gen.get_next(store, n) as stream_ids: |             with stream_id_gen.get_next(n) as stream_ids: | ||||||
|                 # ... persist events ... |                 # ... persist events ... | ||||||
|         """ |         """ | ||||||
|         with self._lock: |         with self._lock: | ||||||
| @ -139,13 +124,3 @@ class StreamIdGenerator(object): | |||||||
|                 return self._unfinished_ids[0] - 1 |                 return self._unfinished_ids[0] - 1 | ||||||
| 
 | 
 | ||||||
|             return self._current_max |             return self._current_max | ||||||
| 
 |  | ||||||
|     def _get_or_compute_current_max(self, txn): |  | ||||||
|         with self._lock: |  | ||||||
|             txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) |  | ||||||
|             rows = txn.fetchall() |  | ||||||
|             val, = rows[0] |  | ||||||
| 
 |  | ||||||
|             self._current_max = int(val) if val else 1 |  | ||||||
| 
 |  | ||||||
|             return self._current_max |  | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user