# This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this file, # You can obtain one at http://mozilla.org/MPL/2.0/. import math import os from sqlalchemy import create_engine from sqlalchemy.sql import text as sqltext from util import get_timestamp # The maximum possible generation number. # Used as a tombstone to mark users that have been "retired" from the db. MAX_GENERATION = 9223372036854775807 NODE_FIELDS = ("capacity", "available", "current_load", "downed", "backoff") _GET_USER_RECORDS = sqltext("""\ select uid, nodes.node, generation, keys_changed_at, client_state, created_at, replaced_at from users left outer join nodes on users.nodeid = nodes.id where email = :email and users.service = :service order by created_at desc, uid desc limit 20 """) _CREATE_USER_RECORD = sqltext("""\ insert into users (service, email, nodeid, generation, keys_changed_at, client_state, created_at, replaced_at) values (:service, :email, :nodeid, :generation, :keys_changed_at, :client_state, :timestamp, NULL) """) # The `where` clause on this statement is designed as an extra layer of # protection, to ensure that concurrent updates don't accidentally move # timestamp fields backwards in time. The handling of `keys_changed_at` # is additionally weird because we want to treat the default `NULL` value # as zero. _UPDATE_USER_RECORD_IN_PLACE = sqltext("""\ update users set generation = COALESCE(:generation, generation), keys_changed_at = COALESCE(:keys_changed_at, keys_changed_at) where service = :service and email = :email and generation <= COALESCE(:generation, generation) and COALESCE(keys_changed_at, 0) <= COALESCE(:keys_changed_at, keys_changed_at, 0) and replaced_at is null """) _REPLACE_USER_RECORDS = sqltext("""\ update users set replaced_at = :timestamp where service = :service and email = :email and replaced_at is null and created_at < :timestamp """) # Mark all records for the user as replaced, # and set a large generation number to block future logins. _RETIRE_USER_RECORDS = sqltext("""\ update users set replaced_at = :timestamp, generation = :generation where email = :email and replaced_at is null """) _GET_OLD_USER_RECORDS_FOR_SERVICE = sqltext("""\ select uid, email, generation, keys_changed_at, client_state, nodes.node, nodes.downed, created_at, replaced_at from users left outer join nodes on users.nodeid = nodes.id where users.service = :service and replaced_at is not null and replaced_at < :timestamp order by replaced_at desc, uid desc limit :limit offset :offset """) _GET_OLD_USER_RECORDS_FOR_SERVICE_RANGE = """\ select uid, email, generation, keys_changed_at, client_state, nodes.node, nodes.downed, created_at, replaced_at from users left outer join nodes on users.nodeid = nodes.id where users.service = :service and ::RANGE:: and replaced_at is not null and replaced_at < :timestamp order by replaced_at desc, uid desc limit :limit offset :offset """ _GET_ALL_USER_RECORDS_FOR_SERVICE = sqltext("""\ select uid, nodes.node, created_at, replaced_at from users left outer join nodes on users.nodeid = nodes.id where email = :email and users.service = :service order by created_at asc, uid desc """) _REPLACE_USER_RECORD = sqltext("""\ update users set replaced_at = :timestamp where service = :service and uid = :uid """) _DELETE_USER_RECORD = sqltext("""\ delete from users where service = :service and uid = :uid """) _FREE_SLOT_ON_NODE = sqltext("""\ update nodes set available = available + 1, current_load = current_load - 1 where id = (SELECT nodeid FROM users WHERE service=:service AND uid=:uid) """) _COUNT_USER_RECORDS = sqltext("""\ select count(email) from users where replaced_at is null and created_at <= :timestamp """) _GET_BEST_NODE = sqltext("""\ select id, node from nodes where service = :service and available > 0 and capacity > current_load and downed = 0 and backoff = 0 order by log(current_load) / log(capacity) limit 1 """) _RELEASE_NODE_CAPACITY = sqltext("""\ update nodes set available = least(capacity * :capacity_release_rate, capacity - current_load) where service = :service and available <= 0 and capacity > current_load and downed = 0 """) _ADD_USER_TO_NODE = sqltext("""\ update nodes set current_load = current_load + 1, available = greatest(available - 1, 0) where service = :service and node = :node """) _GET_SERVICE_ID = sqltext("""\ select id from services where service = :service """) _GET_NODE = sqltext("""\ select * from nodes where service = :service and node = :node """) _GET_SPANNER_NODE = sqltext("""\ select id, node from nodes where id = :id limit 1 """) SERVICE_NAME = 'sync-1.5' class Database: def __init__(self): engine = create_engine(os.environ['SYNC_TOKENSERVER__DATABASE_URL']) self.database = engine. \ execution_options(isolation_level="AUTOCOMMIT"). \ connect() self.capacity_release_rate = os.environ. \ get("NODE_CAPACITY_RELEASE_RATE", 0.1) self.spanner_node_id = os.environ.get( "SYNC_TOKENSERVER__SPANNER_NODE_ID") self.spanner_node = None if self.spanner_node_id: self.spanner_node = self.get_spanner_node(self.spanner_node_id) def _execute_sql(self, *args, **kwds): return self.database.execute(*args, **kwds) def close(self): self.database.close() def get_user(self, email): params = {'service': self._get_service_id(SERVICE_NAME), 'email': email} res = self._execute_sql(_GET_USER_RECORDS, **params) try: # The query fetches rows ordered by created_at, but we want # to ensure that they're ordered by (generation, created_at). # This is almost always true, except for strange race conditions # during row creation. Sorting them is an easy way to enforce # this without bloating the db index. rows = res.fetchall() rows.sort(key=lambda r: (r.generation, r.created_at), reverse=True) if not rows: return None # The first row is the most up-to-date user record. # The rest give previously-seen client-state values. cur_row = rows[0] old_rows = rows[1:] user = { 'email': email, 'uid': cur_row.uid, 'node': cur_row.node, 'generation': cur_row.generation, 'keys_changed_at': cur_row.keys_changed_at or 0, 'client_state': cur_row.client_state, 'old_client_states': {}, 'first_seen_at': cur_row.created_at, } # If the current row is marked as replaced or is missing a node, # and they haven't been retired, then assign them a new node. if cur_row.replaced_at is not None or cur_row.node is None: if cur_row.generation < MAX_GENERATION: user = self.allocate_user(email, cur_row.generation, cur_row.client_state, cur_row.keys_changed_at) for old_row in old_rows: # Collect any previously-seen client-state values. if old_row.client_state != user['client_state']: user['old_client_states'][old_row.client_state] = True # Make sure each old row is marked as replaced. # They might not be, due to races in row creation. if old_row.replaced_at is None: timestamp = cur_row.created_at self.replace_user_record(old_row.uid, timestamp) # Track backwards to the oldest timestamp at which we saw them. user['first_seen_at'] = old_row.created_at return user finally: res.close() def allocate_user(self, email, generation=0, client_state='', keys_changed_at=0, node=None, timestamp=None): if timestamp is None: timestamp = get_timestamp() if node is None: nodeid, node = self.get_best_node() else: nodeid = self.get_node_id(node) params = { 'service': self._get_service_id(SERVICE_NAME), 'email': email, 'nodeid': nodeid, 'generation': generation, 'keys_changed_at': keys_changed_at, 'client_state': client_state, 'timestamp': timestamp } res = self._execute_sql(_CREATE_USER_RECORD, **params) return { 'email': email, 'uid': res.lastrowid, 'node': node, 'generation': generation, 'keys_changed_at': keys_changed_at, 'client_state': client_state, 'old_client_states': {}, 'first_seen_at': timestamp, } def update_user(self, user, generation=None, client_state=None, keys_changed_at=None, node=None): if client_state is None and node is None: # No need for a node-reassignment, just update the row in place. # Note that if we're changing keys_changed_at without changing # client_state, it's because we're seeing an existing value of # keys_changed_at for the first time. params = { 'service': self._get_service_id(SERVICE_NAME), 'email': user['email'], 'generation': generation, 'keys_changed_at': keys_changed_at } res = self._execute_sql(_UPDATE_USER_RECORD_IN_PLACE, **params) res.close() if generation is not None: user['generation'] = max(user['generation'], generation) user['keys_changed_at'] = max_keys_changed_at( user, keys_changed_at ) else: # Reject previously-seen client-state strings. if client_state is None: client_state = user['client_state'] else: if client_state == user['client_state']: raise Exception('previously seen client-state string') if client_state in user['old_client_states']: raise Exception('previously seen client-state string') # Need to create a new record for new user state. # If the node is not explicitly changing, try to keep them on the # same node, but if e.g. it no longer exists them allocate them to # a new one. if node is not None: nodeid = self.get_node_id(node) user['node'] = node else: try: nodeid = self.get_node_id(user['node']) except ValueError: nodeid, node = self.get_best_node() user['node'] = node if generation is not None: generation = max(user['generation'], generation) else: generation = user['generation'] keys_changed_at = max_keys_changed_at(user, keys_changed_at) now = get_timestamp() params = { 'service': self._get_service_id(SERVICE_NAME), 'email': user['email'], 'nodeid': nodeid, 'generation': generation, 'keys_changed_at': keys_changed_at, 'client_state': client_state, 'timestamp': now, } res = self._execute_sql(_CREATE_USER_RECORD, **params) res.close() user['uid'] = res.lastrowid user['generation'] = generation user['keys_changed_at'] = keys_changed_at user['old_client_states'][user['client_state']] = True user['client_state'] = client_state # mark old records as having been replaced. # if we crash here, they are unmarked and we may fail to # garbage collect them for a while, but the active state # will be undamaged. self.replace_user_records(user['email'], now) def retire_user(self, email): now = get_timestamp() params = { 'email': email, 'timestamp': now, 'generation': MAX_GENERATION } # Pass through explicit engine to help with sharded implementation, # since we can't shard by service name here. res = self._execute_sql(_RETIRE_USER_RECORDS, **params) res.close() def count_users(self, timestamp=None): if timestamp is None: timestamp = get_timestamp() res = self._execute_sql(_COUNT_USER_RECORDS, timestamp=timestamp) row = res.fetchone() res.close() return row[0] # # Methods for low-level user record management. # def get_user_records(self, email): """Get all the user's records, including the old ones.""" params = {'service': self._get_service_id(SERVICE_NAME), 'email': email} res = self._execute_sql(_GET_ALL_USER_RECORDS_FOR_SERVICE, **params) try: for row in res: yield row finally: res.close() def _build_old_user_query(self, uid_range, params, **kwargs): if uid_range: # construct the range from the passed arguments rstr = [] try: if uid_range[0]: rstr.append("uid > :start") params["start"] = uid_range[0] if uid_range[1]: rstr.append("uid < :end") params["end"] = uid_range[1] except IndexError: pass rrep = " and ".join(rstr) sql = sqltext( _GET_OLD_USER_RECORDS_FOR_SERVICE_RANGE.replace( "::RANGE::", rrep)) else: sql = _GET_OLD_USER_RECORDS_FOR_SERVICE return sql def get_old_user_records(self, grace_period=-1, limit=100, offset=0, uid_range=None): """Get user records that were replaced outside the grace period.""" if grace_period < 0: grace_period = 60 * 60 * 24 * 7 # one week, in seconds grace_period = int(grace_period * 1000) # convert seconds -> millis params = { "service": self._get_service_id(SERVICE_NAME), "timestamp": get_timestamp() - grace_period, "limit": limit, "offset": offset } sql = self._build_old_user_query(uid_range, params) res = self._execute_sql(sql, **params) try: for row in res: yield row finally: res.close() def replace_user_records(self, email, timestamp=None): """Mark all existing records for a user as replaced.""" if timestamp is None: timestamp = get_timestamp() params = { 'service': self._get_service_id(SERVICE_NAME), 'email': email, 'timestamp': timestamp } res = self._execute_sql(_REPLACE_USER_RECORDS, **params) res.close() def replace_user_record(self, uid, timestamp=None): """Mark an existing service record as replaced.""" if timestamp is None: timestamp = get_timestamp() params = { 'service': self._get_service_id(SERVICE_NAME), 'uid': uid, 'timestamp': timestamp } res = self._execute_sql(_REPLACE_USER_RECORD, **params) res.close() def delete_user_record(self, uid): """Delete the user record with the given uid.""" params = {'service': self._get_service_id(SERVICE_NAME), 'uid': uid} if not self.spanner_node_id: res = self._execute_sql(_FREE_SLOT_ON_NODE, **params) res.close() res = self._execute_sql(_DELETE_USER_RECORD, **params) res.close() # # Nodes management # def _get_service_id(self, service): if hasattr(self, 'service_id'): return self.service_id else: res = self._execute_sql(_GET_SERVICE_ID, service=service) row = res.fetchone() res.close() if row is None: raise Exception('unknown service: ' + service) self.service_id = row.id return row.id def add_service(self, service_name, pattern, **kwds): """Add definition for a new service.""" res = self._execute_sql(sqltext(""" insert into services (service, pattern) values (:servicename, :pattern) """), servicename=service_name, pattern=pattern, **kwds) res.close() return res.lastrowid def add_node(self, node, capacity, **kwds): """Add definition for a new node.""" available = kwds.get('available') # We release only a fraction of the node's capacity to start. if available is None: available = math.ceil(capacity * self.capacity_release_rate) cols = ["service", "node", "available", "capacity", "current_load", "downed", "backoff"] args = [":" + v for v in cols] # Handle test cases that require nodeid to be 800 if "nodeid" in kwds: cols.append("id") args.append(":nodeid") query = """ insert into nodes ({cols}) values ({args}) """.format(cols=", ".join(cols), args=", ".join(args)) res = self._execute_sql( sqltext(query), nodeid=kwds.get('nodeid'), service=self._get_service_id(SERVICE_NAME), node=node, capacity=capacity, available=available, current_load=kwds.get('current_load', 0), downed=kwds.get('downed', 0), backoff=kwds.get('backoff', 0), ) res.close() def update_node(self, node, **kwds): """Updates node fields in the db.""" values = {} cols = NODE_FIELDS & kwds.keys() for col in NODE_FIELDS: try: values[col] = kwds.pop(col) except KeyError: pass args = [v + " = :" + v for v in cols] query = """ update nodes set """ query += ", ".join(args) query += """ where service = :service and node = :node """ values['service'] = self._get_service_id(SERVICE_NAME) values['node'] = node if kwds: raise ValueError("unknown fields: " + str(kwds.keys())) con = self._execute_sql(sqltext(query), **values) con.close() def get_node_id(self, node): """Get numeric id for a node.""" res = self._execute_sql( sqltext(""" select id from nodes where service=:service and node=:node """), service=self._get_service_id(SERVICE_NAME), node=node ) row = res.fetchone() res.close() if row is None: raise ValueError("unknown node: " + node) return row[0] def remove_node(self, node, timestamp=None): """Remove definition for a node.""" nodeid = self.get_node_id(node) res = self._execute_sql(sqltext( """ delete from nodes where id=:nodeid """), nodeid=nodeid ) res.close() self.unassign_node(node, timestamp, nodeid=nodeid) def unassign_node(self, node, timestamp=None, nodeid=None): """Clear any assignments to a node.""" if timestamp is None: timestamp = get_timestamp() if nodeid is None: nodeid = self.get_node_id(node) res = self._execute_sql( sqltext(""" update users set replaced_at=:timestamp where nodeid=:nodeid """), nodeid=nodeid, timestamp=timestamp ) res.close() def get_best_node(self): """Returns the 'least loaded' node currently available, increments the active count on that node, and decrements the slots currently available """ # The spanner node is the best node. if self.spanner_node: return self.spanner_node_id, self.spanner_node # if, for whatever reason, we haven't gotten the spanner node yet... if self.spanner_node_id: self.spanner_node = self.get_spanner_node(self.spanner_node_id) return self.spanner_node_id, self.spanner_node else: # We may have to re-try the query if we need to release more # capacity. This loop allows a maximum of five retries before # bailing out. for _ in range(5): res = self._execute_sql( _GET_BEST_NODE, service=self._get_service_id(SERVICE_NAME)) row = res.fetchone() res.close() if row is None: # Try to release additional capacity from any nodes # that are not fully occupied. res = self._execute_sql( _RELEASE_NODE_CAPACITY, capacity_release_rate=self.capacity_release_rate, service=self._get_service_id(SERVICE_NAME) ) res.close() if res.rowcount == 0: break else: break # Did we succeed in finding a node? if row is None: raise Exception('unable to get a node') nodeid = row.id node = str(row.node) # Update the node to reflect the new assignment. # This is a little racy with concurrent assignments, but no big # deal. con = self._execute_sql(_ADD_USER_TO_NODE, service=self._get_service_id(SERVICE_NAME), node=node) con.close() return nodeid, node def get_node(self, node): if node is None: raise Exception("NONE node") res = self._execute_sql(_GET_NODE, service=self._get_service_id(SERVICE_NAME), node=node) row = res.fetchone() res.close() if row is None: raise Exception('unknown node: ' + node) return row # somewhat simplified version that just gets the one Spanner node. def get_spanner_node(self, node): res = self._execute_sql(_GET_SPANNER_NODE, id=node) row = res.fetchone() res.close() if row is None: raise Exception(f'unknown node: {node}') return str(row.node) def max_keys_changed_at(user, keys_changed_at): """Return the largest `keys_changed_at` between the user record and the specified value. May return `None` as the column is nullable. """ it = ( x for x in (keys_changed_at, user['keys_changed_at']) if x is not None ) return max(it, default=None)