mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-11-04 02:01:03 +01:00 
			
		
		
		
	Merge pull request #613 from matrix-org/markjh/yield
Load the current id in the IdGenerator constructor
This commit is contained in:
		
						commit
						d50ca1b1ed
					
				@ -115,13 +115,13 @@ class DataStore(RoomMemberStore, RoomStore,
 | 
			
		||||
            db_conn, "presence_stream", "stream_id"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
 | 
			
		||||
        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
 | 
			
		||||
        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
 | 
			
		||||
        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
 | 
			
		||||
        self._pushers_id_gen = IdGenerator("pushers", "id", self)
 | 
			
		||||
        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
 | 
			
		||||
        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
 | 
			
		||||
        self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
 | 
			
		||||
        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
 | 
			
		||||
        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
 | 
			
		||||
        self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
 | 
			
		||||
        self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
 | 
			
		||||
        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
 | 
			
		||||
        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
 | 
			
		||||
 | 
			
		||||
        events_max = self._stream_id_gen.get_max_token()
 | 
			
		||||
        event_cache_prefill, min_event_val = self._get_cache_dict(
 | 
			
		||||
 | 
			
		||||
@ -163,12 +163,12 @@ class AccountDataStore(SQLBaseStore):
 | 
			
		||||
            )
 | 
			
		||||
            self._update_max_stream_id(txn, next_id)
 | 
			
		||||
 | 
			
		||||
        with (yield self._account_data_id_gen.get_next(self)) as next_id:
 | 
			
		||||
        with self._account_data_id_gen.get_next() as next_id:
 | 
			
		||||
            yield self.runInteraction(
 | 
			
		||||
                "add_room_account_data", add_account_data_txn, next_id
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        result = yield self._account_data_id_gen.get_max_token()
 | 
			
		||||
        result = self._account_data_id_gen.get_max_token()
 | 
			
		||||
        defer.returnValue(result)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
@ -202,12 +202,12 @@ class AccountDataStore(SQLBaseStore):
 | 
			
		||||
            )
 | 
			
		||||
            self._update_max_stream_id(txn, next_id)
 | 
			
		||||
 | 
			
		||||
        with (yield self._account_data_id_gen.get_next(self)) as next_id:
 | 
			
		||||
        with self._account_data_id_gen.get_next() as next_id:
 | 
			
		||||
            yield self.runInteraction(
 | 
			
		||||
                "add_user_account_data", add_account_data_txn, next_id
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        result = yield self._account_data_id_gen.get_max_token()
 | 
			
		||||
        result = self._account_data_id_gen.get_max_token()
 | 
			
		||||
        defer.returnValue(result)
 | 
			
		||||
 | 
			
		||||
    def _update_max_stream_id(self, txn, next_id):
 | 
			
		||||
 | 
			
		||||
@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore):
 | 
			
		||||
                yield stream_orderings
 | 
			
		||||
            stream_ordering_manager = stream_ordering_manager()
 | 
			
		||||
        else:
 | 
			
		||||
            stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
 | 
			
		||||
                self, len(events_and_contexts)
 | 
			
		||||
            stream_ordering_manager = self._stream_id_gen.get_next_mult(
 | 
			
		||||
                len(events_and_contexts)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        with stream_ordering_manager as stream_orderings:
 | 
			
		||||
@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore):
 | 
			
		||||
            stream_ordering = self.min_stream_token
 | 
			
		||||
 | 
			
		||||
        if stream_ordering is None:
 | 
			
		||||
            stream_ordering_manager = yield self._stream_id_gen.get_next(self)
 | 
			
		||||
            stream_ordering_manager = self._stream_id_gen.get_next()
 | 
			
		||||
        else:
 | 
			
		||||
            @contextmanager
 | 
			
		||||
            def stream_ordering_manager():
 | 
			
		||||
 | 
			
		||||
@ -58,8 +58,8 @@ class UserPresenceState(namedtuple("UserPresenceState",
 | 
			
		||||
class PresenceStore(SQLBaseStore):
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def update_presence(self, presence_states):
 | 
			
		||||
        stream_ordering_manager = yield self._presence_id_gen.get_next_mult(
 | 
			
		||||
            self, len(presence_states)
 | 
			
		||||
        stream_ordering_manager = self._presence_id_gen.get_next_mult(
 | 
			
		||||
            len(presence_states)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with stream_ordering_manager as stream_orderings:
 | 
			
		||||
 | 
			
		||||
@ -226,7 +226,7 @@ class PushRuleStore(SQLBaseStore):
 | 
			
		||||
 | 
			
		||||
        if txn.rowcount == 0:
 | 
			
		||||
            # We didn't update a row with the given rule_id so insert one
 | 
			
		||||
            push_rule_id = self._push_rule_id_gen.get_next_txn(txn)
 | 
			
		||||
            push_rule_id = self._push_rule_id_gen.get_next()
 | 
			
		||||
 | 
			
		||||
            self._simple_insert_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
@ -279,7 +279,7 @@ class PushRuleStore(SQLBaseStore):
 | 
			
		||||
        defer.returnValue(ret)
 | 
			
		||||
 | 
			
		||||
    def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
 | 
			
		||||
        new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
 | 
			
		||||
        new_id = self._push_rules_enable_id_gen.get_next()
 | 
			
		||||
        self._simple_upsert_txn(
 | 
			
		||||
            txn,
 | 
			
		||||
            "push_rules_enable",
 | 
			
		||||
 | 
			
		||||
@ -84,7 +84,7 @@ class PusherStore(SQLBaseStore):
 | 
			
		||||
                   app_display_name, device_display_name,
 | 
			
		||||
                   pushkey, pushkey_ts, lang, data, profile_tag=""):
 | 
			
		||||
        try:
 | 
			
		||||
            next_id = yield self._pushers_id_gen.get_next()
 | 
			
		||||
            next_id = self._pushers_id_gen.get_next()
 | 
			
		||||
            yield self._simple_upsert(
 | 
			
		||||
                "pushers",
 | 
			
		||||
                dict(
 | 
			
		||||
 | 
			
		||||
@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore):
 | 
			
		||||
                "insert_receipt_conv", graph_to_linear
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        stream_id_manager = yield self._receipts_id_gen.get_next(self)
 | 
			
		||||
        stream_id_manager = self._receipts_id_gen.get_next()
 | 
			
		||||
        with stream_id_manager as stream_id:
 | 
			
		||||
            have_persisted = yield self.runInteraction(
 | 
			
		||||
                "insert_linearized_receipt",
 | 
			
		||||
@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
 | 
			
		||||
            room_id, receipt_type, user_id, event_ids, data
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        max_persisted_id = yield self._stream_id_gen.get_max_token()
 | 
			
		||||
        max_persisted_id = self._stream_id_gen.get_max_token()
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((stream_id, max_persisted_id))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore):
 | 
			
		||||
        Raises:
 | 
			
		||||
            StoreError if there was a problem adding this.
 | 
			
		||||
        """
 | 
			
		||||
        next_id = yield self._access_tokens_id_gen.get_next()
 | 
			
		||||
        next_id = self._access_tokens_id_gen.get_next()
 | 
			
		||||
 | 
			
		||||
        yield self._simple_insert(
 | 
			
		||||
            "access_tokens",
 | 
			
		||||
@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
 | 
			
		||||
        Raises:
 | 
			
		||||
            StoreError if there was a problem adding this.
 | 
			
		||||
        """
 | 
			
		||||
        next_id = yield self._refresh_tokens_id_gen.get_next()
 | 
			
		||||
        next_id = self._refresh_tokens_id_gen.get_next()
 | 
			
		||||
 | 
			
		||||
        yield self._simple_insert(
 | 
			
		||||
            "refresh_tokens",
 | 
			
		||||
@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
 | 
			
		||||
    def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
 | 
			
		||||
        now = int(self.clock.time())
 | 
			
		||||
 | 
			
		||||
        next_id = self._access_tokens_id_gen.get_next_txn(txn)
 | 
			
		||||
        next_id = self._access_tokens_id_gen.get_next()
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if was_guest:
 | 
			
		||||
 | 
			
		||||
@ -83,7 +83,7 @@ class StateStore(SQLBaseStore):
 | 
			
		||||
            if event.is_state():
 | 
			
		||||
                state_events[(event.type, event.state_key)] = event
 | 
			
		||||
 | 
			
		||||
            state_group = self._state_groups_id_gen.get_next_txn(txn)
 | 
			
		||||
            state_group = self._state_groups_id_gen.get_next()
 | 
			
		||||
            self._simple_insert_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="state_groups",
 | 
			
		||||
 | 
			
		||||
@ -142,12 +142,12 @@ class TagsStore(SQLBaseStore):
 | 
			
		||||
            )
 | 
			
		||||
            self._update_revision_txn(txn, user_id, room_id, next_id)
 | 
			
		||||
 | 
			
		||||
        with (yield self._account_data_id_gen.get_next(self)) as next_id:
 | 
			
		||||
        with self._account_data_id_gen.get_next() as next_id:
 | 
			
		||||
            yield self.runInteraction("add_tag", add_tag_txn, next_id)
 | 
			
		||||
 | 
			
		||||
        self.get_tags_for_user.invalidate((user_id,))
 | 
			
		||||
 | 
			
		||||
        result = yield self._account_data_id_gen.get_max_token()
 | 
			
		||||
        result = self._account_data_id_gen.get_max_token()
 | 
			
		||||
        defer.returnValue(result)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
@ -164,12 +164,12 @@ class TagsStore(SQLBaseStore):
 | 
			
		||||
            txn.execute(sql, (user_id, room_id, tag))
 | 
			
		||||
            self._update_revision_txn(txn, user_id, room_id, next_id)
 | 
			
		||||
 | 
			
		||||
        with (yield self._account_data_id_gen.get_next(self)) as next_id:
 | 
			
		||||
        with self._account_data_id_gen.get_next() as next_id:
 | 
			
		||||
            yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
 | 
			
		||||
 | 
			
		||||
        self.get_tags_for_user.invalidate((user_id,))
 | 
			
		||||
 | 
			
		||||
        result = yield self._account_data_id_gen.get_max_token()
 | 
			
		||||
        result = self._account_data_id_gen.get_max_token()
 | 
			
		||||
        defer.returnValue(result)
 | 
			
		||||
 | 
			
		||||
    def _update_revision_txn(self, txn, user_id, room_id, next_id):
 | 
			
		||||
 | 
			
		||||
@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore):
 | 
			
		||||
    def _prep_send_transaction(self, txn, transaction_id, destination,
 | 
			
		||||
                               origin_server_ts):
 | 
			
		||||
 | 
			
		||||
        next_id = self._transaction_id_gen.get_next_txn(txn)
 | 
			
		||||
        next_id = self._transaction_id_gen.get_next()
 | 
			
		||||
 | 
			
		||||
        # First we find out what the prev_txns should be.
 | 
			
		||||
        # Since we know that we are only sending one transaction at a time,
 | 
			
		||||
 | 
			
		||||
@ -13,51 +13,30 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from collections import deque
 | 
			
		||||
import contextlib
 | 
			
		||||
import threading
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IdGenerator(object):
 | 
			
		||||
    def __init__(self, table, column, store):
 | 
			
		||||
    def __init__(self, db_conn, table, column):
 | 
			
		||||
        self.table = table
 | 
			
		||||
        self.column = column
 | 
			
		||||
        self.store = store
 | 
			
		||||
        self._lock = threading.Lock()
 | 
			
		||||
        self._next_id = None
 | 
			
		||||
        cur = db_conn.cursor()
 | 
			
		||||
        self._next_id = self._load_next_id(cur)
 | 
			
		||||
        cur.close()
 | 
			
		||||
 | 
			
		||||
    def _load_next_id(self, txn):
 | 
			
		||||
        txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
 | 
			
		||||
        val, = txn.fetchone()
 | 
			
		||||
        return val + 1 if val else 1
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_next(self):
 | 
			
		||||
        if self._next_id is None:
 | 
			
		||||
            yield self.store.runInteraction(
 | 
			
		||||
                "IdGenerator_%s" % (self.table,),
 | 
			
		||||
                self.get_next_txn,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            i = self._next_id
 | 
			
		||||
            self._next_id += 1
 | 
			
		||||
            defer.returnValue(i)
 | 
			
		||||
 | 
			
		||||
    def get_next_txn(self, txn):
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            if self._next_id:
 | 
			
		||||
            i = self._next_id
 | 
			
		||||
            self._next_id += 1
 | 
			
		||||
            return i
 | 
			
		||||
            else:
 | 
			
		||||
                txn.execute(
 | 
			
		||||
                    "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                val, = txn.fetchone()
 | 
			
		||||
                cur = val or 0
 | 
			
		||||
                cur += 1
 | 
			
		||||
                self._next_id = cur + 1
 | 
			
		||||
 | 
			
		||||
                return cur
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StreamIdGenerator(object):
 | 
			
		||||
@ -69,7 +48,7 @@ class StreamIdGenerator(object):
 | 
			
		||||
    persistence of events can complete out of order.
 | 
			
		||||
 | 
			
		||||
    Usage:
 | 
			
		||||
        with stream_id_gen.get_next_txn(txn) as stream_id:
 | 
			
		||||
        with stream_id_gen.get_next() as stream_id:
 | 
			
		||||
            # ... persist event ...
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, db_conn, table, column):
 | 
			
		||||
@ -79,15 +58,21 @@ class StreamIdGenerator(object):
 | 
			
		||||
        self._lock = threading.Lock()
 | 
			
		||||
 | 
			
		||||
        cur = db_conn.cursor()
 | 
			
		||||
        self._current_max = self._get_or_compute_current_max(cur)
 | 
			
		||||
        self._current_max = self._load_current_max(cur)
 | 
			
		||||
        cur.close()
 | 
			
		||||
 | 
			
		||||
        self._unfinished_ids = deque()
 | 
			
		||||
 | 
			
		||||
    def get_next(self, store):
 | 
			
		||||
    def _load_current_max(self, txn):
 | 
			
		||||
        txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
 | 
			
		||||
        rows = txn.fetchall()
 | 
			
		||||
        val, = rows[0]
 | 
			
		||||
        return int(val) if val else 1
 | 
			
		||||
 | 
			
		||||
    def get_next(self):
 | 
			
		||||
        """
 | 
			
		||||
        Usage:
 | 
			
		||||
            with yield stream_id_gen.get_next as stream_id:
 | 
			
		||||
            with stream_id_gen.get_next() as stream_id:
 | 
			
		||||
                # ... persist event ...
 | 
			
		||||
        """
 | 
			
		||||
        with self._lock:
 | 
			
		||||
@ -106,10 +91,10 @@ class StreamIdGenerator(object):
 | 
			
		||||
 | 
			
		||||
        return manager()
 | 
			
		||||
 | 
			
		||||
    def get_next_mult(self, store, n):
 | 
			
		||||
    def get_next_mult(self, n):
 | 
			
		||||
        """
 | 
			
		||||
        Usage:
 | 
			
		||||
            with yield stream_id_gen.get_next(store, n) as stream_ids:
 | 
			
		||||
            with stream_id_gen.get_next(n) as stream_ids:
 | 
			
		||||
                # ... persist events ...
 | 
			
		||||
        """
 | 
			
		||||
        with self._lock:
 | 
			
		||||
@ -139,13 +124,3 @@ class StreamIdGenerator(object):
 | 
			
		||||
                return self._unfinished_ids[0] - 1
 | 
			
		||||
 | 
			
		||||
            return self._current_max
 | 
			
		||||
 | 
			
		||||
    def _get_or_compute_current_max(self, txn):
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
 | 
			
		||||
            rows = txn.fetchall()
 | 
			
		||||
            val, = rows[0]
 | 
			
		||||
 | 
			
		||||
            self._current_max = int(val) if val else 1
 | 
			
		||||
 | 
			
		||||
            return self._current_max
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user