diff --git a/.circleci/config.yml b/.circleci/config.yml index 8a6d1197..2a1388cf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -38,7 +38,7 @@ commands: command: | cargo fmt -- --check # https://github.com/bodil/sized-chunks/issues/11 - cargo audit --ignore RUSTSEC-2020-0041 --ignore RUSTSEC-2021-0078 --ignore RUSTSEC-2021-0079 --ignore RUSTSEC-2020-0159 --ignore RUSTSEC-2020-0071 + cargo audit --ignore RUSTSEC-2020-0041 --ignore RUSTSEC-2021-0078 --ignore RUSTSEC-2021-0079 --ignore RUSTSEC-2020-0159 --ignore RUSTSEC-2020-0071 --ignore RUSTSEC-2021-0124 python-check: steps: - run: @@ -46,6 +46,7 @@ commands: command: | flake8 src/tokenserver/verify.py flake8 tools/integration_tests + flake8 tools/tokenserver rust-clippy: steps: - run: @@ -113,6 +114,16 @@ commands: environment: SYNCSTORAGE_RS_IMAGE: app:build + run-tokenserver-scripts-tests: + steps: + - run: + name: Tokenserver scripts tests + command: > + pip3 install -r tools/tokenserver/requirements.txt && + python3 tools/tokenserver/run_tests.py + environment: + SYNCSTORAGE_RS_IMAGE: app:build + run-e2e-spanner-tests: steps: - run: @@ -211,6 +222,7 @@ jobs: - write-version - cargo-build - run-tests + - run-tokenserver-scripts-tests #- save-sccache-cache - run: name: Build Docker image diff --git a/tools/tokenserver/__init__.py b/tools/tokenserver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/tokenserver/add_node.py b/tools/tokenserver/add_node.py new file mode 100644 index 00000000..e86b01db --- /dev/null +++ b/tools/tokenserver/add_node.py @@ -0,0 +1,80 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. + +""" + +Script to add a new node to the system. + +""" + +import logging +import optparse + +from database import Database, SERVICE_NAME +import util + + +logger = logging.getLogger("tokenserver.scripts.add_node") + + +def add_node(node, capacity, **kwds): + """Add the specific node to the system.""" + logger.info("Adding node %s to service %s", node, SERVICE_NAME) + try: + database = Database() + database.add_node(node, capacity, **kwds) + except Exception: + logger.exception("Error while adding node") + return False + else: + logger.info("Finished adding node %s", node) + return True + + +def main(args=None): + """Main entry-point for running this script. + + This function parses command-line arguments and passes them on + to the add_node() function. + """ + usage = "usage: %prog [options] node_name capacity" + descr = "Add a new node to the tokenserver database" + parser = optparse.OptionParser(usage=usage, description=descr) + parser.add_option("", "--available", type="int", + help="How many user slots the node has available") + parser.add_option("", "--current-load", type="int", + help="How many user slots the node has occupied") + parser.add_option("", "--downed", action="store_true", + help="Mark the node as down in the db") + parser.add_option("", "--backoff", action="store_true", + help="Mark the node as backed-off in the db") + parser.add_option("-v", "--verbose", action="count", dest="verbosity", + help="Control verbosity of log messages") + + opts, args = parser.parse_args(args) + if len(args) != 2: + parser.print_usage() + return 1 + + util.configure_script_logging(opts) + + node_name = args[0] + capacity = int(args[1]) + + kwds = {} + if opts.available is not None: + kwds["available"] = opts.available + if opts.current_load is not None: + kwds["current_load"] = opts.current_load + if opts.backoff is not None: + kwds["backoff"] = opts.backoff + if opts.downed is not None: + kwds["downed"] = opts.downed + + add_node(node_name, capacity, **kwds) + return 0 + + +if __name__ == "__main__": + util.run_script(main) diff --git a/tools/tokenserver/allocate_user.py b/tools/tokenserver/allocate_user.py new file mode 100644 index 00000000..197ac7ec --- /dev/null +++ b/tools/tokenserver/allocate_user.py @@ -0,0 +1,73 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +Script to allocate a specific user to a node. + +This script allocates the specified user to a node. A particular node +may be specified, or the best available node used by default. + +The allocated node is printed to stdout. + +""" + +import logging +import optparse + +from database import Database +import util + + +logger = logging.getLogger("tokenserver.scripts.allocate_user") + + +def allocate_user(email, node=None): + logger.info("Allocating node for user %s", email) + try: + database = Database() + user = database.get_user(email) + if user is None: + user = database.allocate_user(email, node=node) + else: + database.update_user(user, node=node) + except Exception: + logger.exception("Error while updating node") + return False + else: + logger.info("Finished updating node %s", node) + return True + + +def main(args=None): + """Main entry-point for running this script. + + This function parses command-line arguments and passes them on + to the allocate_user() function. + """ + usage = "usage: %prog [options] email [node_name]" + descr = "Allocate a user to a node. You may specify a particular node, "\ + "or omit to use the best available node." + parser = optparse.OptionParser(usage=usage, description=descr) + parser.add_option("-v", "--verbose", action="count", dest="verbosity", + help="Control verbosity of log messages") + + opts, args = parser.parse_args(args) + if not 1 <= len(args) <= 2: + parser.print_usage() + return 1 + + util.configure_script_logging(opts) + + email = args[0] + if len(args) == 1: + node_name = None + else: + node_name = args[1] + + allocate_user(email, node_name) + return 0 + + +if __name__ == "__main__": + util.run_script(main) diff --git a/tools/tokenserver/count_users.py b/tools/tokenserver/count_users.py new file mode 100644 index 00000000..f2963ebd --- /dev/null +++ b/tools/tokenserver/count_users.py @@ -0,0 +1,98 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +Script to emit total-user-count metrics for exec dashboard. + +""" + +import json +import logging +import optparse +import os +import socket +import sys +import time +from datetime import datetime, timedelta, tzinfo + +from database import Database +import util + +logger = logging.getLogger("tokenserver.scripts.count_users") + +ZERO = timedelta(0) + + +class UTC(tzinfo): + + def utcoffset(self, dt): + return ZERO + + def tzname(self, dt): + return "UTC" + + def dst(self, dt): + return ZERO + + +utc = UTC() + + +def count_users(outfile, timestamp=None): + if timestamp is None: + ts = time.gmtime() + midnight = (ts[0], ts[1], ts[2], 0, 0, 0, ts[6], ts[7], ts[8]) + timestamp = int(time.mktime(midnight)) * 1000 + database = Database() + logger.debug("Counting users created before %i", timestamp) + count = database.count_users(timestamp) + logger.debug("Found %d users", count) + # Output has heka-filter-compatible JSON object. + ts_sec = timestamp / 1000 + output = { + "hostname": socket.gethostname(), + "pid": os.getpid(), + "op": "sync_count_users", + "total_users": count, + "time": datetime.fromtimestamp(ts_sec, utc).isoformat(), + "v": 0 + } + json.dump(output, outfile) + outfile.write("\n") + + +def main(args=None): + """Main entry-point for running this script. + + This function parses command-line arguments and passes them on + to the add_node() function. + """ + usage = "usage: %prog [options]" + descr = "Count total users in the tokenserver database" + parser = optparse.OptionParser(usage=usage, description=descr) + parser.add_option("-t", "--timestamp", type="int", + help="Max creation timestamp; default previous midnight") + parser.add_option("-o", "--output", + help="Output file; default stderr") + parser.add_option("-v", "--verbose", action="count", dest="verbosity", + help="Control verbosity of log messages") + + opts, args = parser.parse_args(args) + if len(args) != 0: + parser.print_usage() + return 1 + + util.configure_script_logging(opts) + + if opts.output in (None, "-"): + count_users(sys.stdout, opts.timestamp) + else: + with open(opts.output, "a") as outfile: + count_users(outfile, opts.timestamp) + + return 0 + + +if __name__ == "__main__": + util.run_script(main) diff --git a/tools/tokenserver/database.py b/tools/tokenserver/database.py new file mode 100644 index 00000000..6cb49f45 --- /dev/null +++ b/tools/tokenserver/database.py @@ -0,0 +1,641 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. + +import math +import os +from sqlalchemy import create_engine +from sqlalchemy.sql import text as sqltext + +from util import get_timestamp + +# The maximum possible generation number. +# Used as a tombstone to mark users that have been "retired" from the db. +MAX_GENERATION = 9223372036854775807 +NODE_FIELDS = ("capacity", "available", "current_load", "downed", "backoff") + +_GET_USER_RECORDS = sqltext("""\ +select + uid, nodes.node, generation, keys_changed_at, client_state, created_at, + replaced_at +from + users left outer join nodes on users.nodeid = nodes.id +where + email = :email and users.service = :service +order by + created_at desc, uid desc +limit + 20 +""") + +_CREATE_USER_RECORD = sqltext("""\ +insert into + users + (service, email, nodeid, generation, keys_changed_at, client_state, + created_at, replaced_at) +values + (:service, :email, :nodeid, :generation, :keys_changed_at, + :client_state, :timestamp, NULL) +""") + +# The `where` clause on this statement is designed as an extra layer of +# protection, to ensure that concurrent updates don't accidentally move +# timestamp fields backwards in time. The handling of `keys_changed_at` +# is additionally weird because we want to treat the default `NULL` value +# as zero. +_UPDATE_USER_RECORD_IN_PLACE = sqltext("""\ +update + users +set + generation = COALESCE(:generation, generation), + keys_changed_at = COALESCE(:keys_changed_at, keys_changed_at) +where + service = :service and email = :email and + generation <= COALESCE(:generation, generation) and + COALESCE(keys_changed_at, 0) <= + COALESCE(:keys_changed_at, keys_changed_at, 0) and + replaced_at is null +""") + + +_REPLACE_USER_RECORDS = sqltext("""\ +update + users +set + replaced_at = :timestamp +where + service = :service and email = :email + and replaced_at is null and created_at < :timestamp +""") + + +# Mark all records for the user as replaced, +# and set a large generation number to block future logins. +_RETIRE_USER_RECORDS = sqltext("""\ +update + users +set + replaced_at = :timestamp, + generation = :generation +where + email = :email + and replaced_at is null +""") + + +_GET_OLD_USER_RECORDS_FOR_SERVICE = sqltext("""\ +select + uid, email, generation, keys_changed_at, client_state, + nodes.node, nodes.downed, created_at, replaced_at +from + users left outer join nodes on users.nodeid = nodes.id +where + users.service = :service +and + replaced_at is not null and replaced_at < :timestamp +order by + replaced_at desc, uid desc +limit + :limit +offset + :offset +""") + + +_GET_ALL_USER_RECORDS_FOR_SERVICE = sqltext("""\ +select + uid, nodes.node, created_at, replaced_at +from + users left outer join nodes on users.nodeid = nodes.id +where + email = :email and users.service = :service +order by + created_at asc, uid desc +""") + + +_REPLACE_USER_RECORD = sqltext("""\ +update + users +set + replaced_at = :timestamp +where + service = :service +and + uid = :uid +""") + + +_DELETE_USER_RECORD = sqltext("""\ +delete from + users +where + service = :service +and + uid = :uid +""") + + +_FREE_SLOT_ON_NODE = sqltext("""\ +update + nodes +set + available = available + 1, current_load = current_load - 1 +where + id = (SELECT nodeid FROM users WHERE service=:service AND uid=:uid) +""") + + +_COUNT_USER_RECORDS = sqltext("""\ +select + count(email) +from + users +where + replaced_at is null + and created_at <= :timestamp +""") + + +_GET_BEST_NODE = sqltext("""\ +select + id, node +from + nodes +where + service = :service + and available > 0 + and capacity > current_load + and downed = 0 + and backoff = 0 +order by + log(current_load) / log(capacity) +limit 1 +""") + + +_RELEASE_NODE_CAPACITY = sqltext("""\ +update + nodes +set + available = least(capacity * :capacity_release_rate, + capacity - current_load) +where + service = :service + and available <= 0 + and capacity > current_load + and downed = 0 +""") + + +_ADD_USER_TO_NODE = sqltext("""\ +update + nodes +set + current_load = current_load + 1, + available = greatest(available - 1, 0) +where + service = :service + and node = :node +""") + + +_GET_SERVICE_ID = sqltext("""\ +select + id +from + services +where + service = :service +""") + + +_GET_NODE = sqltext("""\ +select + * +from + nodes +where + service = :service + and node = :node + """) + +SERVICE_NAME = 'sync-1.5' + + +class Database: + def __init__(self): + engine = create_engine(os.environ['SYNC_TOKENSERVER__DATABASE_URL']) + self.database = engine. \ + execution_options(isolation_level="AUTOCOMMIT"). \ + connect() + self.service_id = self._get_service_id(SERVICE_NAME) + self.capacity_release_rate = os.environ. \ + get("NODE_CAPACITY_RELEASE_RATE", 0.1) + + def _execute_sql(self, *args, **kwds): + return self.database.execute(*args, **kwds) + + def close(self): + self.database.close() + + def get_user(self, email): + params = {'service': self.service_id, 'email': email} + res = self._execute_sql(_GET_USER_RECORDS, **params) + try: + # The query fetches rows ordered by created_at, but we want + # to ensure that they're ordered by (generation, created_at). + # This is almost always true, except for strange race conditions + # during row creation. Sorting them is an easy way to enforce + # this without bloating the db index. + rows = res.fetchall() + rows.sort(key=lambda r: (r.generation, r.created_at), reverse=True) + if not rows: + return None + # The first row is the most up-to-date user record. + # The rest give previously-seen client-state values. + cur_row = rows[0] + old_rows = rows[1:] + user = { + 'email': email, + 'uid': cur_row.uid, + 'node': cur_row.node, + 'generation': cur_row.generation, + 'keys_changed_at': cur_row.keys_changed_at or 0, + 'client_state': cur_row.client_state, + 'old_client_states': {}, + 'first_seen_at': cur_row.created_at, + } + # If the current row is marked as replaced or is missing a node, + # and they haven't been retired, then assign them a new node. + if cur_row.replaced_at is not None or cur_row.node is None: + if cur_row.generation < MAX_GENERATION: + user = self.allocate_user(email, + cur_row.generation, + cur_row.client_state, + cur_row.keys_changed_at) + for old_row in old_rows: + # Collect any previously-seen client-state values. + if old_row.client_state != user['client_state']: + user['old_client_states'][old_row.client_state] = True + # Make sure each old row is marked as replaced. + # They might not be, due to races in row creation. + if old_row.replaced_at is None: + timestamp = cur_row.created_at + self.replace_user_record(old_row.uid, timestamp) + # Track backwards to the oldest timestamp at which we saw them. + user['first_seen_at'] = old_row.created_at + return user + finally: + res.close() + + def allocate_user(self, email, generation=0, client_state='', + keys_changed_at=0, node=None, timestamp=None): + if timestamp is None: + timestamp = get_timestamp() + if node is None: + nodeid, node = self.get_best_node() + else: + nodeid = self.get_node_id(node) + params = { + 'service': self.service_id, + 'email': email, + 'nodeid': nodeid, + 'generation': generation, + 'keys_changed_at': keys_changed_at, + 'client_state': client_state, + 'timestamp': timestamp + } + res = self._execute_sql(_CREATE_USER_RECORD, **params) + return { + 'email': email, + 'uid': res.lastrowid, + 'node': node, + 'generation': generation, + 'keys_changed_at': keys_changed_at, + 'client_state': client_state, + 'old_client_states': {}, + 'first_seen_at': timestamp, + } + + def update_user(self, user, generation=None, client_state=None, + keys_changed_at=None, node=None): + if client_state is None and node is None: + # No need for a node-reassignment, just update the row in place. + # Note that if we're changing keys_changed_at without changing + # client_state, it's because we're seeing an existing value of + # keys_changed_at for the first time. + params = { + 'service': self.service_id, + 'email': user['email'], + 'generation': generation, + 'keys_changed_at': keys_changed_at + } + res = self._execute_sql(_UPDATE_USER_RECORD_IN_PLACE, **params) + res.close() + + user['generation'] = max([x + for x + in [generation, user['generation']] + if x is not None]) + user['keys_changed_at'] = max([x + for x + in [keys_changed_at, + user['keys_changed_at']] + if x is not None]) + else: + # Reject previously-seen client-state strings. + if client_state is None: + client_state = user['client_state'] + else: + if client_state == user['client_state']: + raise Exception('previously seen client-state string') + if client_state in user['old_client_states']: + raise Exception('previously seen client-state string') + # Need to create a new record for new user state. + # If the node is not explicitly changing, try to keep them on the + # same node, but if e.g. it no longer exists them allocate them to + # a new one. + if node is not None: + nodeid = self.get_node_id(node) + user['node'] = node + else: + try: + nodeid = self.get_node_id(user['node']) + except ValueError: + nodeid, node = self.get_best_node() + user['node'] = node + if generation is not None: + generation = max(user['generation'], generation) + else: + generation = user['generation'] + if keys_changed_at is not None: + keys_changed_at = max(user['keys_changed_at'], keys_changed_at) + else: + keys_changed_at = user['keys_changed_at'] + now = get_timestamp() + params = { + 'service': self.service_id, 'email': user['email'], + 'nodeid': nodeid, 'generation': generation, + 'keys_changed_at': keys_changed_at, + 'client_state': client_state, 'timestamp': now, + } + res = self._execute_sql(_CREATE_USER_RECORD, **params) + res.close() + user['uid'] = res.lastrowid + user['generation'] = generation + user['keys_changed_at'] = keys_changed_at + user['old_client_states'][user['client_state']] = True + user['client_state'] = client_state + # mark old records as having been replaced. + # if we crash here, they are unmarked and we may fail to + # garbage collect them for a while, but the active state + # will be undamaged. + self.replace_user_records(user['email'], now) + + def retire_user(self, email): + now = get_timestamp() + params = { + 'email': email, 'timestamp': now, 'generation': MAX_GENERATION + } + # Pass through explicit engine to help with sharded implementation, + # since we can't shard by service name here. + res = self._execute_sql(_RETIRE_USER_RECORDS, **params) + res.close() + + def count_users(self, timestamp=None): + if timestamp is None: + timestamp = get_timestamp() + res = self._execute_sql(_COUNT_USER_RECORDS, timestamp=timestamp) + row = res.fetchone() + res.close() + return row[0] + + # + # Methods for low-level user record management. + # + + def get_user_records(self, email): + """Get all the user's records, including the old ones.""" + params = {'service': self.service_id, 'email': email} + res = self._execute_sql(_GET_ALL_USER_RECORDS_FOR_SERVICE, **params) + try: + for row in res: + yield row + finally: + res.close() + + def get_old_user_records(self, grace_period=-1, limit=100, + offset=0): + """Get user records that were replaced outside the grace period.""" + if grace_period < 0: + grace_period = 60 * 60 * 24 * 7 # one week, in seconds + grace_period = int(grace_period * 1000) # convert seconds -> millis + params = { + "service": self.service_id, + "timestamp": get_timestamp() - grace_period, + "limit": limit, + "offset": offset + } + res = self._execute_sql(_GET_OLD_USER_RECORDS_FOR_SERVICE, **params) + try: + for row in res: + yield row + finally: + res.close() + + def replace_user_records(self, email, timestamp=None): + """Mark all existing records for a user as replaced.""" + if timestamp is None: + timestamp = get_timestamp() + params = { + 'service': self.service_id, 'email': email, 'timestamp': timestamp + } + res = self._execute_sql(_REPLACE_USER_RECORDS, **params) + res.close() + + def replace_user_record(self, uid, timestamp=None): + """Mark an existing service record as replaced.""" + if timestamp is None: + timestamp = get_timestamp() + params = { + 'service': self.service_id, 'uid': uid, 'timestamp': timestamp + } + res = self._execute_sql(_REPLACE_USER_RECORD, **params) + res.close() + + def delete_user_record(self, uid): + """Delete the user record with the given uid.""" + params = {'service': self.service_id, 'uid': uid} + res = self._execute_sql(_FREE_SLOT_ON_NODE, **params) + res.close() + res = self._execute_sql(_DELETE_USER_RECORD, **params) + res.close() + + # + # Nodes management + # + + def _get_service_id(self, service): + res = self._execute_sql(_GET_SERVICE_ID, service=service) + row = res.fetchone() + res.close() + if row is None: + raise Exception('unknown service: ' + service) + return row.id + + def add_service(self, service_name, pattern, **kwds): + """Add definition for a new service.""" + res = self._execute_sql(sqltext(""" + insert into services (service, pattern) + values (:servicename, :pattern) + """), servicename=service_name, pattern=pattern, **kwds) + res.close() + return res.lastrowid + + def add_node(self, node, capacity, **kwds): + """Add definition for a new node.""" + available = kwds.get('available') + # We release only a fraction of the node's capacity to start. + if available is None: + available = math.ceil(capacity * self.capacity_release_rate) + cols = ["service", "node", "available", "capacity", + "current_load", "downed", "backoff"] + args = [":" + v for v in cols] + # Handle test cases that require nodeid to be 800 + if "nodeid" in kwds: + cols.append("id") + args.append(":nodeid") + query = """ + insert into nodes ({cols}) + values ({args}) + """.format(cols=", ".join(cols), args=", ".join(args)) + res = self._execute_sql( + sqltext(query), + nodeid=kwds.get('nodeid'), + service=self.service_id, + node=node, + capacity=capacity, + available=available, + current_load=kwds.get('current_load', 0), + downed=kwds.get('downed', 0), + backoff=kwds.get('backoff', 0), + ) + res.close() + + def update_node(self, node, **kwds): + """Updates node fields in the db.""" + values = {} + cols = NODE_FIELDS & kwds.keys() + for col in NODE_FIELDS: + try: + values[col] = kwds.pop(col) + except KeyError: + pass + + args = [v + " = :" + v for v in cols] + query = """ + update nodes + set """ + query += ", ".join(args) + query += """ + where service = :service and node = :node + """ + values['service'] = self.service_id + values['node'] = node + if kwds: + raise ValueError("unknown fields: " + str(kwds.keys())) + con = self._execute_sql(sqltext(query), **values) + con.close() + + def get_node_id(self, node): + """Get numeric id for a node.""" + res = self._execute_sql( + sqltext(""" + select id from nodes + where service=:service and node=:node + """), + service=self.service_id, node=node + ) + row = res.fetchone() + res.close() + if row is None: + raise ValueError("unknown node: " + node) + return row[0] + + def remove_node(self, node, timestamp=None): + """Remove definition for a node.""" + nodeid = self.get_node_id(node) + res = self._execute_sql(sqltext( + """ + delete from nodes where id=:nodeid + """), + nodeid=nodeid + ) + res.close() + self.unassign_node(node, timestamp, nodeid=nodeid) + + def unassign_node(self, node, timestamp=None, nodeid=None): + """Clear any assignments to a node.""" + if timestamp is None: + timestamp = get_timestamp() + if nodeid is None: + nodeid = self.get_node_id(node) + res = self._execute_sql( + sqltext(""" + update users + set replaced_at=:timestamp + where nodeid=:nodeid + """), + nodeid=nodeid, timestamp=timestamp + ) + res.close() + + def get_best_node(self): + """Returns the 'least loaded' node currently available, increments the + active count on that node, and decrements the slots currently available + """ + # We may have to re-try the query if we need to release more capacity. + # This loop allows a maximum of five retries before bailing out. + for _ in range(5): + res = self._execute_sql(_GET_BEST_NODE, service=self.service_id) + row = res.fetchone() + res.close() + if row is None: + # Try to release additional capacity from any nodes + # that are not fully occupied. + res = self._execute_sql( + _RELEASE_NODE_CAPACITY, + capacity_release_rate=self.capacity_release_rate, + service=self.service_id + ) + res.close() + if res.rowcount == 0: + break + else: + break + + # Did we succeed in finding a node? + if row is None: + raise Exception('unable to get a node') + + nodeid = row.id + node = str(row.node) + + # Update the node to reflect the new assignment. + # This is a little racy with concurrent assignments, but no big + # deal. + con = self._execute_sql(_ADD_USER_TO_NODE, + service=self.service_id, + node=node) + con.close() + + return nodeid, node + + def get_node(self, node): + res = self._execute_sql(_GET_NODE, service=self.service_id, node=node) + row = res.fetchone() + res.close() + if row is None: + raise Exception('unknown node: ' + node) + return row diff --git a/tools/tokenserver/process_account_events.py b/tools/tokenserver/process_account_events.py new file mode 100644 index 00000000..98ea0f19 --- /dev/null +++ b/tools/tokenserver/process_account_events.py @@ -0,0 +1,179 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +Script to process account-related events from an SQS queue. + +This script polls an SQS queue for events indicating activity on an upstream +account, as documented here: + + https://github.com/mozilla/fxa-auth-server/blob/master/docs/service_notifications.md + +The following event types are currently supported: + + * "delete": the account was deleted; we mark their records as retired + so they'll be cleaned up by our garbage-collection process. + + * "reset": the account password was reset; we update our copy of their + generation number to disconnect other devices. + + * "passwordChange": the account password was changed; we update our copy + of their generation number to disconnect other devices. + +Note that this is a purely optional administrative task, highly specific to +Mozilla's internal Firefox-Accounts-supported deployment. + +""" + +import json +import logging +import optparse + +import boto +import boto.ec2 +import boto.sqs +import boto.sqs.message +import boto.utils + +import util +from database import Database + + +logger = logging.getLogger("tokenserver.scripts.process_account_deletions") + + +def process_account_events(queue_name, aws_region=None, queue_wait_time=20): + """Process account events from an SQS queue. + + This function polls the specified SQS queue for account-realted events, + processing each as it is found. It polls indefinitely and does not return; + to interrupt execution you'll need to e.g. SIGINT the process. + """ + logger.info("Processing account events from %s", queue_name) + try: + # Connect to the SQS queue. + # If no region is given, infer it from the instance metadata. + if aws_region is None: + logger.debug("Finding default region from instance metadata") + aws_info = boto.utils.get_instance_metadata() + aws_region = aws_info["placement"]["availability-zone"][:-1] + logger.debug("Connecting to queue %r in %r", queue_name, aws_region) + conn = boto.sqs.connect_to_region(aws_region) + queue = conn.get_queue(queue_name) + # We must force boto not to b64-decode the message contents, ugh. + queue.set_message_class(boto.sqs.message.RawMessage) + # Poll for messages indefinitely. + while True: + msg = queue.read(wait_time_seconds=queue_wait_time) + if msg is None: + continue + process_account_event(msg.get_body()) + # This intentionally deletes the event even if it was some + # unrecognized type. Not point leaving a backlog. + queue.delete_message(msg) + except Exception: + logger.exception("Error while processing account events") + raise + + +def process_account_event(body): + """Parse and process a single account event.""" + database = Database() + # Try very hard not to error out if there's junk in the queue. + email = None + event_type = None + generation = None + try: + body = json.loads(body) + event = json.loads(body['Message']) + event_type = event["event"] + uid = event["uid"] + # Older versions of the fxa-auth-server would send an email-like + # identifier the "uid" field, but that doesn't make sense for any + # relier other than tokenserver. Newer versions send just the raw uid + # in the "uid" field, and include the domain in a separate "iss" field. + if "iss" in event: + email = "%s@%s" % (uid, event["iss"]) + else: + if "@" not in uid: + raise ValueError("uid field does not contain issuer info") + email = uid + if event_type in ("reset", "passwordChange",): + generation = event["generation"] + except (ValueError, KeyError) as e: + logger.exception("Invalid account message: %s", e) + else: + if email is not None: + if event_type == "delete": + # Mark the user as retired. + # Actual cleanup is done by a separate process. + logger.info("Processing account delete for %r", email) + database.retire_user(email) + elif event_type == "reset": + logger.info("Processing account reset for %r", email) + update_generation_number(database, email, generation) + elif event_type == "passwordChange": + logger.info("Processing password change for %r", email) + update_generation_number(database, email, generation) + else: + logger.warning("Dropping unknown event type %r", + event_type) + + +def update_generation_number(database, email, generation): + """Update the maximum recorded generation number for the given user. + + When the FxA server sends us an update to the user's generation + number, we want to update our high-water-mark in the DB in order to + immediately lock out disconnected devices. However, since we don't + know the new value of the client state that goes with it, we can't just + record the new generation number in the DB. If we did, the first + device that tried to sync with the new generation number would appear + to have an incorrect client state value, and would be rejected. + + Instead, we take advantage of the fact that it's a timestamp, and write + it into the DB at one millisecond less than its current value. This + ensures that we lock out any devices with an older generation number + while avoiding errors with client state handling. + + This does leave a tiny edge-case where we can fail to lock out older + devices, if the generation number changes twice in less than a + millisecond. This is acceptably unlikely in practice, and we'll recover + as soon as we see an updated generation number as part of a sync. + """ + user = database.get_user(email) + if user is not None: + database.update_user(user, generation - 1) + + +def main(args=None): + """Main entry-point for running this script. + + This function parses command-line arguments and passes them on + to the process_account_events() function. + """ + usage = "usage: %prog [options] queue_name" + parser = optparse.OptionParser(usage=usage) + parser.add_option("", "--aws-region", + help="aws region in which the queue can be found") + parser.add_option("", "--queue-wait-time", type="int", default=20, + help="Number of seconds to wait for jobs on the queue") + parser.add_option("-v", "--verbose", action="count", dest="verbosity", + help="Control verbosity of log messages") + + opts, args = parser.parse_args(args) + if len(args) != 1: + parser.print_usage() + return 1 + + util.configure_script_logging(opts) + + queue_name = args[0] + + process_account_events(queue_name, opts.aws_region, opts.queue_wait_time) + return 0 + + +if __name__ == "__main__": + util.run_script(main) diff --git a/tools/tokenserver/purge_old_records.py b/tools/tokenserver/purge_old_records.py new file mode 100644 index 00000000..b3d4df6a --- /dev/null +++ b/tools/tokenserver/purge_old_records.py @@ -0,0 +1,177 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +Script to purge user records that have been replaced. + +This script purges any obsolete user records from the database. +Obsolete records are those that have been replaced by a newer record for +the same user. + +Note that this is a purely optional administrative task, since replaced records +are handled internally by the assignment backend. But it should help reduce +overheads, improve performance etc if run regularly. + +""" + +import binascii +import hawkauthlib +import logging +import optparse +import random +import requests +import time +import tokenlib + +import util +from database import Database +from util import format_key_id + + +logger = logging.getLogger("tokenserver.scripts.purge_old_records") + +PATTERN = "{node}/1.5/{uid}" + + +def purge_old_records(secret, grace_period=-1, max_per_loop=10, max_offset=0, + request_timeout=60): + """Purge old records from the database. + + This function queries all of the old user records in the database, deletes + the Tokenserver database record for each of the users, and issues a delete + request to each user's storage node. The result is a gradual pruning of + expired items from each database. + + `max_offset` is used to select a random offset into the list of purgeable + records. With multiple tasks running concurrently, this will provide each + a (likely) different set of records to work on. A cheap, imperfect + randomization. + """ + logger.info("Purging old user records") + try: + database = Database() + # Process batches of items, until we run out. + while True: + offset = random.randint(0, max_offset) + kwds = { + "grace_period": grace_period, + "limit": max_per_loop, + "offset": offset, + } + rows = list(database.get_old_user_records(**kwds)) + logger.info("Fetched %d rows at offset %d", len(rows), offset) + for row in rows: + # Don't attempt to purge data from downed nodes. + # Instead wait for them to either come back up or to be + # completely removed from service. + if row.node is None: + logger.info("Deleting user record for uid %s on %s", + row.uid, row.node) + database.delete_user_record(row.uid) + elif not row.downed: + logger.info("Purging uid %s on %s", row.uid, row.node) + delete_service_data(row, secret, timeout=request_timeout) + database.delete_user_record(row.uid) + if len(rows) < max_per_loop: + break + except Exception: + logger.exception("Error while purging old user records") + return False + else: + logger.info("Finished purging old user records") + return True + + +def delete_service_data(user, secret, timeout=60): + """Send a data-deletion request to the user's service node. + + This is a little bit of hackery to cause the user's service node to + remove any data it still has stored for the user. We simulate a DELETE + request from the user's own account. + """ + token = tokenlib.make_token({ + "uid": user.uid, + "node": user.node, + "fxa_uid": user.email.split("@", 1)[0], + "fxa_kid": format_key_id( + user.keys_changed_at or user.generation, + binascii.unhexlify(user.client_state) + ), + }, secret=secret) + secret = tokenlib.get_derived_secret(token, secret=secret) + endpoint = PATTERN.format(uid=user.uid, node=user.node) + auth = HawkAuth(token, secret) + resp = requests.delete(endpoint, auth=auth, timeout=timeout) + if resp.status_code >= 400 and resp.status_code != 404: + resp.raise_for_status() + + +class HawkAuth(requests.auth.AuthBase): + """Hawk-signing auth helper class.""" + + def __init__(self, token, secret): + self.token = token + self.secret = secret + + def __call__(self, req): + hawkauthlib.sign_request(req, self.token, self.secret) + return req + + +def main(args=None): + """Main entry-point for running this script. + + This function parses command-line arguments and passes them on + to the purge_old_records() function. + """ + usage = "usage: %prog [options] secret" + parser = optparse.OptionParser(usage=usage) + parser.add_option("", "--purge-interval", type="int", default=3600, + help="Interval to sleep between purging runs") + parser.add_option("", "--grace-period", type="int", default=86400, + help="Number of seconds grace to allow on replacement") + parser.add_option("", "--max-per-loop", type="int", default=10, + help="Maximum number of items to fetch in one go") + # N.B., if the number of purgeable rows is <<< max_offset then most + # selects will return zero rows. Choose this value accordingly. + parser.add_option("", "--max-offset", type="int", default=0, + help="Use random offset from 0 to max_offset") + parser.add_option("", "--request-timeout", type="int", default=60, + help="Timeout for service deletion requests") + parser.add_option("", "--oneshot", action="store_true", + help="Do a single purge run and then exit") + parser.add_option("-v", "--verbose", action="count", dest="verbosity", + help="Control verbosity of log messages") + + opts, args = parser.parse_args(args) + if len(args) != 2: + parser.print_usage() + return 1 + + secret = args[1] + + util.configure_script_logging(opts) + + purge_old_records(secret, + grace_period=opts.grace_period, + max_per_loop=opts.max_per_loop, + max_offset=opts.max_offset, + request_timeout=opts.request_timeout) + if not opts.oneshot: + while True: + # Randomize sleep interval +/- thirty percent to desynchronize + # instances of this script running on multiple webheads. + sleep_time = opts.purge_interval + sleep_time += random.randint(-0.3 * sleep_time, 0.3 * sleep_time) + logger.debug("Sleeping for %d seconds", sleep_time) + time.sleep(sleep_time) + purge_old_records(grace_period=opts.grace_period, + max_per_loop=opts.max_per_loop, + max_offset=opts.max_offset, + request_timeout=opts.request_timeout) + return 0 + + +if __name__ == "__main__": + util.run_script(main) diff --git a/tools/tokenserver/remove_node.py b/tools/tokenserver/remove_node.py new file mode 100644 index 00000000..6789063f --- /dev/null +++ b/tools/tokenserver/remove_node.py @@ -0,0 +1,73 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +Script to remove a node from the system. + +This script nukes any references to the named node - it is removed from +the "nodes" table and any users currently assigned to that node have their +assignments cleared. + +""" + +import logging +import optparse + +import util +from database import Database + +logger = logging.getLogger("tokenserver.scripts.remove_node") + + +def remove_node(node): + """Remove the named node from the system.""" + logger.info("Removing node %s", node) + try: + database = Database() + found = False + try: + database.remove_node(node) + except ValueError: + logger.debug(" not found") + else: + found = True + logger.debug(" removed") + except Exception: + logger.exception("Error while removing node") + return False + else: + if not found: + logger.info("Node %s was not found", node) + else: + logger.info("Finished removing node %s", node) + return True + + +def main(args=None): + """Main entry-point for running this script. + + This function parses command-line arguments and passes them on + to the remove_node() function. + """ + usage = "usage: %prog [options] node_name" + descr = "Remove a node from the tokenserver database" + parser = optparse.OptionParser(usage=usage, description=descr) + parser.add_option("-v", "--verbose", action="count", dest="verbosity", + help="Control verbosity of log messages") + + opts, args = parser.parse_args(args) + if len(args) != 1: + parser.print_usage() + return 1 + + util.configure_script_logging(opts) + + node_name = args[0] + + remove_node(node_name) + return 0 + + +if __name__ == "__main__": + util.run_script(main) diff --git a/tools/tokenserver/requirements.txt b/tools/tokenserver/requirements.txt new file mode 100644 index 00000000..f116d8c0 --- /dev/null +++ b/tools/tokenserver/requirements.txt @@ -0,0 +1,6 @@ +boto +hawkauthlib +mysqlclient +pyramid +sqlalchemy +testfixtures diff --git a/tools/tokenserver/run_tests.py b/tools/tokenserver/run_tests.py new file mode 100644 index 00000000..dcab490d --- /dev/null +++ b/tools/tokenserver/run_tests.py @@ -0,0 +1,25 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. + +import sys +import unittest + +from test_database import TestDatabase +from test_process_account_events import TestProcessAccountEvents +from test_purge_old_records import TestPurgeOldRecords +from test_scripts import TestScripts + +if __name__ == "__main__": + loader = unittest.TestLoader() + test_cases = [TestDatabase, TestPurgeOldRecords, TestProcessAccountEvents, + TestScripts] + + res = 0 + for test_case in test_cases: + suite = loader.loadTestsFromTestCase(test_case) + runner = unittest.TextTestRunner() + if not runner.run(suite).wasSuccessful(): + res = 1 + + sys.exit(res) diff --git a/tools/tokenserver/test_database.py b/tools/tokenserver/test_database.py new file mode 100644 index 00000000..9266f241 --- /dev/null +++ b/tools/tokenserver/test_database.py @@ -0,0 +1,446 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. + +import time +import unittest + +from collections import defaultdict +from database import MAX_GENERATION, Database +from util import get_timestamp + + +class TestDatabase(unittest.TestCase): + def setUp(self): + super(TestDatabase, self).setUp() + self.database = Database() + # Start each test with a blank slate. + cursor = self.database._execute_sql(('DELETE FROM users'), ()) + cursor.close() + + cursor = self.database._execute_sql(('DELETE FROM nodes'), ()) + cursor.close() + + self.database.add_node('https://phx12', 100) + + def tearDown(self): + super(TestDatabase, self).tearDown() + # And clean up at the end, for good measure. + cursor = self.database._execute_sql(('DELETE FROM users'), ()) + cursor.close() + + cursor = self.database._execute_sql(('DELETE FROM nodes'), ()) + cursor.close() + + self.database.close() + + def test_node_allocation(self): + user = self.database.get_user('test1@example.com') + self.assertEquals(user, None) + + user = self.database.allocate_user('test1@example.com') + wanted = 'https://phx12' + self.assertEqual(user['node'], wanted) + + user = self.database.get_user('test1@example.com') + self.assertEqual(user['node'], wanted) + + def test_allocation_to_least_loaded_node(self): + self.database.add_node('https://phx13', 100) + user1 = self.database.allocate_user('test1@mozilla.com') + user2 = self.database.allocate_user('test2@mozilla.com') + self.assertNotEqual(user1['node'], user2['node']) + + def test_allocation_is_not_allowed_to_downed_nodes(self): + self.database.update_node('https://phx12', + downed=True) + with self.assertRaises(Exception): + self.database.allocate_user('test1@mozilla.com') + + def test_allocation_is_not_allowed_to_backoff_nodes(self): + self.database.update_node('https://phx12', + backoff=True) + with self.assertRaises(Exception): + self.database.allocate_user('test1@mozilla.com') + + def test_update_generation_number(self): + user = self.database.allocate_user('test1@example.com') + self.assertEqual(user['generation'], 0) + self.assertEqual(user['client_state'], '') + orig_uid = user['uid'] + orig_node = user['node'] + + # Changing generation should leave other properties unchanged. + self.database.update_user(user, generation=42) + self.assertEqual(user['uid'], orig_uid) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 42) + self.assertEqual(user['client_state'], '') + + user = self.database.get_user('test1@example.com') + self.assertEqual(user['uid'], orig_uid) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 42) + self.assertEqual(user['client_state'], '') + + # It's not possible to move generation number backwards. + self.database.update_user(user, generation=17) + self.assertEqual(user['uid'], orig_uid) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 42) + self.assertEqual(user['client_state'], '') + + user = self.database.get_user('test1@example.com') + self.assertEqual(user['uid'], orig_uid) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 42) + self.assertEqual(user['client_state'], '') + + def test_update_client_state(self): + user = self.database.allocate_user('test1@example.com') + self.assertEqual(user['generation'], 0) + self.assertEqual(user['client_state'], '') + self.assertEqual(set(user['old_client_states']), set(())) + seen_uids = set((user['uid'],)) + orig_node = user['node'] + + # Changing client-state allocates a new userid. + self.database.update_user(user, client_state='aaa') + self.assertTrue(user['uid'] not in seen_uids) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 0) + self.assertEqual(user['client_state'], 'aaa') + self.assertEqual(set(user['old_client_states']), set(('',))) + + user = self.database.get_user('test1@example.com') + self.assertTrue(user['uid'] not in seen_uids) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 0) + self.assertEqual(user['client_state'], 'aaa') + self.assertEqual(set(user['old_client_states']), set(('',))) + + seen_uids.add(user['uid']) + + # It's possible to change client-state and generation at once. + self.database.update_user(user, + client_state='bbb', generation=12) + self.assertTrue(user['uid'] not in seen_uids) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 12) + self.assertEqual(user['client_state'], 'bbb') + self.assertEqual(set(user['old_client_states']), set(('', 'aaa'))) + + user = self.database.get_user('test1@example.com') + self.assertTrue(user['uid'] not in seen_uids) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 12) + self.assertEqual(user['client_state'], 'bbb') + self.assertEqual(set(user['old_client_states']), set(('', 'aaa'))) + + # You can't got back to an old client_state. + orig_uid = user['uid'] + with self.assertRaises(Exception): + self.database.update_user(user, + client_state='aaa') + + user = self.database.get_user('test1@example.com') + self.assertEqual(user['uid'], orig_uid) + self.assertEqual(user['node'], orig_node) + self.assertEqual(user['generation'], 12) + self.assertEqual(user['client_state'], 'bbb') + self.assertEqual(set(user['old_client_states']), set(('', 'aaa'))) + + def test_user_retirement(self): + self.database.allocate_user('test@mozilla.com') + user1 = self.database.get_user('test@mozilla.com') + self.database.retire_user('test@mozilla.com') + user2 = self.database.get_user('test@mozilla.com') + self.assertTrue(user2['generation'] > user1['generation']) + + def test_cleanup_of_old_records(self): + # Create 6 user records for the first user. + # Do a sleep halfway through so we can test use of grace period. + email1 = 'test1@mozilla.com' + user1 = self.database.allocate_user(email1) + self.database.update_user(user1, client_state='a') + self.database.update_user(user1, client_state='b') + self.database.update_user(user1, client_state='c') + break_time = time.time() + time.sleep(0.1) + self.database.update_user(user1, client_state='d') + self.database.update_user(user1, client_state='e') + records = list(self.database.get_user_records(email1)) + self.assertEqual(len(records), 6) + # Create 3 user records for the second user. + email2 = 'test2@mozilla.com' + user2 = self.database.allocate_user(email2) + self.database.update_user(user2, client_state='a') + self.database.update_user(user2, client_state='b') + records = list(self.database.get_user_records(email2)) + self.assertEqual(len(records), 3) + # That should be a total of 7 old records. + old_records = list(self.database.get_old_user_records(0)) + self.assertEqual(len(old_records), 7) + # And with max_offset of 3, the first record should be id 4 + old_records = list(self.database.get_old_user_records(0, + 100, 3)) + # The 'limit' parameter should be respected. + old_records = list(self.database.get_old_user_records(0, 2)) + self.assertEqual(len(old_records), 2) + # The default grace period is too big to pick them up. + old_records = list(self.database.get_old_user_records()) + self.assertEqual(len(old_records), 0) + # The grace period can select a subset of the records. + grace = time.time() - break_time + old_records = list(self.database.get_old_user_records(grace)) + self.assertEqual(len(old_records), 3) + # Old records can be successfully deleted: + for record in old_records: + self.database.delete_user_record(record.uid) + old_records = list(self.database.get_old_user_records(0)) + self.assertEqual(len(old_records), 4) + + def test_node_reassignment_when_records_are_replaced(self): + self.database.allocate_user('test@mozilla.com', + generation=42, + keys_changed_at=12, + client_state='aaa') + user1 = self.database.get_user('test@mozilla.com') + self.database.replace_user_records('test@mozilla.com') + user2 = self.database.get_user('test@mozilla.com') + # They should have got a new uid. + self.assertNotEqual(user2['uid'], user1['uid']) + # But their account metadata should have been preserved. + self.assertEqual(user2['generation'], user1['generation']) + self.assertEqual(user2['keys_changed_at'], user1['keys_changed_at']) + self.assertEqual(user2['client_state'], user1['client_state']) + + def test_node_reassignment_not_done_for_retired_users(self): + self.database.allocate_user('test@mozilla.com', + generation=42, client_state='aaa') + user1 = self.database.get_user('test@mozilla.com') + self.database.retire_user('test@mozilla.com') + user2 = self.database.get_user('test@mozilla.com') + self.assertEqual(user2['uid'], user1['uid']) + self.assertEqual(user2['generation'], MAX_GENERATION) + self.assertEqual(user2['client_state'], user2['client_state']) + + def test_recovery_from_racy_record_creation(self): + timestamp = get_timestamp() + # Simulate race for forcing creation of two rows with same timestamp. + user1 = self.database.allocate_user('test@mozilla.com', + timestamp=timestamp) + user2 = self.database.allocate_user('test@mozilla.com', + timestamp=timestamp) + self.assertNotEqual(user1['uid'], user2['uid']) + # Neither is marked replaced initially. + old_records = list( + self.database.get_old_user_records(0) + ) + self.assertEqual(len(old_records), 0) + # Reading current details will detect the problem and fix it. + self.database.get_user('test@mozilla.com') + old_records = list( + self.database.get_old_user_records(0) + ) + self.assertEqual(len(old_records), 1) + + def test_that_race_recovery_respects_generation_number_monotonicity(self): + timestamp = get_timestamp() + # Simulate race between clients with different generation numbers, + # in which the out-of-date client gets a higher timestamp. + user1 = self.database.allocate_user('test@mozilla.com', + generation=1, + timestamp=timestamp) + user2 = self.database.allocate_user('test@mozilla.com', + generation=2, + timestamp=timestamp - 1) + self.assertNotEqual(user1['uid'], user2['uid']) + # Reading current details should promote the higher-generation one. + user = self.database.get_user('test@mozilla.com') + self.assertEqual(user['generation'], 2) + self.assertEqual(user['uid'], user2['uid']) + # And the other record should get marked as replaced. + old_records = list( + self.database.get_old_user_records(0) + ) + self.assertEqual(len(old_records), 1) + + def test_node_reassignment_and_removal(self): + NODE1 = 'https://phx12' + NODE2 = 'https://phx13' + # note that NODE1 is created by default for all tests. + self.database.add_node(NODE2, 100) + # Assign four users, we should get two on each node. + user1 = self.database.allocate_user('test1@mozilla.com') + user2 = self.database.allocate_user('test2@mozilla.com') + user3 = self.database.allocate_user('test3@mozilla.com') + user4 = self.database.allocate_user('test4@mozilla.com') + node_counts = defaultdict(lambda: 0) + for user in (user1, user2, user3, user4): + node_counts[user['node']] += 1 + self.assertEqual(node_counts[NODE1], 2) + self.assertEqual(node_counts[NODE2], 2) + # Clear the assignments for NODE1, and re-assign. + # The users previously on NODE1 should balance across both nodes, + # giving 1 on NODE1 and 3 on NODE2. + self.database.unassign_node(NODE1) + node_counts = defaultdict(lambda: 0) + for user in (user1, user2, user3, user4): + new_user = self.database.get_user(user['email']) + if user['node'] == NODE2: + self.assertEqual(new_user['node'], NODE2) + node_counts[new_user['node']] += 1 + self.assertEqual(node_counts[NODE1], 1) + self.assertEqual(node_counts[NODE2], 3) + # Remove NODE2. Everyone should wind up on NODE1. + self.database.remove_node(NODE2) + for user in (user1, user2, user3, user4): + new_user = self.database.get_user(user['email']) + self.assertEqual(new_user['node'], NODE1) + # The old users records pointing to NODE2 should have a NULL 'node' + # property since it has been removed from the db. + null_node_count = 0 + for row in self.database.get_old_user_records(0): + if row.node is None: + null_node_count += 1 + else: + self.assertEqual(row.node, NODE1) + self.assertEqual(null_node_count, 3) + + def test_that_race_recovery_respects_generation_after_reassignment(self): + timestamp = get_timestamp() + # Simulate race between clients with different generation numbers, + # in which the out-of-date client gets a higher timestamp. + user1 = self.database.allocate_user('test@mozilla.com', + generation=1, + timestamp=timestamp) + user2 = self.database.allocate_user('test@mozilla.com', + generation=2, + timestamp=timestamp - 1) + self.assertNotEqual(user1['uid'], user2['uid']) + # Force node re-assignment by marking all records as replaced. + self.database.replace_user_records('test@mozilla.com', + timestamp=timestamp + 1) + # The next client to show up should get a new assignment, marked + # with the correct generation number. + user = self.database.get_user('test@mozilla.com') + self.assertEqual(user['generation'], 2) + self.assertNotEqual(user['uid'], user1['uid']) + self.assertNotEqual(user['uid'], user2['uid']) + + def test_that_we_can_allocate_users_to_a_specific_node(self): + node = 'https://phx13' + self.database.add_node(node, 50) + # The new node is not selected by default, because of lower capacity. + user = self.database.allocate_user('test1@mozilla.com') + self.assertNotEqual(user['node'], node) + # But we can force it using keyword argument. + user = self.database.allocate_user('test2@mozilla.com', + node=node) + self.assertEqual(user['node'], node) + + def test_that_we_can_move_users_to_a_specific_node(self): + node = 'https://phx13' + self.database.add_node(node, 50) + # The new node is not selected by default, because of lower capacity. + user = self.database.allocate_user('test@mozilla.com') + self.assertNotEqual(user['node'], node) + # But we can move them there explicitly using keyword argument. + self.database.update_user(user, node=node) + self.assertEqual(user['node'], node) + # Sanity-check by re-reading it from the db. + user = self.database.get_user('test@mozilla.com') + self.assertEqual(user['node'], node) + # Check that it properly respects client-state and generation. + self.database.update_user(user, generation=12) + self.database.update_user(user, client_state='XXX') + self.database.update_user(user, generation=42, + client_state='YYY', node='https://phx12') + self.assertEqual(user['node'], 'https://phx12') + self.assertEqual(user['generation'], 42) + self.assertEqual(user['client_state'], 'YYY') + self.assertEqual(sorted(user['old_client_states']), ['', 'XXX']) + # Sanity-check by re-reading it from the db. + user = self.database.get_user('test@mozilla.com') + self.assertEqual(user['node'], 'https://phx12') + self.assertEqual(user['generation'], 42) + self.assertEqual(user['client_state'], 'YYY') + self.assertEqual(sorted(user['old_client_states']), ['', 'XXX']) + + def test_that_record_cleanup_frees_slots_on_the_node(self): + node = 'https://phx12' + self.database.update_node(node, capacity=10, available=1, + current_load=9) + # We should only be able to allocate one more user to that node. + user = self.database.allocate_user('test1@mozilla.com') + self.assertEqual(user['node'], node) + with self.assertRaises(Exception): + self.database.allocate_user('test2@mozilla.com') + # But when we clean up the user's record, it frees up the slot. + self.database.retire_user('test1@mozilla.com') + self.database.delete_user_record(user['uid']) + user = self.database.allocate_user('test2@mozilla.com') + self.assertEqual(user['node'], node) + + def test_gradual_release_of_node_capacity(self): + node1 = 'https://phx12' + self.database.update_node(node1, capacity=8, available=1, + current_load=4) + node2 = 'https://phx13' + self.database.add_node(node2, capacity=6, + available=1, current_load=4) + # Two allocations should succeed without update, one on each node. + user = self.database.allocate_user('test1@mozilla.com') + self.assertEqual(user['node'], node1) + user = self.database.allocate_user('test2@mozilla.com') + self.assertEqual(user['node'], node2) + # The next allocation attempt will release 10% more capacity, + # which is one more slot for each node. + user = self.database.allocate_user('test3@mozilla.com') + self.assertEqual(user['node'], node1) + user = self.database.allocate_user('test4@mozilla.com') + self.assertEqual(user['node'], node2) + # Now node2 is full, so further allocations all go to node1. + user = self.database.allocate_user('test5@mozilla.com') + self.assertEqual(user['node'], node1) + user = self.database.allocate_user('test6@mozilla.com') + self.assertEqual(user['node'], node1) + # Until it finally reaches capacity. + with self.assertRaises(Exception): + self.database.allocate_user('test7@mozilla.com') + + def test_count_users(self): + user = self.database.allocate_user('test1@example.com') + self.assertEqual(self.database.count_users(), 1) + old_timestamp = get_timestamp() + time.sleep(0.01) + # Adding users increases the count. + user = self.database.allocate_user('rfkelly@mozilla.com') + self.assertEqual(self.database.count_users(), 2) + # Updating a user doesn't change the count. + self.database.update_user(user, client_state='aaa') + self.assertEqual(self.database.count_users(), 2) + # Looking back in time doesn't count newer users. + self.assertEqual(self.database.count_users(old_timestamp), 1) + # Retiring a user decreases the count. + self.database.retire_user('test1@example.com') + self.assertEqual(self.database.count_users(), 1) + + def test_first_seen_at(self): + EMAIL = 'test1@example.com' + user0 = self.database.allocate_user(EMAIL) + user1 = self.database.get_user(EMAIL) + self.assertEqual(user1['uid'], user0['uid']) + self.assertEqual(user1['first_seen_at'], user0['first_seen_at']) + # It should stay consistent if we re-allocate the user's node. + time.sleep(0.1) + self.database.update_user(user1, client_state='aaa') + user2 = self.database.get_user(EMAIL) + self.assertNotEqual(user2['uid'], user0['uid']) + self.assertEqual(user2['first_seen_at'], user0['first_seen_at']) + # Until we purge their old node-assignment records. + self.database.delete_user_record(user0['uid']) + user3 = self.database.get_user(EMAIL) + self.assertEqual(user3['uid'], user2['uid']) + self.assertNotEqual(user3['first_seen_at'], user2['first_seen_at']) diff --git a/tools/tokenserver/test_process_account_events.py b/tools/tokenserver/test_process_account_events.py new file mode 100644 index 00000000..62bdd57d --- /dev/null +++ b/tools/tokenserver/test_process_account_events.py @@ -0,0 +1,244 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. + +import json +import os +import unittest + +from pyramid import testing +from testfixtures import LogCapture + +from database import Database +from process_account_events import process_account_event + + +PATTERN = "{node}/1.5/{uid}" +EMAIL = "test@example.com" +UID = "test" +ISS = "example.com" + + +def message_body(**kwds): + return json.dumps({ + "Message": json.dumps(kwds) + }) + + +class TestProcessAccountEvents(unittest.TestCase): + + def get_ini(self): + return os.path.join(os.path.dirname(__file__), + 'test_sql.ini') + + def setUp(self): + self.database = Database() + self.database.add_node("https://phx12", 100) + self.logs = LogCapture() + + def tearDown(self): + self.logs.uninstall() + testing.tearDown() + + cursor = self.database._execute_sql('DELETE FROM nodes') + cursor.close() + + cursor = self.database._execute_sql('DELETE FROM users') + cursor.close + + def assertMessageWasLogged(self, msg): + """Check that a metric was logged during the request.""" + for r in self.logs.records: + if msg in r.getMessage(): + break + else: + assert False, "message %r was not logged" % (msg,) + + def clearLogs(self): + del self.logs.records[:] + + def test_delete_user(self): + self.database.allocate_user(EMAIL) + user = self.database.get_user(EMAIL) + self.database.update_user(user, client_state="abcdef") + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 2) + self.assertTrue(records[0]["replaced_at"] is not None) + + process_account_event(message_body( + event="delete", + uid=UID, + iss=ISS, + )) + + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 2) + for row in records: + self.assertTrue(row["replaced_at"] is not None) + + def test_delete_user_by_legacy_uid_format(self): + self.database.allocate_user(EMAIL) + user = self.database.get_user(EMAIL) + self.database.update_user(user, client_state="abcdef") + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 2) + self.assertTrue(records[0]["replaced_at"] is not None) + + process_account_event(message_body( + event="delete", + uid=EMAIL, + )) + + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 2) + for row in records: + self.assertTrue(row["replaced_at"] is not None) + + def test_delete_user_who_is_not_in_the_db(self): + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 0) + + process_account_event(message_body( + event="delete", + uid=UID, + iss=ISS + )) + + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 0) + + def test_reset_user(self): + self.database.allocate_user(EMAIL, generation=12) + + process_account_event(message_body( + event="reset", + uid=UID, + iss=ISS, + generation=43, + )) + + user = self.database.get_user(EMAIL) + self.assertEquals(user["generation"], 42) + + def test_reset_user_by_legacy_uid_format(self): + self.database.allocate_user(EMAIL, generation=12) + + process_account_event(message_body( + event="reset", + uid=EMAIL, + generation=43, + )) + + user = self.database.get_user(EMAIL) + self.assertEquals(user["generation"], 42) + + def test_reset_user_who_is_not_in_the_db(self): + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 0) + + process_account_event(message_body( + event="reset", + uid=UID, + iss=ISS, + generation=43, + )) + + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 0) + + def test_password_change(self): + self.database.allocate_user(EMAIL, generation=12) + + process_account_event(message_body( + event="passwordChange", + uid=UID, + iss=ISS, + generation=43, + )) + + user = self.database.get_user(EMAIL) + self.assertEquals(user["generation"], 42) + + def test_password_change_user_not_in_db(self): + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 0) + + process_account_event(message_body( + event="passwordChange", + uid=UID, + iss=ISS, + generation=43, + )) + + records = list(self.database.get_user_records(EMAIL)) + self.assertEquals(len(records), 0) + + def test_malformed_events(self): + + # Unknown event type. + process_account_event(message_body( + event="party", + uid=UID, + iss=ISS, + generation=43, + )) + self.assertMessageWasLogged("Dropping unknown event type") + self.clearLogs() + + # Missing event type. + process_account_event(message_body( + uid=UID, + iss=ISS, + generation=43, + )) + self.assertMessageWasLogged("Invalid account message") + self.clearLogs() + + # Missing uid. + process_account_event(message_body( + event="delete", + iss=ISS, + )) + self.assertMessageWasLogged("Invalid account message") + self.clearLogs() + + # Missing generation for reset events. + process_account_event(message_body( + event="reset", + uid=UID, + iss=ISS, + )) + self.assertMessageWasLogged("Invalid account message") + self.clearLogs() + + # Missing generation for passwordChange events. + process_account_event(message_body( + event="passwordChange", + uid=UID, + iss=ISS, + )) + self.assertMessageWasLogged("Invalid account message") + self.clearLogs() + + # Missing issuer with nonemail uid + process_account_event(message_body( + event="delete", + uid=UID, + )) + self.assertMessageWasLogged("Invalid account message") + self.clearLogs() + + # Non-JSON garbage. + process_account_event("wat") + self.assertMessageWasLogged("Invalid account message") + self.clearLogs() + + # Non-JSON garbage in Message field. + process_account_event('{ "Message": "wat" }') + self.assertMessageWasLogged("Invalid account message") + self.clearLogs() + + # Badly-typed JSON value in Message field. + process_account_event('{ "Message": "[1, 2, 3"] }') + self.assertMessageWasLogged("Invalid account message") + self.clearLogs() diff --git a/tools/tokenserver/test_purge_old_records.py b/tools/tokenserver/test_purge_old_records.py new file mode 100644 index 00000000..451fe194 --- /dev/null +++ b/tools/tokenserver/test_purge_old_records.py @@ -0,0 +1,136 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. + +import hawkauthlib +import re +import threading +import tokenlib +import unittest +from wsgiref.simple_server import make_server + +from database import Database +from purge_old_records import purge_old_records + + +class TestPurgeOldRecords(unittest.TestCase): + """A testcase for proper functioning of the purge_old_records.py script. + + This is a tricky one, because we have to actually run the script and + test that it does the right thing. We also run a mock downstream service + so we can test that data-deletion requests go through ok. + """ + + @classmethod + def setUpClass(cls): + cls.service_requests = [] + cls.service_node = "http://localhost:8002" + cls.service = make_server("localhost", 8002, cls._service_app) + target = cls.service.serve_forever + cls.service_thread = threading.Thread(target=target) + # Note: If the following `start` causes the test thread to hang, + # you may need to specify + # `[app::pyramid.app] pyramid.worker_class = sync` in the test_*.ini + # files + cls.service_thread.start() + # This silences nuisance on-by-default logging output. + cls.service.RequestHandlerClass.log_request = lambda *a: None + + def setUp(self): + super(TestPurgeOldRecords, self).setUp() + + # Configure the node-assignment backend to talk to our test service. + self.database = Database() + self.database.add_node(self.service_node, 100) + + def tearDown(self): + cursor = self.database._execute_sql('DELETE FROM nodes') + cursor.close() + + cursor = self.database._execute_sql('DELETE FROM users') + cursor.close() + + del self.service_requests[:] + + @classmethod + def tearDownClass(cls): + cls.service.shutdown() + cls.service_thread.join() + + @classmethod + def _service_app(cls, environ, start_response): + cls.service_requests.append(environ) + start_response("200 OK", []) + return "" + + def test_purging_of_old_user_records(self): + # Make some old user records. + email = "test@mozilla.com" + user = self.database.allocate_user(email, client_state="aa", + generation=123) + self.database.update_user(user, client_state="bb", + generation=456, keys_changed_at=450) + self.database.update_user(user, client_state="cc", + generation=789) + user_records = list(self.database.get_user_records(email)) + self.assertEqual(len(user_records), 3) + user = self.database.get_user(email) + self.assertEquals(user["client_state"], "cc") + self.assertEquals(len(user["old_client_states"]), 2) + + # The default grace-period should prevent any cleanup. + node_secret = "SECRET" + self.assertTrue(purge_old_records(node_secret)) + user_records = list(self.database.get_user_records(email)) + self.assertEqual(len(user_records), 3) + self.assertEqual(len(self.service_requests), 0) + + # With no grace period, we should cleanup two old records. + self.assertTrue(purge_old_records(node_secret, grace_period=0)) + user_records = list(self.database.get_user_records(email)) + self.assertEqual(len(user_records), 1) + self.assertEqual(len(self.service_requests), 2) + + # Check that the proper delete requests were made to the service. + expected_kids = ["0000000000450-uw", "0000000000123-qg"] + for i, environ in enumerate(self.service_requests): + # They must be to the correct path. + self.assertEquals(environ["REQUEST_METHOD"], "DELETE") + self.assertTrue(re.match("/1.5/[0-9]+", environ["PATH_INFO"])) + # They must have a correct request signature. + token = hawkauthlib.get_id(environ) + secret = tokenlib.get_derived_secret(token, secret=node_secret) + self.assertTrue(hawkauthlib.check_signature(environ, secret)) + userdata = tokenlib.parse_token(token, secret=node_secret) + self.assertTrue("uid" in userdata) + self.assertTrue("node" in userdata) + self.assertEqual(userdata["fxa_uid"], "test") + self.assertEqual(userdata["fxa_kid"], expected_kids[i]) + + # Check that the user's current state is unaffected + user = self.database.get_user(email) + self.assertEquals(user["client_state"], "cc") + self.assertEquals(len(user["old_client_states"]), 0) + + def test_purging_is_not_done_on_downed_nodes(self): + # Make some old user records. + node_secret = "SECRET" + email = "test@mozilla.com" + user = self.database.allocate_user(email, client_state="aa") + self.database.update_user(user, client_state="bb") + user_records = list(self.database.get_user_records(email)) + self.assertEqual(len(user_records), 2) + + # With the node down, we should not purge any records. + self.database.update_node(self.service_node, downed=1) + self.assertTrue(purge_old_records(node_secret, grace_period=0)) + user_records = list(self.database.get_user_records(email)) + self.assertEqual(len(user_records), 2) + self.assertEqual(len(self.service_requests), 0) + + # With the node back up, we should purge correctly. + self.database.update_node(self.service_node, downed=0) + self.assertTrue(purge_old_records(node_secret, grace_period=0)) + user_records = list(self.database.get_user_records(email)) + self.assertEqual(len(user_records), 1) + self.assertEqual(len(self.service_requests), 1) diff --git a/tools/tokenserver/test_scripts.py b/tools/tokenserver/test_scripts.py new file mode 100644 index 00000000..42f0fbb3 --- /dev/null +++ b/tools/tokenserver/test_scripts.py @@ -0,0 +1,223 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. + +import json +import os +import unittest +import uuid + +from add_node import main as add_node_script +from allocate_user import main as allocate_user_script +from count_users import main as count_users_script +from database import Database +from remove_node import main as remove_node_script +from unassign_node import main as unassign_node_script +from update_node import main as update_node_script +from util import get_timestamp + + +class TestScripts(unittest.TestCase): + NODE_ID = 800 + NODE_URL = 'https://node1' + + def setUp(self): + self.database = Database() + + # Start each test with a blank slate. + cursor = self.database._execute_sql('DELETE FROM users') + cursor.close() + + cursor = self.database._execute_sql('DELETE FROM nodes') + cursor.close() + + # Ensure we have a node with enough capacity to run the tests. + self.database.add_node(self.NODE_URL, 100, id=self.NODE_ID) + + def tearDown(self): + # And clean up at the end, for good measure. + cursor = self.database._execute_sql('DELETE FROM users') + cursor.close() + + cursor = self.database._execute_sql('DELETE FROM nodes') + cursor.close() + + self.database.close() + + def test_add_node(self): + add_node_script( + args=['--current-load', '9', 'test_node', '100'] + ) + res = self.database.get_node('test_node') + # The node should have the expected attributes + self.assertEqual(res.capacity, 100) + self.assertEqual(res.available, 10) + self.assertEqual(res.current_load, 9) + self.assertEqual(res.downed, 0) + self.assertEqual(res.backoff, 0) + self.assertEqual(res.service, self.database.service_id) + + def test_add_node_with_explicit_available(self): + args = ['--current-load', '9', '--available', '5', 'test_node', '100'] + add_node_script(args=args) + res = self.database.get_node('test_node') + # The node should have the expected attributes + self.assertEqual(res.capacity, 100) + self.assertEqual(res.available, 5) + self.assertEqual(res.current_load, 9) + self.assertEqual(res.downed, 0) + self.assertEqual(res.backoff, 0) + self.assertEqual(res.service, self.database.service_id) + + def test_add_downed_node(self): + add_node_script( + args=['--downed', 'test_node', '100'] + ) + res = self.database.get_node('test_node') + # The node should have the expected attributes + self.assertEqual(res.capacity, 100) + self.assertEqual(res.available, 10) + self.assertEqual(res.current_load, 0) + self.assertEqual(res.downed, 1) + self.assertEqual(res.backoff, 0) + self.assertEqual(res.service, self.database.service_id) + + def test_add_backoff_node(self): + add_node_script( + args=['--backoff', 'test_node', '100'] + ) + res = self.database.get_node('test_node') + # The node should have the expected attributes + self.assertEqual(res.capacity, 100) + self.assertEqual(res.available, 10) + self.assertEqual(res.current_load, 0) + self.assertEqual(res.downed, 0) + self.assertEqual(res.backoff, 1) + self.assertEqual(res.service, self.database.service_id) + + def test_allocate_user_user_already_exists(self): + email = 'test@test.com' + self.database.allocate_user(email) + node = 'https://node2' + self.database.add_node(node, 100) + allocate_user_script(args=[email, node]) + user = self.database.get_user(email) + # The user should be assigned to the given node + self.assertEqual(user['node'], node) + # Another user should not have been created + count = self.database.count_users() + self.assertEqual(count, 1) + + def test_allocate_user_given_node(self): + email = 'test@test.com' + node = 'https://node2' + self.database.add_node(node, 100) + allocate_user_script(args=[email, node]) + user = self.database.get_user(email) + # A new user should be created and assigned to the given node + self.assertEqual(user['node'], node) + + def test_allocate_user_not_given_node(self): + email = 'test@test.com' + self.database.add_node('https://node2', 100, + current_load=10) + self.database.add_node('https://node3', 100, + current_load=20) + self.database.add_node('https://node4', 100, + current_load=30) + allocate_user_script(args=[email]) + user = self.database.get_user(email) + # The user should be assigned to the least-loaded node + self.assertEqual(user['node'], 'https://node1') + + def test_count_users(self): + self.database.allocate_user('test1@test.com') + self.database.allocate_user('test2@test.com') + self.database.allocate_user('test3@test.com') + + timestamp = get_timestamp() + filename = '/tmp/' + str(uuid.uuid4()) + try: + count_users_script( + args=['--output', filename, '--timestamp', str(timestamp)] + ) + + with open(filename) as f: + info = json.loads(f.readline()) + self.assertEqual(info['total_users'], 3) + self.assertEqual(info['op'], 'sync_count_users') + finally: + os.remove(filename) + + filename = '/tmp/' + str(uuid.uuid4()) + try: + args = ['--output', filename, '--timestamp', + str(timestamp - 10000)] + count_users_script(args=args) + + with open(filename) as f: + info = json.loads(f.readline()) + self.assertEqual(info['total_users'], 0) + self.assertEqual(info['op'], 'sync_count_users') + finally: + os.remove(filename) + + def test_remove_node(self): + self.database.add_node('https://node2', 100) + self.database.allocate_user('test1@test.com', + node='https://node2') + self.database.allocate_user('test2@test.com', + node=self.NODE_URL) + self.database.allocate_user('test3@test.com', + node=self.NODE_URL) + + remove_node_script(args=['https://node2']) + + # The node should have been removed from the database + args = ['https://node2'] + self.assertRaises(ValueError, self.database.get_node_id, *args) + # The first user should have been assigned to a new node + user = self.database.get_user('test1@test.com') + self.assertEqual(user['node'], self.NODE_URL) + # The second and third users should still be on the first node + user = self.database.get_user('test2@test.com') + self.assertEqual(user['node'], self.NODE_URL) + user = self.database.get_user('test3@test.com') + self.assertEqual(user['node'], self.NODE_URL) + + def test_unassign_node(self): + self.database.add_node('https://node2', 100) + self.database.allocate_user('test1@test.com', + node='https://node2') + self.database.allocate_user('test2@test.com', + node='https://node2') + self.database.allocate_user('test3@test.com', + node=self.NODE_URL) + + unassign_node_script(args=['https://node2']) + self.database.remove_node('https://node2') + # All of the users should now be assigned to the first node + user = self.database.get_user('test1@test.com') + self.assertEqual(user['node'], self.NODE_URL) + user = self.database.get_user('test2@test.com') + self.assertEqual(user['node'], self.NODE_URL) + user = self.database.get_user('test3@test.com') + self.assertEqual(user['node'], self.NODE_URL) + + def test_update_node(self): + self.database.add_node('https://node2', 100) + update_node_script(args=[ + '--capacity', '150', + '--available', '125', + '--current-load', '25', + '--downed', + '--backoff', + 'https://node2' + ]) + node = self.database.get_node('https://node2') + # Ensure the node has the expected attributes + self.assertEqual(node['capacity'], 150) + self.assertEqual(node['available'], 125) + self.assertEqual(node['current_load'], 25) + self.assertEqual(node['downed'], 1) + self.assertEqual(node['backoff'], 1) diff --git a/tools/tokenserver/unassign_node.py b/tools/tokenserver/unassign_node.py new file mode 100644 index 00000000..0f5b1a82 --- /dev/null +++ b/tools/tokenserver/unassign_node.py @@ -0,0 +1,72 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +Script to remove a node from the system. + +This script clears any assignments to the named node. + +""" + +import logging +import optparse + +from database import Database +import util + + +logger = logging.getLogger("tokenserver.scripts.unassign_node") + + +def unassign_node(node): + """Clear any assignments to the named node.""" + logger.info("Unassignment node %s", node) + try: + database = Database() + found = False + try: + database.unassign_node(node) + except ValueError: + logger.debug(" not found") + else: + found = True + logger.debug(" unassigned") + except Exception: + logger.exception("Error while unassigning node") + return False + else: + if not found: + logger.info("Node %s was not found", node) + else: + logger.info("Finished unassigning node %s", node) + return True + + +def main(args=None): + """Main entry-point for running this script. + + This function parses command-line arguments and passes them on + to the unassign_node() function. + """ + usage = "usage: %prog [options] node_name" + descr = "Clear all assignments to node in the tokenserver database" + parser = optparse.OptionParser(usage=usage, description=descr) + parser.add_option("-v", "--verbose", action="count", dest="verbosity", + help="Control verbosity of log messages") + + opts, args = parser.parse_args(args) + if len(args) != 1: + parser.print_usage() + return 1 + + util.configure_script_logging(opts) + + node_name = args[0] + + unassign_node(node_name) + return 0 + + +if __name__ == "__main__": + util.run_script(main) diff --git a/tools/tokenserver/update_node.py b/tools/tokenserver/update_node.py new file mode 100644 index 00000000..45a18479 --- /dev/null +++ b/tools/tokenserver/update_node.py @@ -0,0 +1,83 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +Script to update node status in the db. + +""" + +import logging +import optparse + +from database import Database +import util + + +logger = logging.getLogger("tokenserver.scripts.update_node") + + +def update_node(node, **kwds): + """Update details of a node.""" + logger.info("Updating node %s for service %s", node) + logger.debug("Value: %r", kwds) + try: + database = Database() + database.update_node(node, **kwds) + except Exception: + logger.exception("Error while updating node") + return False + else: + logger.info("Finished updating node %s", node) + return True + + +def main(args=None): + """Main entry-point for running this script. + + This function parses command-line arguments and passes them on + to the update_node() function. + """ + usage = "usage: %prog [options] node_name" + descr = "Update node details in the tokenserver database" + parser = optparse.OptionParser(usage=usage, description=descr) + parser.add_option("", "--capacity", type="int", + help="How many user slots the node has overall") + parser.add_option("", "--available", type="int", + help="How many user slots the node has available") + parser.add_option("", "--current-load", type="int", + help="How many user slots the node has occupied") + parser.add_option("", "--downed", action="store_true", + help="Mark the node as down in the db") + parser.add_option("", "--backoff", action="store_true", + help="Mark the node as backed-off in the db") + parser.add_option("-v", "--verbose", action="count", dest="verbosity", + help="Control verbosity of log messages") + + opts, args = parser.parse_args(args) + if len(args) != 1: + parser.print_usage() + return 1 + + util.configure_script_logging(opts) + + node_name = args[0] + + kwds = {} + if opts.capacity is not None: + kwds["capacity"] = opts.capacity + if opts.available is not None: + kwds["available"] = opts.available + if opts.current_load is not None: + kwds["current_load"] = opts.current_load + if opts.backoff is not None: + kwds["backoff"] = opts.backoff + if opts.downed is not None: + kwds["downed"] = opts.downed + + update_node(node_name, **kwds) + return 0 + + +if __name__ == "__main__": + util.run_script(main) diff --git a/tools/tokenserver/util.py b/tools/tokenserver/util.py new file mode 100644 index 00000000..2810da51 --- /dev/null +++ b/tools/tokenserver/util.py @@ -0,0 +1,58 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +""" + +Admin/managment scripts for TokenServer. + +""" + +import sys +import time +import logging + +from browserid.utils import encode_bytes as encode_bytes_b64 + + +def run_script(main): + """Simple wrapper for running scripts in __main__ section.""" + try: + exitcode = main() + except KeyboardInterrupt: + exitcode = 1 + sys.exit(exitcode) + + +def configure_script_logging(opts=None): + """Configure stdlib logging to produce output from the script. + + This basically configures logging to send messages to stderr, with + formatting that's more for human readability than machine parsing. + It also takes care of the --verbosity command-line option. + """ + if not opts or not opts.verbosity: + loglevel = logging.WARNING + elif opts.verbosity == 1: + loglevel = logging.INFO + else: + loglevel = logging.DEBUG + + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(message)s")) + handler.setLevel(loglevel) + logger = logging.getLogger("") + logger.addHandler(handler) + logger.setLevel(loglevel) + + +def format_key_id(keys_changed_at, key_hash): + """Format an FxA key ID from a timestamp and key hash.""" + return "{:013d}-{}".format( + keys_changed_at, + encode_bytes_b64(key_hash), + ) + + +def get_timestamp(): + """Get current timestamp in milliseconds.""" + return int(time.time() * 1000)