feat: add Tokenserver admin scripts (#1168)

Closes #1086
This commit is contained in:
Ethan Donowitz 2021-11-18 09:37:02 -05:00 committed by GitHub
parent 0996cb154f
commit 0ac30958de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 2627 additions and 1 deletions

View File

@ -38,7 +38,7 @@ commands:
command: |
cargo fmt -- --check
# https://github.com/bodil/sized-chunks/issues/11
cargo audit --ignore RUSTSEC-2020-0041 --ignore RUSTSEC-2021-0078 --ignore RUSTSEC-2021-0079 --ignore RUSTSEC-2020-0159 --ignore RUSTSEC-2020-0071
cargo audit --ignore RUSTSEC-2020-0041 --ignore RUSTSEC-2021-0078 --ignore RUSTSEC-2021-0079 --ignore RUSTSEC-2020-0159 --ignore RUSTSEC-2020-0071 --ignore RUSTSEC-2021-0124
python-check:
steps:
- run:
@ -46,6 +46,7 @@ commands:
command: |
flake8 src/tokenserver/verify.py
flake8 tools/integration_tests
flake8 tools/tokenserver
rust-clippy:
steps:
- run:
@ -113,6 +114,16 @@ commands:
environment:
SYNCSTORAGE_RS_IMAGE: app:build
run-tokenserver-scripts-tests:
steps:
- run:
name: Tokenserver scripts tests
command: >
pip3 install -r tools/tokenserver/requirements.txt &&
python3 tools/tokenserver/run_tests.py
environment:
SYNCSTORAGE_RS_IMAGE: app:build
run-e2e-spanner-tests:
steps:
- run:
@ -211,6 +222,7 @@ jobs:
- write-version
- cargo-build
- run-tests
- run-tokenserver-scripts-tests
#- save-sccache-cache
- run:
name: Build Docker image

View File

View File

@ -0,0 +1,80 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Script to add a new node to the system.
"""
import logging
import optparse
from database import Database, SERVICE_NAME
import util
logger = logging.getLogger("tokenserver.scripts.add_node")
def add_node(node, capacity, **kwds):
"""Add the specific node to the system."""
logger.info("Adding node %s to service %s", node, SERVICE_NAME)
try:
database = Database()
database.add_node(node, capacity, **kwds)
except Exception:
logger.exception("Error while adding node")
return False
else:
logger.info("Finished adding node %s", node)
return True
def main(args=None):
"""Main entry-point for running this script.
This function parses command-line arguments and passes them on
to the add_node() function.
"""
usage = "usage: %prog [options] node_name capacity"
descr = "Add a new node to the tokenserver database"
parser = optparse.OptionParser(usage=usage, description=descr)
parser.add_option("", "--available", type="int",
help="How many user slots the node has available")
parser.add_option("", "--current-load", type="int",
help="How many user slots the node has occupied")
parser.add_option("", "--downed", action="store_true",
help="Mark the node as down in the db")
parser.add_option("", "--backoff", action="store_true",
help="Mark the node as backed-off in the db")
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
help="Control verbosity of log messages")
opts, args = parser.parse_args(args)
if len(args) != 2:
parser.print_usage()
return 1
util.configure_script_logging(opts)
node_name = args[0]
capacity = int(args[1])
kwds = {}
if opts.available is not None:
kwds["available"] = opts.available
if opts.current_load is not None:
kwds["current_load"] = opts.current_load
if opts.backoff is not None:
kwds["backoff"] = opts.backoff
if opts.downed is not None:
kwds["downed"] = opts.downed
add_node(node_name, capacity, **kwds)
return 0
if __name__ == "__main__":
util.run_script(main)

View File

@ -0,0 +1,73 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Script to allocate a specific user to a node.
This script allocates the specified user to a node. A particular node
may be specified, or the best available node used by default.
The allocated node is printed to stdout.
"""
import logging
import optparse
from database import Database
import util
logger = logging.getLogger("tokenserver.scripts.allocate_user")
def allocate_user(email, node=None):
logger.info("Allocating node for user %s", email)
try:
database = Database()
user = database.get_user(email)
if user is None:
user = database.allocate_user(email, node=node)
else:
database.update_user(user, node=node)
except Exception:
logger.exception("Error while updating node")
return False
else:
logger.info("Finished updating node %s", node)
return True
def main(args=None):
"""Main entry-point for running this script.
This function parses command-line arguments and passes them on
to the allocate_user() function.
"""
usage = "usage: %prog [options] email [node_name]"
descr = "Allocate a user to a node. You may specify a particular node, "\
"or omit to use the best available node."
parser = optparse.OptionParser(usage=usage, description=descr)
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
help="Control verbosity of log messages")
opts, args = parser.parse_args(args)
if not 1 <= len(args) <= 2:
parser.print_usage()
return 1
util.configure_script_logging(opts)
email = args[0]
if len(args) == 1:
node_name = None
else:
node_name = args[1]
allocate_user(email, node_name)
return 0
if __name__ == "__main__":
util.run_script(main)

View File

@ -0,0 +1,98 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Script to emit total-user-count metrics for exec dashboard.
"""
import json
import logging
import optparse
import os
import socket
import sys
import time
from datetime import datetime, timedelta, tzinfo
from database import Database
import util
logger = logging.getLogger("tokenserver.scripts.count_users")
ZERO = timedelta(0)
class UTC(tzinfo):
def utcoffset(self, dt):
return ZERO
def tzname(self, dt):
return "UTC"
def dst(self, dt):
return ZERO
utc = UTC()
def count_users(outfile, timestamp=None):
if timestamp is None:
ts = time.gmtime()
midnight = (ts[0], ts[1], ts[2], 0, 0, 0, ts[6], ts[7], ts[8])
timestamp = int(time.mktime(midnight)) * 1000
database = Database()
logger.debug("Counting users created before %i", timestamp)
count = database.count_users(timestamp)
logger.debug("Found %d users", count)
# Output has heka-filter-compatible JSON object.
ts_sec = timestamp / 1000
output = {
"hostname": socket.gethostname(),
"pid": os.getpid(),
"op": "sync_count_users",
"total_users": count,
"time": datetime.fromtimestamp(ts_sec, utc).isoformat(),
"v": 0
}
json.dump(output, outfile)
outfile.write("\n")
def main(args=None):
"""Main entry-point for running this script.
This function parses command-line arguments and passes them on
to the add_node() function.
"""
usage = "usage: %prog [options]"
descr = "Count total users in the tokenserver database"
parser = optparse.OptionParser(usage=usage, description=descr)
parser.add_option("-t", "--timestamp", type="int",
help="Max creation timestamp; default previous midnight")
parser.add_option("-o", "--output",
help="Output file; default stderr")
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
help="Control verbosity of log messages")
opts, args = parser.parse_args(args)
if len(args) != 0:
parser.print_usage()
return 1
util.configure_script_logging(opts)
if opts.output in (None, "-"):
count_users(sys.stdout, opts.timestamp)
else:
with open(opts.output, "a") as outfile:
count_users(outfile, opts.timestamp)
return 0
if __name__ == "__main__":
util.run_script(main)

View File

@ -0,0 +1,641 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import math
import os
from sqlalchemy import create_engine
from sqlalchemy.sql import text as sqltext
from util import get_timestamp
# The maximum possible generation number.
# Used as a tombstone to mark users that have been "retired" from the db.
MAX_GENERATION = 9223372036854775807
NODE_FIELDS = ("capacity", "available", "current_load", "downed", "backoff")
_GET_USER_RECORDS = sqltext("""\
select
uid, nodes.node, generation, keys_changed_at, client_state, created_at,
replaced_at
from
users left outer join nodes on users.nodeid = nodes.id
where
email = :email and users.service = :service
order by
created_at desc, uid desc
limit
20
""")
_CREATE_USER_RECORD = sqltext("""\
insert into
users
(service, email, nodeid, generation, keys_changed_at, client_state,
created_at, replaced_at)
values
(:service, :email, :nodeid, :generation, :keys_changed_at,
:client_state, :timestamp, NULL)
""")
# The `where` clause on this statement is designed as an extra layer of
# protection, to ensure that concurrent updates don't accidentally move
# timestamp fields backwards in time. The handling of `keys_changed_at`
# is additionally weird because we want to treat the default `NULL` value
# as zero.
_UPDATE_USER_RECORD_IN_PLACE = sqltext("""\
update
users
set
generation = COALESCE(:generation, generation),
keys_changed_at = COALESCE(:keys_changed_at, keys_changed_at)
where
service = :service and email = :email and
generation <= COALESCE(:generation, generation) and
COALESCE(keys_changed_at, 0) <=
COALESCE(:keys_changed_at, keys_changed_at, 0) and
replaced_at is null
""")
_REPLACE_USER_RECORDS = sqltext("""\
update
users
set
replaced_at = :timestamp
where
service = :service and email = :email
and replaced_at is null and created_at < :timestamp
""")
# Mark all records for the user as replaced,
# and set a large generation number to block future logins.
_RETIRE_USER_RECORDS = sqltext("""\
update
users
set
replaced_at = :timestamp,
generation = :generation
where
email = :email
and replaced_at is null
""")
_GET_OLD_USER_RECORDS_FOR_SERVICE = sqltext("""\
select
uid, email, generation, keys_changed_at, client_state,
nodes.node, nodes.downed, created_at, replaced_at
from
users left outer join nodes on users.nodeid = nodes.id
where
users.service = :service
and
replaced_at is not null and replaced_at < :timestamp
order by
replaced_at desc, uid desc
limit
:limit
offset
:offset
""")
_GET_ALL_USER_RECORDS_FOR_SERVICE = sqltext("""\
select
uid, nodes.node, created_at, replaced_at
from
users left outer join nodes on users.nodeid = nodes.id
where
email = :email and users.service = :service
order by
created_at asc, uid desc
""")
_REPLACE_USER_RECORD = sqltext("""\
update
users
set
replaced_at = :timestamp
where
service = :service
and
uid = :uid
""")
_DELETE_USER_RECORD = sqltext("""\
delete from
users
where
service = :service
and
uid = :uid
""")
_FREE_SLOT_ON_NODE = sqltext("""\
update
nodes
set
available = available + 1, current_load = current_load - 1
where
id = (SELECT nodeid FROM users WHERE service=:service AND uid=:uid)
""")
_COUNT_USER_RECORDS = sqltext("""\
select
count(email)
from
users
where
replaced_at is null
and created_at <= :timestamp
""")
_GET_BEST_NODE = sqltext("""\
select
id, node
from
nodes
where
service = :service
and available > 0
and capacity > current_load
and downed = 0
and backoff = 0
order by
log(current_load) / log(capacity)
limit 1
""")
_RELEASE_NODE_CAPACITY = sqltext("""\
update
nodes
set
available = least(capacity * :capacity_release_rate,
capacity - current_load)
where
service = :service
and available <= 0
and capacity > current_load
and downed = 0
""")
_ADD_USER_TO_NODE = sqltext("""\
update
nodes
set
current_load = current_load + 1,
available = greatest(available - 1, 0)
where
service = :service
and node = :node
""")
_GET_SERVICE_ID = sqltext("""\
select
id
from
services
where
service = :service
""")
_GET_NODE = sqltext("""\
select
*
from
nodes
where
service = :service
and node = :node
""")
SERVICE_NAME = 'sync-1.5'
class Database:
def __init__(self):
engine = create_engine(os.environ['SYNC_TOKENSERVER__DATABASE_URL'])
self.database = engine. \
execution_options(isolation_level="AUTOCOMMIT"). \
connect()
self.service_id = self._get_service_id(SERVICE_NAME)
self.capacity_release_rate = os.environ. \
get("NODE_CAPACITY_RELEASE_RATE", 0.1)
def _execute_sql(self, *args, **kwds):
return self.database.execute(*args, **kwds)
def close(self):
self.database.close()
def get_user(self, email):
params = {'service': self.service_id, 'email': email}
res = self._execute_sql(_GET_USER_RECORDS, **params)
try:
# The query fetches rows ordered by created_at, but we want
# to ensure that they're ordered by (generation, created_at).
# This is almost always true, except for strange race conditions
# during row creation. Sorting them is an easy way to enforce
# this without bloating the db index.
rows = res.fetchall()
rows.sort(key=lambda r: (r.generation, r.created_at), reverse=True)
if not rows:
return None
# The first row is the most up-to-date user record.
# The rest give previously-seen client-state values.
cur_row = rows[0]
old_rows = rows[1:]
user = {
'email': email,
'uid': cur_row.uid,
'node': cur_row.node,
'generation': cur_row.generation,
'keys_changed_at': cur_row.keys_changed_at or 0,
'client_state': cur_row.client_state,
'old_client_states': {},
'first_seen_at': cur_row.created_at,
}
# If the current row is marked as replaced or is missing a node,
# and they haven't been retired, then assign them a new node.
if cur_row.replaced_at is not None or cur_row.node is None:
if cur_row.generation < MAX_GENERATION:
user = self.allocate_user(email,
cur_row.generation,
cur_row.client_state,
cur_row.keys_changed_at)
for old_row in old_rows:
# Collect any previously-seen client-state values.
if old_row.client_state != user['client_state']:
user['old_client_states'][old_row.client_state] = True
# Make sure each old row is marked as replaced.
# They might not be, due to races in row creation.
if old_row.replaced_at is None:
timestamp = cur_row.created_at
self.replace_user_record(old_row.uid, timestamp)
# Track backwards to the oldest timestamp at which we saw them.
user['first_seen_at'] = old_row.created_at
return user
finally:
res.close()
def allocate_user(self, email, generation=0, client_state='',
keys_changed_at=0, node=None, timestamp=None):
if timestamp is None:
timestamp = get_timestamp()
if node is None:
nodeid, node = self.get_best_node()
else:
nodeid = self.get_node_id(node)
params = {
'service': self.service_id,
'email': email,
'nodeid': nodeid,
'generation': generation,
'keys_changed_at': keys_changed_at,
'client_state': client_state,
'timestamp': timestamp
}
res = self._execute_sql(_CREATE_USER_RECORD, **params)
return {
'email': email,
'uid': res.lastrowid,
'node': node,
'generation': generation,
'keys_changed_at': keys_changed_at,
'client_state': client_state,
'old_client_states': {},
'first_seen_at': timestamp,
}
def update_user(self, user, generation=None, client_state=None,
keys_changed_at=None, node=None):
if client_state is None and node is None:
# No need for a node-reassignment, just update the row in place.
# Note that if we're changing keys_changed_at without changing
# client_state, it's because we're seeing an existing value of
# keys_changed_at for the first time.
params = {
'service': self.service_id,
'email': user['email'],
'generation': generation,
'keys_changed_at': keys_changed_at
}
res = self._execute_sql(_UPDATE_USER_RECORD_IN_PLACE, **params)
res.close()
user['generation'] = max([x
for x
in [generation, user['generation']]
if x is not None])
user['keys_changed_at'] = max([x
for x
in [keys_changed_at,
user['keys_changed_at']]
if x is not None])
else:
# Reject previously-seen client-state strings.
if client_state is None:
client_state = user['client_state']
else:
if client_state == user['client_state']:
raise Exception('previously seen client-state string')
if client_state in user['old_client_states']:
raise Exception('previously seen client-state string')
# Need to create a new record for new user state.
# If the node is not explicitly changing, try to keep them on the
# same node, but if e.g. it no longer exists them allocate them to
# a new one.
if node is not None:
nodeid = self.get_node_id(node)
user['node'] = node
else:
try:
nodeid = self.get_node_id(user['node'])
except ValueError:
nodeid, node = self.get_best_node()
user['node'] = node
if generation is not None:
generation = max(user['generation'], generation)
else:
generation = user['generation']
if keys_changed_at is not None:
keys_changed_at = max(user['keys_changed_at'], keys_changed_at)
else:
keys_changed_at = user['keys_changed_at']
now = get_timestamp()
params = {
'service': self.service_id, 'email': user['email'],
'nodeid': nodeid, 'generation': generation,
'keys_changed_at': keys_changed_at,
'client_state': client_state, 'timestamp': now,
}
res = self._execute_sql(_CREATE_USER_RECORD, **params)
res.close()
user['uid'] = res.lastrowid
user['generation'] = generation
user['keys_changed_at'] = keys_changed_at
user['old_client_states'][user['client_state']] = True
user['client_state'] = client_state
# mark old records as having been replaced.
# if we crash here, they are unmarked and we may fail to
# garbage collect them for a while, but the active state
# will be undamaged.
self.replace_user_records(user['email'], now)
def retire_user(self, email):
now = get_timestamp()
params = {
'email': email, 'timestamp': now, 'generation': MAX_GENERATION
}
# Pass through explicit engine to help with sharded implementation,
# since we can't shard by service name here.
res = self._execute_sql(_RETIRE_USER_RECORDS, **params)
res.close()
def count_users(self, timestamp=None):
if timestamp is None:
timestamp = get_timestamp()
res = self._execute_sql(_COUNT_USER_RECORDS, timestamp=timestamp)
row = res.fetchone()
res.close()
return row[0]
#
# Methods for low-level user record management.
#
def get_user_records(self, email):
"""Get all the user's records, including the old ones."""
params = {'service': self.service_id, 'email': email}
res = self._execute_sql(_GET_ALL_USER_RECORDS_FOR_SERVICE, **params)
try:
for row in res:
yield row
finally:
res.close()
def get_old_user_records(self, grace_period=-1, limit=100,
offset=0):
"""Get user records that were replaced outside the grace period."""
if grace_period < 0:
grace_period = 60 * 60 * 24 * 7 # one week, in seconds
grace_period = int(grace_period * 1000) # convert seconds -> millis
params = {
"service": self.service_id,
"timestamp": get_timestamp() - grace_period,
"limit": limit,
"offset": offset
}
res = self._execute_sql(_GET_OLD_USER_RECORDS_FOR_SERVICE, **params)
try:
for row in res:
yield row
finally:
res.close()
def replace_user_records(self, email, timestamp=None):
"""Mark all existing records for a user as replaced."""
if timestamp is None:
timestamp = get_timestamp()
params = {
'service': self.service_id, 'email': email, 'timestamp': timestamp
}
res = self._execute_sql(_REPLACE_USER_RECORDS, **params)
res.close()
def replace_user_record(self, uid, timestamp=None):
"""Mark an existing service record as replaced."""
if timestamp is None:
timestamp = get_timestamp()
params = {
'service': self.service_id, 'uid': uid, 'timestamp': timestamp
}
res = self._execute_sql(_REPLACE_USER_RECORD, **params)
res.close()
def delete_user_record(self, uid):
"""Delete the user record with the given uid."""
params = {'service': self.service_id, 'uid': uid}
res = self._execute_sql(_FREE_SLOT_ON_NODE, **params)
res.close()
res = self._execute_sql(_DELETE_USER_RECORD, **params)
res.close()
#
# Nodes management
#
def _get_service_id(self, service):
res = self._execute_sql(_GET_SERVICE_ID, service=service)
row = res.fetchone()
res.close()
if row is None:
raise Exception('unknown service: ' + service)
return row.id
def add_service(self, service_name, pattern, **kwds):
"""Add definition for a new service."""
res = self._execute_sql(sqltext("""
insert into services (service, pattern)
values (:servicename, :pattern)
"""), servicename=service_name, pattern=pattern, **kwds)
res.close()
return res.lastrowid
def add_node(self, node, capacity, **kwds):
"""Add definition for a new node."""
available = kwds.get('available')
# We release only a fraction of the node's capacity to start.
if available is None:
available = math.ceil(capacity * self.capacity_release_rate)
cols = ["service", "node", "available", "capacity",
"current_load", "downed", "backoff"]
args = [":" + v for v in cols]
# Handle test cases that require nodeid to be 800
if "nodeid" in kwds:
cols.append("id")
args.append(":nodeid")
query = """
insert into nodes ({cols})
values ({args})
""".format(cols=", ".join(cols), args=", ".join(args))
res = self._execute_sql(
sqltext(query),
nodeid=kwds.get('nodeid'),
service=self.service_id,
node=node,
capacity=capacity,
available=available,
current_load=kwds.get('current_load', 0),
downed=kwds.get('downed', 0),
backoff=kwds.get('backoff', 0),
)
res.close()
def update_node(self, node, **kwds):
"""Updates node fields in the db."""
values = {}
cols = NODE_FIELDS & kwds.keys()
for col in NODE_FIELDS:
try:
values[col] = kwds.pop(col)
except KeyError:
pass
args = [v + " = :" + v for v in cols]
query = """
update nodes
set """
query += ", ".join(args)
query += """
where service = :service and node = :node
"""
values['service'] = self.service_id
values['node'] = node
if kwds:
raise ValueError("unknown fields: " + str(kwds.keys()))
con = self._execute_sql(sqltext(query), **values)
con.close()
def get_node_id(self, node):
"""Get numeric id for a node."""
res = self._execute_sql(
sqltext("""
select id from nodes
where service=:service and node=:node
"""),
service=self.service_id, node=node
)
row = res.fetchone()
res.close()
if row is None:
raise ValueError("unknown node: " + node)
return row[0]
def remove_node(self, node, timestamp=None):
"""Remove definition for a node."""
nodeid = self.get_node_id(node)
res = self._execute_sql(sqltext(
"""
delete from nodes where id=:nodeid
"""),
nodeid=nodeid
)
res.close()
self.unassign_node(node, timestamp, nodeid=nodeid)
def unassign_node(self, node, timestamp=None, nodeid=None):
"""Clear any assignments to a node."""
if timestamp is None:
timestamp = get_timestamp()
if nodeid is None:
nodeid = self.get_node_id(node)
res = self._execute_sql(
sqltext("""
update users
set replaced_at=:timestamp
where nodeid=:nodeid
"""),
nodeid=nodeid, timestamp=timestamp
)
res.close()
def get_best_node(self):
"""Returns the 'least loaded' node currently available, increments the
active count on that node, and decrements the slots currently available
"""
# We may have to re-try the query if we need to release more capacity.
# This loop allows a maximum of five retries before bailing out.
for _ in range(5):
res = self._execute_sql(_GET_BEST_NODE, service=self.service_id)
row = res.fetchone()
res.close()
if row is None:
# Try to release additional capacity from any nodes
# that are not fully occupied.
res = self._execute_sql(
_RELEASE_NODE_CAPACITY,
capacity_release_rate=self.capacity_release_rate,
service=self.service_id
)
res.close()
if res.rowcount == 0:
break
else:
break
# Did we succeed in finding a node?
if row is None:
raise Exception('unable to get a node')
nodeid = row.id
node = str(row.node)
# Update the node to reflect the new assignment.
# This is a little racy with concurrent assignments, but no big
# deal.
con = self._execute_sql(_ADD_USER_TO_NODE,
service=self.service_id,
node=node)
con.close()
return nodeid, node
def get_node(self, node):
res = self._execute_sql(_GET_NODE, service=self.service_id, node=node)
row = res.fetchone()
res.close()
if row is None:
raise Exception('unknown node: ' + node)
return row

View File

@ -0,0 +1,179 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Script to process account-related events from an SQS queue.
This script polls an SQS queue for events indicating activity on an upstream
account, as documented here:
https://github.com/mozilla/fxa-auth-server/blob/master/docs/service_notifications.md
The following event types are currently supported:
* "delete": the account was deleted; we mark their records as retired
so they'll be cleaned up by our garbage-collection process.
* "reset": the account password was reset; we update our copy of their
generation number to disconnect other devices.
* "passwordChange": the account password was changed; we update our copy
of their generation number to disconnect other devices.
Note that this is a purely optional administrative task, highly specific to
Mozilla's internal Firefox-Accounts-supported deployment.
"""
import json
import logging
import optparse
import boto
import boto.ec2
import boto.sqs
import boto.sqs.message
import boto.utils
import util
from database import Database
logger = logging.getLogger("tokenserver.scripts.process_account_deletions")
def process_account_events(queue_name, aws_region=None, queue_wait_time=20):
"""Process account events from an SQS queue.
This function polls the specified SQS queue for account-realted events,
processing each as it is found. It polls indefinitely and does not return;
to interrupt execution you'll need to e.g. SIGINT the process.
"""
logger.info("Processing account events from %s", queue_name)
try:
# Connect to the SQS queue.
# If no region is given, infer it from the instance metadata.
if aws_region is None:
logger.debug("Finding default region from instance metadata")
aws_info = boto.utils.get_instance_metadata()
aws_region = aws_info["placement"]["availability-zone"][:-1]
logger.debug("Connecting to queue %r in %r", queue_name, aws_region)
conn = boto.sqs.connect_to_region(aws_region)
queue = conn.get_queue(queue_name)
# We must force boto not to b64-decode the message contents, ugh.
queue.set_message_class(boto.sqs.message.RawMessage)
# Poll for messages indefinitely.
while True:
msg = queue.read(wait_time_seconds=queue_wait_time)
if msg is None:
continue
process_account_event(msg.get_body())
# This intentionally deletes the event even if it was some
# unrecognized type. Not point leaving a backlog.
queue.delete_message(msg)
except Exception:
logger.exception("Error while processing account events")
raise
def process_account_event(body):
"""Parse and process a single account event."""
database = Database()
# Try very hard not to error out if there's junk in the queue.
email = None
event_type = None
generation = None
try:
body = json.loads(body)
event = json.loads(body['Message'])
event_type = event["event"]
uid = event["uid"]
# Older versions of the fxa-auth-server would send an email-like
# identifier the "uid" field, but that doesn't make sense for any
# relier other than tokenserver. Newer versions send just the raw uid
# in the "uid" field, and include the domain in a separate "iss" field.
if "iss" in event:
email = "%s@%s" % (uid, event["iss"])
else:
if "@" not in uid:
raise ValueError("uid field does not contain issuer info")
email = uid
if event_type in ("reset", "passwordChange",):
generation = event["generation"]
except (ValueError, KeyError) as e:
logger.exception("Invalid account message: %s", e)
else:
if email is not None:
if event_type == "delete":
# Mark the user as retired.
# Actual cleanup is done by a separate process.
logger.info("Processing account delete for %r", email)
database.retire_user(email)
elif event_type == "reset":
logger.info("Processing account reset for %r", email)
update_generation_number(database, email, generation)
elif event_type == "passwordChange":
logger.info("Processing password change for %r", email)
update_generation_number(database, email, generation)
else:
logger.warning("Dropping unknown event type %r",
event_type)
def update_generation_number(database, email, generation):
"""Update the maximum recorded generation number for the given user.
When the FxA server sends us an update to the user's generation
number, we want to update our high-water-mark in the DB in order to
immediately lock out disconnected devices. However, since we don't
know the new value of the client state that goes with it, we can't just
record the new generation number in the DB. If we did, the first
device that tried to sync with the new generation number would appear
to have an incorrect client state value, and would be rejected.
Instead, we take advantage of the fact that it's a timestamp, and write
it into the DB at one millisecond less than its current value. This
ensures that we lock out any devices with an older generation number
while avoiding errors with client state handling.
This does leave a tiny edge-case where we can fail to lock out older
devices, if the generation number changes twice in less than a
millisecond. This is acceptably unlikely in practice, and we'll recover
as soon as we see an updated generation number as part of a sync.
"""
user = database.get_user(email)
if user is not None:
database.update_user(user, generation - 1)
def main(args=None):
"""Main entry-point for running this script.
This function parses command-line arguments and passes them on
to the process_account_events() function.
"""
usage = "usage: %prog [options] queue_name"
parser = optparse.OptionParser(usage=usage)
parser.add_option("", "--aws-region",
help="aws region in which the queue can be found")
parser.add_option("", "--queue-wait-time", type="int", default=20,
help="Number of seconds to wait for jobs on the queue")
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
help="Control verbosity of log messages")
opts, args = parser.parse_args(args)
if len(args) != 1:
parser.print_usage()
return 1
util.configure_script_logging(opts)
queue_name = args[0]
process_account_events(queue_name, opts.aws_region, opts.queue_wait_time)
return 0
if __name__ == "__main__":
util.run_script(main)

View File

@ -0,0 +1,177 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Script to purge user records that have been replaced.
This script purges any obsolete user records from the database.
Obsolete records are those that have been replaced by a newer record for
the same user.
Note that this is a purely optional administrative task, since replaced records
are handled internally by the assignment backend. But it should help reduce
overheads, improve performance etc if run regularly.
"""
import binascii
import hawkauthlib
import logging
import optparse
import random
import requests
import time
import tokenlib
import util
from database import Database
from util import format_key_id
logger = logging.getLogger("tokenserver.scripts.purge_old_records")
PATTERN = "{node}/1.5/{uid}"
def purge_old_records(secret, grace_period=-1, max_per_loop=10, max_offset=0,
request_timeout=60):
"""Purge old records from the database.
This function queries all of the old user records in the database, deletes
the Tokenserver database record for each of the users, and issues a delete
request to each user's storage node. The result is a gradual pruning of
expired items from each database.
`max_offset` is used to select a random offset into the list of purgeable
records. With multiple tasks running concurrently, this will provide each
a (likely) different set of records to work on. A cheap, imperfect
randomization.
"""
logger.info("Purging old user records")
try:
database = Database()
# Process batches of <max_per_loop> items, until we run out.
while True:
offset = random.randint(0, max_offset)
kwds = {
"grace_period": grace_period,
"limit": max_per_loop,
"offset": offset,
}
rows = list(database.get_old_user_records(**kwds))
logger.info("Fetched %d rows at offset %d", len(rows), offset)
for row in rows:
# Don't attempt to purge data from downed nodes.
# Instead wait for them to either come back up or to be
# completely removed from service.
if row.node is None:
logger.info("Deleting user record for uid %s on %s",
row.uid, row.node)
database.delete_user_record(row.uid)
elif not row.downed:
logger.info("Purging uid %s on %s", row.uid, row.node)
delete_service_data(row, secret, timeout=request_timeout)
database.delete_user_record(row.uid)
if len(rows) < max_per_loop:
break
except Exception:
logger.exception("Error while purging old user records")
return False
else:
logger.info("Finished purging old user records")
return True
def delete_service_data(user, secret, timeout=60):
"""Send a data-deletion request to the user's service node.
This is a little bit of hackery to cause the user's service node to
remove any data it still has stored for the user. We simulate a DELETE
request from the user's own account.
"""
token = tokenlib.make_token({
"uid": user.uid,
"node": user.node,
"fxa_uid": user.email.split("@", 1)[0],
"fxa_kid": format_key_id(
user.keys_changed_at or user.generation,
binascii.unhexlify(user.client_state)
),
}, secret=secret)
secret = tokenlib.get_derived_secret(token, secret=secret)
endpoint = PATTERN.format(uid=user.uid, node=user.node)
auth = HawkAuth(token, secret)
resp = requests.delete(endpoint, auth=auth, timeout=timeout)
if resp.status_code >= 400 and resp.status_code != 404:
resp.raise_for_status()
class HawkAuth(requests.auth.AuthBase):
"""Hawk-signing auth helper class."""
def __init__(self, token, secret):
self.token = token
self.secret = secret
def __call__(self, req):
hawkauthlib.sign_request(req, self.token, self.secret)
return req
def main(args=None):
"""Main entry-point for running this script.
This function parses command-line arguments and passes them on
to the purge_old_records() function.
"""
usage = "usage: %prog [options] secret"
parser = optparse.OptionParser(usage=usage)
parser.add_option("", "--purge-interval", type="int", default=3600,
help="Interval to sleep between purging runs")
parser.add_option("", "--grace-period", type="int", default=86400,
help="Number of seconds grace to allow on replacement")
parser.add_option("", "--max-per-loop", type="int", default=10,
help="Maximum number of items to fetch in one go")
# N.B., if the number of purgeable rows is <<< max_offset then most
# selects will return zero rows. Choose this value accordingly.
parser.add_option("", "--max-offset", type="int", default=0,
help="Use random offset from 0 to max_offset")
parser.add_option("", "--request-timeout", type="int", default=60,
help="Timeout for service deletion requests")
parser.add_option("", "--oneshot", action="store_true",
help="Do a single purge run and then exit")
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
help="Control verbosity of log messages")
opts, args = parser.parse_args(args)
if len(args) != 2:
parser.print_usage()
return 1
secret = args[1]
util.configure_script_logging(opts)
purge_old_records(secret,
grace_period=opts.grace_period,
max_per_loop=opts.max_per_loop,
max_offset=opts.max_offset,
request_timeout=opts.request_timeout)
if not opts.oneshot:
while True:
# Randomize sleep interval +/- thirty percent to desynchronize
# instances of this script running on multiple webheads.
sleep_time = opts.purge_interval
sleep_time += random.randint(-0.3 * sleep_time, 0.3 * sleep_time)
logger.debug("Sleeping for %d seconds", sleep_time)
time.sleep(sleep_time)
purge_old_records(grace_period=opts.grace_period,
max_per_loop=opts.max_per_loop,
max_offset=opts.max_offset,
request_timeout=opts.request_timeout)
return 0
if __name__ == "__main__":
util.run_script(main)

View File

@ -0,0 +1,73 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Script to remove a node from the system.
This script nukes any references to the named node - it is removed from
the "nodes" table and any users currently assigned to that node have their
assignments cleared.
"""
import logging
import optparse
import util
from database import Database
logger = logging.getLogger("tokenserver.scripts.remove_node")
def remove_node(node):
"""Remove the named node from the system."""
logger.info("Removing node %s", node)
try:
database = Database()
found = False
try:
database.remove_node(node)
except ValueError:
logger.debug(" not found")
else:
found = True
logger.debug(" removed")
except Exception:
logger.exception("Error while removing node")
return False
else:
if not found:
logger.info("Node %s was not found", node)
else:
logger.info("Finished removing node %s", node)
return True
def main(args=None):
"""Main entry-point for running this script.
This function parses command-line arguments and passes them on
to the remove_node() function.
"""
usage = "usage: %prog [options] node_name"
descr = "Remove a node from the tokenserver database"
parser = optparse.OptionParser(usage=usage, description=descr)
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
help="Control verbosity of log messages")
opts, args = parser.parse_args(args)
if len(args) != 1:
parser.print_usage()
return 1
util.configure_script_logging(opts)
node_name = args[0]
remove_node(node_name)
return 0
if __name__ == "__main__":
util.run_script(main)

View File

@ -0,0 +1,6 @@
boto
hawkauthlib
mysqlclient
pyramid
sqlalchemy
testfixtures

View File

@ -0,0 +1,25 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import sys
import unittest
from test_database import TestDatabase
from test_process_account_events import TestProcessAccountEvents
from test_purge_old_records import TestPurgeOldRecords
from test_scripts import TestScripts
if __name__ == "__main__":
loader = unittest.TestLoader()
test_cases = [TestDatabase, TestPurgeOldRecords, TestProcessAccountEvents,
TestScripts]
res = 0
for test_case in test_cases:
suite = loader.loadTestsFromTestCase(test_case)
runner = unittest.TextTestRunner()
if not runner.run(suite).wasSuccessful():
res = 1
sys.exit(res)

View File

@ -0,0 +1,446 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import time
import unittest
from collections import defaultdict
from database import MAX_GENERATION, Database
from util import get_timestamp
class TestDatabase(unittest.TestCase):
def setUp(self):
super(TestDatabase, self).setUp()
self.database = Database()
# Start each test with a blank slate.
cursor = self.database._execute_sql(('DELETE FROM users'), ())
cursor.close()
cursor = self.database._execute_sql(('DELETE FROM nodes'), ())
cursor.close()
self.database.add_node('https://phx12', 100)
def tearDown(self):
super(TestDatabase, self).tearDown()
# And clean up at the end, for good measure.
cursor = self.database._execute_sql(('DELETE FROM users'), ())
cursor.close()
cursor = self.database._execute_sql(('DELETE FROM nodes'), ())
cursor.close()
self.database.close()
def test_node_allocation(self):
user = self.database.get_user('test1@example.com')
self.assertEquals(user, None)
user = self.database.allocate_user('test1@example.com')
wanted = 'https://phx12'
self.assertEqual(user['node'], wanted)
user = self.database.get_user('test1@example.com')
self.assertEqual(user['node'], wanted)
def test_allocation_to_least_loaded_node(self):
self.database.add_node('https://phx13', 100)
user1 = self.database.allocate_user('test1@mozilla.com')
user2 = self.database.allocate_user('test2@mozilla.com')
self.assertNotEqual(user1['node'], user2['node'])
def test_allocation_is_not_allowed_to_downed_nodes(self):
self.database.update_node('https://phx12',
downed=True)
with self.assertRaises(Exception):
self.database.allocate_user('test1@mozilla.com')
def test_allocation_is_not_allowed_to_backoff_nodes(self):
self.database.update_node('https://phx12',
backoff=True)
with self.assertRaises(Exception):
self.database.allocate_user('test1@mozilla.com')
def test_update_generation_number(self):
user = self.database.allocate_user('test1@example.com')
self.assertEqual(user['generation'], 0)
self.assertEqual(user['client_state'], '')
orig_uid = user['uid']
orig_node = user['node']
# Changing generation should leave other properties unchanged.
self.database.update_user(user, generation=42)
self.assertEqual(user['uid'], orig_uid)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 42)
self.assertEqual(user['client_state'], '')
user = self.database.get_user('test1@example.com')
self.assertEqual(user['uid'], orig_uid)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 42)
self.assertEqual(user['client_state'], '')
# It's not possible to move generation number backwards.
self.database.update_user(user, generation=17)
self.assertEqual(user['uid'], orig_uid)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 42)
self.assertEqual(user['client_state'], '')
user = self.database.get_user('test1@example.com')
self.assertEqual(user['uid'], orig_uid)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 42)
self.assertEqual(user['client_state'], '')
def test_update_client_state(self):
user = self.database.allocate_user('test1@example.com')
self.assertEqual(user['generation'], 0)
self.assertEqual(user['client_state'], '')
self.assertEqual(set(user['old_client_states']), set(()))
seen_uids = set((user['uid'],))
orig_node = user['node']
# Changing client-state allocates a new userid.
self.database.update_user(user, client_state='aaa')
self.assertTrue(user['uid'] not in seen_uids)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 0)
self.assertEqual(user['client_state'], 'aaa')
self.assertEqual(set(user['old_client_states']), set(('',)))
user = self.database.get_user('test1@example.com')
self.assertTrue(user['uid'] not in seen_uids)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 0)
self.assertEqual(user['client_state'], 'aaa')
self.assertEqual(set(user['old_client_states']), set(('',)))
seen_uids.add(user['uid'])
# It's possible to change client-state and generation at once.
self.database.update_user(user,
client_state='bbb', generation=12)
self.assertTrue(user['uid'] not in seen_uids)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 12)
self.assertEqual(user['client_state'], 'bbb')
self.assertEqual(set(user['old_client_states']), set(('', 'aaa')))
user = self.database.get_user('test1@example.com')
self.assertTrue(user['uid'] not in seen_uids)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 12)
self.assertEqual(user['client_state'], 'bbb')
self.assertEqual(set(user['old_client_states']), set(('', 'aaa')))
# You can't got back to an old client_state.
orig_uid = user['uid']
with self.assertRaises(Exception):
self.database.update_user(user,
client_state='aaa')
user = self.database.get_user('test1@example.com')
self.assertEqual(user['uid'], orig_uid)
self.assertEqual(user['node'], orig_node)
self.assertEqual(user['generation'], 12)
self.assertEqual(user['client_state'], 'bbb')
self.assertEqual(set(user['old_client_states']), set(('', 'aaa')))
def test_user_retirement(self):
self.database.allocate_user('test@mozilla.com')
user1 = self.database.get_user('test@mozilla.com')
self.database.retire_user('test@mozilla.com')
user2 = self.database.get_user('test@mozilla.com')
self.assertTrue(user2['generation'] > user1['generation'])
def test_cleanup_of_old_records(self):
# Create 6 user records for the first user.
# Do a sleep halfway through so we can test use of grace period.
email1 = 'test1@mozilla.com'
user1 = self.database.allocate_user(email1)
self.database.update_user(user1, client_state='a')
self.database.update_user(user1, client_state='b')
self.database.update_user(user1, client_state='c')
break_time = time.time()
time.sleep(0.1)
self.database.update_user(user1, client_state='d')
self.database.update_user(user1, client_state='e')
records = list(self.database.get_user_records(email1))
self.assertEqual(len(records), 6)
# Create 3 user records for the second user.
email2 = 'test2@mozilla.com'
user2 = self.database.allocate_user(email2)
self.database.update_user(user2, client_state='a')
self.database.update_user(user2, client_state='b')
records = list(self.database.get_user_records(email2))
self.assertEqual(len(records), 3)
# That should be a total of 7 old records.
old_records = list(self.database.get_old_user_records(0))
self.assertEqual(len(old_records), 7)
# And with max_offset of 3, the first record should be id 4
old_records = list(self.database.get_old_user_records(0,
100, 3))
# The 'limit' parameter should be respected.
old_records = list(self.database.get_old_user_records(0, 2))
self.assertEqual(len(old_records), 2)
# The default grace period is too big to pick them up.
old_records = list(self.database.get_old_user_records())
self.assertEqual(len(old_records), 0)
# The grace period can select a subset of the records.
grace = time.time() - break_time
old_records = list(self.database.get_old_user_records(grace))
self.assertEqual(len(old_records), 3)
# Old records can be successfully deleted:
for record in old_records:
self.database.delete_user_record(record.uid)
old_records = list(self.database.get_old_user_records(0))
self.assertEqual(len(old_records), 4)
def test_node_reassignment_when_records_are_replaced(self):
self.database.allocate_user('test@mozilla.com',
generation=42,
keys_changed_at=12,
client_state='aaa')
user1 = self.database.get_user('test@mozilla.com')
self.database.replace_user_records('test@mozilla.com')
user2 = self.database.get_user('test@mozilla.com')
# They should have got a new uid.
self.assertNotEqual(user2['uid'], user1['uid'])
# But their account metadata should have been preserved.
self.assertEqual(user2['generation'], user1['generation'])
self.assertEqual(user2['keys_changed_at'], user1['keys_changed_at'])
self.assertEqual(user2['client_state'], user1['client_state'])
def test_node_reassignment_not_done_for_retired_users(self):
self.database.allocate_user('test@mozilla.com',
generation=42, client_state='aaa')
user1 = self.database.get_user('test@mozilla.com')
self.database.retire_user('test@mozilla.com')
user2 = self.database.get_user('test@mozilla.com')
self.assertEqual(user2['uid'], user1['uid'])
self.assertEqual(user2['generation'], MAX_GENERATION)
self.assertEqual(user2['client_state'], user2['client_state'])
def test_recovery_from_racy_record_creation(self):
timestamp = get_timestamp()
# Simulate race for forcing creation of two rows with same timestamp.
user1 = self.database.allocate_user('test@mozilla.com',
timestamp=timestamp)
user2 = self.database.allocate_user('test@mozilla.com',
timestamp=timestamp)
self.assertNotEqual(user1['uid'], user2['uid'])
# Neither is marked replaced initially.
old_records = list(
self.database.get_old_user_records(0)
)
self.assertEqual(len(old_records), 0)
# Reading current details will detect the problem and fix it.
self.database.get_user('test@mozilla.com')
old_records = list(
self.database.get_old_user_records(0)
)
self.assertEqual(len(old_records), 1)
def test_that_race_recovery_respects_generation_number_monotonicity(self):
timestamp = get_timestamp()
# Simulate race between clients with different generation numbers,
# in which the out-of-date client gets a higher timestamp.
user1 = self.database.allocate_user('test@mozilla.com',
generation=1,
timestamp=timestamp)
user2 = self.database.allocate_user('test@mozilla.com',
generation=2,
timestamp=timestamp - 1)
self.assertNotEqual(user1['uid'], user2['uid'])
# Reading current details should promote the higher-generation one.
user = self.database.get_user('test@mozilla.com')
self.assertEqual(user['generation'], 2)
self.assertEqual(user['uid'], user2['uid'])
# And the other record should get marked as replaced.
old_records = list(
self.database.get_old_user_records(0)
)
self.assertEqual(len(old_records), 1)
def test_node_reassignment_and_removal(self):
NODE1 = 'https://phx12'
NODE2 = 'https://phx13'
# note that NODE1 is created by default for all tests.
self.database.add_node(NODE2, 100)
# Assign four users, we should get two on each node.
user1 = self.database.allocate_user('test1@mozilla.com')
user2 = self.database.allocate_user('test2@mozilla.com')
user3 = self.database.allocate_user('test3@mozilla.com')
user4 = self.database.allocate_user('test4@mozilla.com')
node_counts = defaultdict(lambda: 0)
for user in (user1, user2, user3, user4):
node_counts[user['node']] += 1
self.assertEqual(node_counts[NODE1], 2)
self.assertEqual(node_counts[NODE2], 2)
# Clear the assignments for NODE1, and re-assign.
# The users previously on NODE1 should balance across both nodes,
# giving 1 on NODE1 and 3 on NODE2.
self.database.unassign_node(NODE1)
node_counts = defaultdict(lambda: 0)
for user in (user1, user2, user3, user4):
new_user = self.database.get_user(user['email'])
if user['node'] == NODE2:
self.assertEqual(new_user['node'], NODE2)
node_counts[new_user['node']] += 1
self.assertEqual(node_counts[NODE1], 1)
self.assertEqual(node_counts[NODE2], 3)
# Remove NODE2. Everyone should wind up on NODE1.
self.database.remove_node(NODE2)
for user in (user1, user2, user3, user4):
new_user = self.database.get_user(user['email'])
self.assertEqual(new_user['node'], NODE1)
# The old users records pointing to NODE2 should have a NULL 'node'
# property since it has been removed from the db.
null_node_count = 0
for row in self.database.get_old_user_records(0):
if row.node is None:
null_node_count += 1
else:
self.assertEqual(row.node, NODE1)
self.assertEqual(null_node_count, 3)
def test_that_race_recovery_respects_generation_after_reassignment(self):
timestamp = get_timestamp()
# Simulate race between clients with different generation numbers,
# in which the out-of-date client gets a higher timestamp.
user1 = self.database.allocate_user('test@mozilla.com',
generation=1,
timestamp=timestamp)
user2 = self.database.allocate_user('test@mozilla.com',
generation=2,
timestamp=timestamp - 1)
self.assertNotEqual(user1['uid'], user2['uid'])
# Force node re-assignment by marking all records as replaced.
self.database.replace_user_records('test@mozilla.com',
timestamp=timestamp + 1)
# The next client to show up should get a new assignment, marked
# with the correct generation number.
user = self.database.get_user('test@mozilla.com')
self.assertEqual(user['generation'], 2)
self.assertNotEqual(user['uid'], user1['uid'])
self.assertNotEqual(user['uid'], user2['uid'])
def test_that_we_can_allocate_users_to_a_specific_node(self):
node = 'https://phx13'
self.database.add_node(node, 50)
# The new node is not selected by default, because of lower capacity.
user = self.database.allocate_user('test1@mozilla.com')
self.assertNotEqual(user['node'], node)
# But we can force it using keyword argument.
user = self.database.allocate_user('test2@mozilla.com',
node=node)
self.assertEqual(user['node'], node)
def test_that_we_can_move_users_to_a_specific_node(self):
node = 'https://phx13'
self.database.add_node(node, 50)
# The new node is not selected by default, because of lower capacity.
user = self.database.allocate_user('test@mozilla.com')
self.assertNotEqual(user['node'], node)
# But we can move them there explicitly using keyword argument.
self.database.update_user(user, node=node)
self.assertEqual(user['node'], node)
# Sanity-check by re-reading it from the db.
user = self.database.get_user('test@mozilla.com')
self.assertEqual(user['node'], node)
# Check that it properly respects client-state and generation.
self.database.update_user(user, generation=12)
self.database.update_user(user, client_state='XXX')
self.database.update_user(user, generation=42,
client_state='YYY', node='https://phx12')
self.assertEqual(user['node'], 'https://phx12')
self.assertEqual(user['generation'], 42)
self.assertEqual(user['client_state'], 'YYY')
self.assertEqual(sorted(user['old_client_states']), ['', 'XXX'])
# Sanity-check by re-reading it from the db.
user = self.database.get_user('test@mozilla.com')
self.assertEqual(user['node'], 'https://phx12')
self.assertEqual(user['generation'], 42)
self.assertEqual(user['client_state'], 'YYY')
self.assertEqual(sorted(user['old_client_states']), ['', 'XXX'])
def test_that_record_cleanup_frees_slots_on_the_node(self):
node = 'https://phx12'
self.database.update_node(node, capacity=10, available=1,
current_load=9)
# We should only be able to allocate one more user to that node.
user = self.database.allocate_user('test1@mozilla.com')
self.assertEqual(user['node'], node)
with self.assertRaises(Exception):
self.database.allocate_user('test2@mozilla.com')
# But when we clean up the user's record, it frees up the slot.
self.database.retire_user('test1@mozilla.com')
self.database.delete_user_record(user['uid'])
user = self.database.allocate_user('test2@mozilla.com')
self.assertEqual(user['node'], node)
def test_gradual_release_of_node_capacity(self):
node1 = 'https://phx12'
self.database.update_node(node1, capacity=8, available=1,
current_load=4)
node2 = 'https://phx13'
self.database.add_node(node2, capacity=6,
available=1, current_load=4)
# Two allocations should succeed without update, one on each node.
user = self.database.allocate_user('test1@mozilla.com')
self.assertEqual(user['node'], node1)
user = self.database.allocate_user('test2@mozilla.com')
self.assertEqual(user['node'], node2)
# The next allocation attempt will release 10% more capacity,
# which is one more slot for each node.
user = self.database.allocate_user('test3@mozilla.com')
self.assertEqual(user['node'], node1)
user = self.database.allocate_user('test4@mozilla.com')
self.assertEqual(user['node'], node2)
# Now node2 is full, so further allocations all go to node1.
user = self.database.allocate_user('test5@mozilla.com')
self.assertEqual(user['node'], node1)
user = self.database.allocate_user('test6@mozilla.com')
self.assertEqual(user['node'], node1)
# Until it finally reaches capacity.
with self.assertRaises(Exception):
self.database.allocate_user('test7@mozilla.com')
def test_count_users(self):
user = self.database.allocate_user('test1@example.com')
self.assertEqual(self.database.count_users(), 1)
old_timestamp = get_timestamp()
time.sleep(0.01)
# Adding users increases the count.
user = self.database.allocate_user('rfkelly@mozilla.com')
self.assertEqual(self.database.count_users(), 2)
# Updating a user doesn't change the count.
self.database.update_user(user, client_state='aaa')
self.assertEqual(self.database.count_users(), 2)
# Looking back in time doesn't count newer users.
self.assertEqual(self.database.count_users(old_timestamp), 1)
# Retiring a user decreases the count.
self.database.retire_user('test1@example.com')
self.assertEqual(self.database.count_users(), 1)
def test_first_seen_at(self):
EMAIL = 'test1@example.com'
user0 = self.database.allocate_user(EMAIL)
user1 = self.database.get_user(EMAIL)
self.assertEqual(user1['uid'], user0['uid'])
self.assertEqual(user1['first_seen_at'], user0['first_seen_at'])
# It should stay consistent if we re-allocate the user's node.
time.sleep(0.1)
self.database.update_user(user1, client_state='aaa')
user2 = self.database.get_user(EMAIL)
self.assertNotEqual(user2['uid'], user0['uid'])
self.assertEqual(user2['first_seen_at'], user0['first_seen_at'])
# Until we purge their old node-assignment records.
self.database.delete_user_record(user0['uid'])
user3 = self.database.get_user(EMAIL)
self.assertEqual(user3['uid'], user2['uid'])
self.assertNotEqual(user3['first_seen_at'], user2['first_seen_at'])

View File

@ -0,0 +1,244 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import json
import os
import unittest
from pyramid import testing
from testfixtures import LogCapture
from database import Database
from process_account_events import process_account_event
PATTERN = "{node}/1.5/{uid}"
EMAIL = "test@example.com"
UID = "test"
ISS = "example.com"
def message_body(**kwds):
return json.dumps({
"Message": json.dumps(kwds)
})
class TestProcessAccountEvents(unittest.TestCase):
def get_ini(self):
return os.path.join(os.path.dirname(__file__),
'test_sql.ini')
def setUp(self):
self.database = Database()
self.database.add_node("https://phx12", 100)
self.logs = LogCapture()
def tearDown(self):
self.logs.uninstall()
testing.tearDown()
cursor = self.database._execute_sql('DELETE FROM nodes')
cursor.close()
cursor = self.database._execute_sql('DELETE FROM users')
cursor.close
def assertMessageWasLogged(self, msg):
"""Check that a metric was logged during the request."""
for r in self.logs.records:
if msg in r.getMessage():
break
else:
assert False, "message %r was not logged" % (msg,)
def clearLogs(self):
del self.logs.records[:]
def test_delete_user(self):
self.database.allocate_user(EMAIL)
user = self.database.get_user(EMAIL)
self.database.update_user(user, client_state="abcdef")
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 2)
self.assertTrue(records[0]["replaced_at"] is not None)
process_account_event(message_body(
event="delete",
uid=UID,
iss=ISS,
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 2)
for row in records:
self.assertTrue(row["replaced_at"] is not None)
def test_delete_user_by_legacy_uid_format(self):
self.database.allocate_user(EMAIL)
user = self.database.get_user(EMAIL)
self.database.update_user(user, client_state="abcdef")
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 2)
self.assertTrue(records[0]["replaced_at"] is not None)
process_account_event(message_body(
event="delete",
uid=EMAIL,
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 2)
for row in records:
self.assertTrue(row["replaced_at"] is not None)
def test_delete_user_who_is_not_in_the_db(self):
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
process_account_event(message_body(
event="delete",
uid=UID,
iss=ISS
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
def test_reset_user(self):
self.database.allocate_user(EMAIL, generation=12)
process_account_event(message_body(
event="reset",
uid=UID,
iss=ISS,
generation=43,
))
user = self.database.get_user(EMAIL)
self.assertEquals(user["generation"], 42)
def test_reset_user_by_legacy_uid_format(self):
self.database.allocate_user(EMAIL, generation=12)
process_account_event(message_body(
event="reset",
uid=EMAIL,
generation=43,
))
user = self.database.get_user(EMAIL)
self.assertEquals(user["generation"], 42)
def test_reset_user_who_is_not_in_the_db(self):
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
process_account_event(message_body(
event="reset",
uid=UID,
iss=ISS,
generation=43,
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
def test_password_change(self):
self.database.allocate_user(EMAIL, generation=12)
process_account_event(message_body(
event="passwordChange",
uid=UID,
iss=ISS,
generation=43,
))
user = self.database.get_user(EMAIL)
self.assertEquals(user["generation"], 42)
def test_password_change_user_not_in_db(self):
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
process_account_event(message_body(
event="passwordChange",
uid=UID,
iss=ISS,
generation=43,
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
def test_malformed_events(self):
# Unknown event type.
process_account_event(message_body(
event="party",
uid=UID,
iss=ISS,
generation=43,
))
self.assertMessageWasLogged("Dropping unknown event type")
self.clearLogs()
# Missing event type.
process_account_event(message_body(
uid=UID,
iss=ISS,
generation=43,
))
self.assertMessageWasLogged("Invalid account message")
self.clearLogs()
# Missing uid.
process_account_event(message_body(
event="delete",
iss=ISS,
))
self.assertMessageWasLogged("Invalid account message")
self.clearLogs()
# Missing generation for reset events.
process_account_event(message_body(
event="reset",
uid=UID,
iss=ISS,
))
self.assertMessageWasLogged("Invalid account message")
self.clearLogs()
# Missing generation for passwordChange events.
process_account_event(message_body(
event="passwordChange",
uid=UID,
iss=ISS,
))
self.assertMessageWasLogged("Invalid account message")
self.clearLogs()
# Missing issuer with nonemail uid
process_account_event(message_body(
event="delete",
uid=UID,
))
self.assertMessageWasLogged("Invalid account message")
self.clearLogs()
# Non-JSON garbage.
process_account_event("wat")
self.assertMessageWasLogged("Invalid account message")
self.clearLogs()
# Non-JSON garbage in Message field.
process_account_event('{ "Message": "wat" }')
self.assertMessageWasLogged("Invalid account message")
self.clearLogs()
# Badly-typed JSON value in Message field.
process_account_event('{ "Message": "[1, 2, 3"] }')
self.assertMessageWasLogged("Invalid account message")
self.clearLogs()

View File

@ -0,0 +1,136 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import hawkauthlib
import re
import threading
import tokenlib
import unittest
from wsgiref.simple_server import make_server
from database import Database
from purge_old_records import purge_old_records
class TestPurgeOldRecords(unittest.TestCase):
"""A testcase for proper functioning of the purge_old_records.py script.
This is a tricky one, because we have to actually run the script and
test that it does the right thing. We also run a mock downstream service
so we can test that data-deletion requests go through ok.
"""
@classmethod
def setUpClass(cls):
cls.service_requests = []
cls.service_node = "http://localhost:8002"
cls.service = make_server("localhost", 8002, cls._service_app)
target = cls.service.serve_forever
cls.service_thread = threading.Thread(target=target)
# Note: If the following `start` causes the test thread to hang,
# you may need to specify
# `[app::pyramid.app] pyramid.worker_class = sync` in the test_*.ini
# files
cls.service_thread.start()
# This silences nuisance on-by-default logging output.
cls.service.RequestHandlerClass.log_request = lambda *a: None
def setUp(self):
super(TestPurgeOldRecords, self).setUp()
# Configure the node-assignment backend to talk to our test service.
self.database = Database()
self.database.add_node(self.service_node, 100)
def tearDown(self):
cursor = self.database._execute_sql('DELETE FROM nodes')
cursor.close()
cursor = self.database._execute_sql('DELETE FROM users')
cursor.close()
del self.service_requests[:]
@classmethod
def tearDownClass(cls):
cls.service.shutdown()
cls.service_thread.join()
@classmethod
def _service_app(cls, environ, start_response):
cls.service_requests.append(environ)
start_response("200 OK", [])
return ""
def test_purging_of_old_user_records(self):
# Make some old user records.
email = "test@mozilla.com"
user = self.database.allocate_user(email, client_state="aa",
generation=123)
self.database.update_user(user, client_state="bb",
generation=456, keys_changed_at=450)
self.database.update_user(user, client_state="cc",
generation=789)
user_records = list(self.database.get_user_records(email))
self.assertEqual(len(user_records), 3)
user = self.database.get_user(email)
self.assertEquals(user["client_state"], "cc")
self.assertEquals(len(user["old_client_states"]), 2)
# The default grace-period should prevent any cleanup.
node_secret = "SECRET"
self.assertTrue(purge_old_records(node_secret))
user_records = list(self.database.get_user_records(email))
self.assertEqual(len(user_records), 3)
self.assertEqual(len(self.service_requests), 0)
# With no grace period, we should cleanup two old records.
self.assertTrue(purge_old_records(node_secret, grace_period=0))
user_records = list(self.database.get_user_records(email))
self.assertEqual(len(user_records), 1)
self.assertEqual(len(self.service_requests), 2)
# Check that the proper delete requests were made to the service.
expected_kids = ["0000000000450-uw", "0000000000123-qg"]
for i, environ in enumerate(self.service_requests):
# They must be to the correct path.
self.assertEquals(environ["REQUEST_METHOD"], "DELETE")
self.assertTrue(re.match("/1.5/[0-9]+", environ["PATH_INFO"]))
# They must have a correct request signature.
token = hawkauthlib.get_id(environ)
secret = tokenlib.get_derived_secret(token, secret=node_secret)
self.assertTrue(hawkauthlib.check_signature(environ, secret))
userdata = tokenlib.parse_token(token, secret=node_secret)
self.assertTrue("uid" in userdata)
self.assertTrue("node" in userdata)
self.assertEqual(userdata["fxa_uid"], "test")
self.assertEqual(userdata["fxa_kid"], expected_kids[i])
# Check that the user's current state is unaffected
user = self.database.get_user(email)
self.assertEquals(user["client_state"], "cc")
self.assertEquals(len(user["old_client_states"]), 0)
def test_purging_is_not_done_on_downed_nodes(self):
# Make some old user records.
node_secret = "SECRET"
email = "test@mozilla.com"
user = self.database.allocate_user(email, client_state="aa")
self.database.update_user(user, client_state="bb")
user_records = list(self.database.get_user_records(email))
self.assertEqual(len(user_records), 2)
# With the node down, we should not purge any records.
self.database.update_node(self.service_node, downed=1)
self.assertTrue(purge_old_records(node_secret, grace_period=0))
user_records = list(self.database.get_user_records(email))
self.assertEqual(len(user_records), 2)
self.assertEqual(len(self.service_requests), 0)
# With the node back up, we should purge correctly.
self.database.update_node(self.service_node, downed=0)
self.assertTrue(purge_old_records(node_secret, grace_period=0))
user_records = list(self.database.get_user_records(email))
self.assertEqual(len(user_records), 1)
self.assertEqual(len(self.service_requests), 1)

View File

@ -0,0 +1,223 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import json
import os
import unittest
import uuid
from add_node import main as add_node_script
from allocate_user import main as allocate_user_script
from count_users import main as count_users_script
from database import Database
from remove_node import main as remove_node_script
from unassign_node import main as unassign_node_script
from update_node import main as update_node_script
from util import get_timestamp
class TestScripts(unittest.TestCase):
NODE_ID = 800
NODE_URL = 'https://node1'
def setUp(self):
self.database = Database()
# Start each test with a blank slate.
cursor = self.database._execute_sql('DELETE FROM users')
cursor.close()
cursor = self.database._execute_sql('DELETE FROM nodes')
cursor.close()
# Ensure we have a node with enough capacity to run the tests.
self.database.add_node(self.NODE_URL, 100, id=self.NODE_ID)
def tearDown(self):
# And clean up at the end, for good measure.
cursor = self.database._execute_sql('DELETE FROM users')
cursor.close()
cursor = self.database._execute_sql('DELETE FROM nodes')
cursor.close()
self.database.close()
def test_add_node(self):
add_node_script(
args=['--current-load', '9', 'test_node', '100']
)
res = self.database.get_node('test_node')
# The node should have the expected attributes
self.assertEqual(res.capacity, 100)
self.assertEqual(res.available, 10)
self.assertEqual(res.current_load, 9)
self.assertEqual(res.downed, 0)
self.assertEqual(res.backoff, 0)
self.assertEqual(res.service, self.database.service_id)
def test_add_node_with_explicit_available(self):
args = ['--current-load', '9', '--available', '5', 'test_node', '100']
add_node_script(args=args)
res = self.database.get_node('test_node')
# The node should have the expected attributes
self.assertEqual(res.capacity, 100)
self.assertEqual(res.available, 5)
self.assertEqual(res.current_load, 9)
self.assertEqual(res.downed, 0)
self.assertEqual(res.backoff, 0)
self.assertEqual(res.service, self.database.service_id)
def test_add_downed_node(self):
add_node_script(
args=['--downed', 'test_node', '100']
)
res = self.database.get_node('test_node')
# The node should have the expected attributes
self.assertEqual(res.capacity, 100)
self.assertEqual(res.available, 10)
self.assertEqual(res.current_load, 0)
self.assertEqual(res.downed, 1)
self.assertEqual(res.backoff, 0)
self.assertEqual(res.service, self.database.service_id)
def test_add_backoff_node(self):
add_node_script(
args=['--backoff', 'test_node', '100']
)
res = self.database.get_node('test_node')
# The node should have the expected attributes
self.assertEqual(res.capacity, 100)
self.assertEqual(res.available, 10)
self.assertEqual(res.current_load, 0)
self.assertEqual(res.downed, 0)
self.assertEqual(res.backoff, 1)
self.assertEqual(res.service, self.database.service_id)
def test_allocate_user_user_already_exists(self):
email = 'test@test.com'
self.database.allocate_user(email)
node = 'https://node2'
self.database.add_node(node, 100)
allocate_user_script(args=[email, node])
user = self.database.get_user(email)
# The user should be assigned to the given node
self.assertEqual(user['node'], node)
# Another user should not have been created
count = self.database.count_users()
self.assertEqual(count, 1)
def test_allocate_user_given_node(self):
email = 'test@test.com'
node = 'https://node2'
self.database.add_node(node, 100)
allocate_user_script(args=[email, node])
user = self.database.get_user(email)
# A new user should be created and assigned to the given node
self.assertEqual(user['node'], node)
def test_allocate_user_not_given_node(self):
email = 'test@test.com'
self.database.add_node('https://node2', 100,
current_load=10)
self.database.add_node('https://node3', 100,
current_load=20)
self.database.add_node('https://node4', 100,
current_load=30)
allocate_user_script(args=[email])
user = self.database.get_user(email)
# The user should be assigned to the least-loaded node
self.assertEqual(user['node'], 'https://node1')
def test_count_users(self):
self.database.allocate_user('test1@test.com')
self.database.allocate_user('test2@test.com')
self.database.allocate_user('test3@test.com')
timestamp = get_timestamp()
filename = '/tmp/' + str(uuid.uuid4())
try:
count_users_script(
args=['--output', filename, '--timestamp', str(timestamp)]
)
with open(filename) as f:
info = json.loads(f.readline())
self.assertEqual(info['total_users'], 3)
self.assertEqual(info['op'], 'sync_count_users')
finally:
os.remove(filename)
filename = '/tmp/' + str(uuid.uuid4())
try:
args = ['--output', filename, '--timestamp',
str(timestamp - 10000)]
count_users_script(args=args)
with open(filename) as f:
info = json.loads(f.readline())
self.assertEqual(info['total_users'], 0)
self.assertEqual(info['op'], 'sync_count_users')
finally:
os.remove(filename)
def test_remove_node(self):
self.database.add_node('https://node2', 100)
self.database.allocate_user('test1@test.com',
node='https://node2')
self.database.allocate_user('test2@test.com',
node=self.NODE_URL)
self.database.allocate_user('test3@test.com',
node=self.NODE_URL)
remove_node_script(args=['https://node2'])
# The node should have been removed from the database
args = ['https://node2']
self.assertRaises(ValueError, self.database.get_node_id, *args)
# The first user should have been assigned to a new node
user = self.database.get_user('test1@test.com')
self.assertEqual(user['node'], self.NODE_URL)
# The second and third users should still be on the first node
user = self.database.get_user('test2@test.com')
self.assertEqual(user['node'], self.NODE_URL)
user = self.database.get_user('test3@test.com')
self.assertEqual(user['node'], self.NODE_URL)
def test_unassign_node(self):
self.database.add_node('https://node2', 100)
self.database.allocate_user('test1@test.com',
node='https://node2')
self.database.allocate_user('test2@test.com',
node='https://node2')
self.database.allocate_user('test3@test.com',
node=self.NODE_URL)
unassign_node_script(args=['https://node2'])
self.database.remove_node('https://node2')
# All of the users should now be assigned to the first node
user = self.database.get_user('test1@test.com')
self.assertEqual(user['node'], self.NODE_URL)
user = self.database.get_user('test2@test.com')
self.assertEqual(user['node'], self.NODE_URL)
user = self.database.get_user('test3@test.com')
self.assertEqual(user['node'], self.NODE_URL)
def test_update_node(self):
self.database.add_node('https://node2', 100)
update_node_script(args=[
'--capacity', '150',
'--available', '125',
'--current-load', '25',
'--downed',
'--backoff',
'https://node2'
])
node = self.database.get_node('https://node2')
# Ensure the node has the expected attributes
self.assertEqual(node['capacity'], 150)
self.assertEqual(node['available'], 125)
self.assertEqual(node['current_load'], 25)
self.assertEqual(node['downed'], 1)
self.assertEqual(node['backoff'], 1)

View File

@ -0,0 +1,72 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Script to remove a node from the system.
This script clears any assignments to the named node.
"""
import logging
import optparse
from database import Database
import util
logger = logging.getLogger("tokenserver.scripts.unassign_node")
def unassign_node(node):
"""Clear any assignments to the named node."""
logger.info("Unassignment node %s", node)
try:
database = Database()
found = False
try:
database.unassign_node(node)
except ValueError:
logger.debug(" not found")
else:
found = True
logger.debug(" unassigned")
except Exception:
logger.exception("Error while unassigning node")
return False
else:
if not found:
logger.info("Node %s was not found", node)
else:
logger.info("Finished unassigning node %s", node)
return True
def main(args=None):
"""Main entry-point for running this script.
This function parses command-line arguments and passes them on
to the unassign_node() function.
"""
usage = "usage: %prog [options] node_name"
descr = "Clear all assignments to node in the tokenserver database"
parser = optparse.OptionParser(usage=usage, description=descr)
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
help="Control verbosity of log messages")
opts, args = parser.parse_args(args)
if len(args) != 1:
parser.print_usage()
return 1
util.configure_script_logging(opts)
node_name = args[0]
unassign_node(node_name)
return 0
if __name__ == "__main__":
util.run_script(main)

View File

@ -0,0 +1,83 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Script to update node status in the db.
"""
import logging
import optparse
from database import Database
import util
logger = logging.getLogger("tokenserver.scripts.update_node")
def update_node(node, **kwds):
"""Update details of a node."""
logger.info("Updating node %s for service %s", node)
logger.debug("Value: %r", kwds)
try:
database = Database()
database.update_node(node, **kwds)
except Exception:
logger.exception("Error while updating node")
return False
else:
logger.info("Finished updating node %s", node)
return True
def main(args=None):
"""Main entry-point for running this script.
This function parses command-line arguments and passes them on
to the update_node() function.
"""
usage = "usage: %prog [options] node_name"
descr = "Update node details in the tokenserver database"
parser = optparse.OptionParser(usage=usage, description=descr)
parser.add_option("", "--capacity", type="int",
help="How many user slots the node has overall")
parser.add_option("", "--available", type="int",
help="How many user slots the node has available")
parser.add_option("", "--current-load", type="int",
help="How many user slots the node has occupied")
parser.add_option("", "--downed", action="store_true",
help="Mark the node as down in the db")
parser.add_option("", "--backoff", action="store_true",
help="Mark the node as backed-off in the db")
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
help="Control verbosity of log messages")
opts, args = parser.parse_args(args)
if len(args) != 1:
parser.print_usage()
return 1
util.configure_script_logging(opts)
node_name = args[0]
kwds = {}
if opts.capacity is not None:
kwds["capacity"] = opts.capacity
if opts.available is not None:
kwds["available"] = opts.available
if opts.current_load is not None:
kwds["current_load"] = opts.current_load
if opts.backoff is not None:
kwds["backoff"] = opts.backoff
if opts.downed is not None:
kwds["downed"] = opts.downed
update_node(node_name, **kwds)
return 0
if __name__ == "__main__":
util.run_script(main)

58
tools/tokenserver/util.py Normal file
View File

@ -0,0 +1,58 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
"""
Admin/managment scripts for TokenServer.
"""
import sys
import time
import logging
from browserid.utils import encode_bytes as encode_bytes_b64
def run_script(main):
"""Simple wrapper for running scripts in __main__ section."""
try:
exitcode = main()
except KeyboardInterrupt:
exitcode = 1
sys.exit(exitcode)
def configure_script_logging(opts=None):
"""Configure stdlib logging to produce output from the script.
This basically configures logging to send messages to stderr, with
formatting that's more for human readability than machine parsing.
It also takes care of the --verbosity command-line option.
"""
if not opts or not opts.verbosity:
loglevel = logging.WARNING
elif opts.verbosity == 1:
loglevel = logging.INFO
else:
loglevel = logging.DEBUG
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(message)s"))
handler.setLevel(loglevel)
logger = logging.getLogger("")
logger.addHandler(handler)
logger.setLevel(loglevel)
def format_key_id(keys_changed_at, key_hash):
"""Format an FxA key ID from a timestamp and key hash."""
return "{:013d}-{}".format(
keys_changed_at,
encode_bytes_b64(key_hash),
)
def get_timestamp():
"""Get current timestamp in milliseconds."""
return int(time.time() * 1000)