mirror of
https://github.com/mozilla-services/syncstorage-rs.git
synced 2026-05-05 12:16:21 +02:00
parent
0996cb154f
commit
0ac30958de
@ -38,7 +38,7 @@ commands:
|
||||
command: |
|
||||
cargo fmt -- --check
|
||||
# https://github.com/bodil/sized-chunks/issues/11
|
||||
cargo audit --ignore RUSTSEC-2020-0041 --ignore RUSTSEC-2021-0078 --ignore RUSTSEC-2021-0079 --ignore RUSTSEC-2020-0159 --ignore RUSTSEC-2020-0071
|
||||
cargo audit --ignore RUSTSEC-2020-0041 --ignore RUSTSEC-2021-0078 --ignore RUSTSEC-2021-0079 --ignore RUSTSEC-2020-0159 --ignore RUSTSEC-2020-0071 --ignore RUSTSEC-2021-0124
|
||||
python-check:
|
||||
steps:
|
||||
- run:
|
||||
@ -46,6 +46,7 @@ commands:
|
||||
command: |
|
||||
flake8 src/tokenserver/verify.py
|
||||
flake8 tools/integration_tests
|
||||
flake8 tools/tokenserver
|
||||
rust-clippy:
|
||||
steps:
|
||||
- run:
|
||||
@ -113,6 +114,16 @@ commands:
|
||||
environment:
|
||||
SYNCSTORAGE_RS_IMAGE: app:build
|
||||
|
||||
run-tokenserver-scripts-tests:
|
||||
steps:
|
||||
- run:
|
||||
name: Tokenserver scripts tests
|
||||
command: >
|
||||
pip3 install -r tools/tokenserver/requirements.txt &&
|
||||
python3 tools/tokenserver/run_tests.py
|
||||
environment:
|
||||
SYNCSTORAGE_RS_IMAGE: app:build
|
||||
|
||||
run-e2e-spanner-tests:
|
||||
steps:
|
||||
- run:
|
||||
@ -211,6 +222,7 @@ jobs:
|
||||
- write-version
|
||||
- cargo-build
|
||||
- run-tests
|
||||
- run-tokenserver-scripts-tests
|
||||
#- save-sccache-cache
|
||||
- run:
|
||||
name: Build Docker image
|
||||
|
||||
0
tools/tokenserver/__init__.py
Normal file
0
tools/tokenserver/__init__.py
Normal file
80
tools/tokenserver/add_node.py
Normal file
80
tools/tokenserver/add_node.py
Normal file
@ -0,0 +1,80 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
"""
|
||||
|
||||
Script to add a new node to the system.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import optparse
|
||||
|
||||
from database import Database, SERVICE_NAME
|
||||
import util
|
||||
|
||||
|
||||
logger = logging.getLogger("tokenserver.scripts.add_node")
|
||||
|
||||
|
||||
def add_node(node, capacity, **kwds):
|
||||
"""Add the specific node to the system."""
|
||||
logger.info("Adding node %s to service %s", node, SERVICE_NAME)
|
||||
try:
|
||||
database = Database()
|
||||
database.add_node(node, capacity, **kwds)
|
||||
except Exception:
|
||||
logger.exception("Error while adding node")
|
||||
return False
|
||||
else:
|
||||
logger.info("Finished adding node %s", node)
|
||||
return True
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Main entry-point for running this script.
|
||||
|
||||
This function parses command-line arguments and passes them on
|
||||
to the add_node() function.
|
||||
"""
|
||||
usage = "usage: %prog [options] node_name capacity"
|
||||
descr = "Add a new node to the tokenserver database"
|
||||
parser = optparse.OptionParser(usage=usage, description=descr)
|
||||
parser.add_option("", "--available", type="int",
|
||||
help="How many user slots the node has available")
|
||||
parser.add_option("", "--current-load", type="int",
|
||||
help="How many user slots the node has occupied")
|
||||
parser.add_option("", "--downed", action="store_true",
|
||||
help="Mark the node as down in the db")
|
||||
parser.add_option("", "--backoff", action="store_true",
|
||||
help="Mark the node as backed-off in the db")
|
||||
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
|
||||
help="Control verbosity of log messages")
|
||||
|
||||
opts, args = parser.parse_args(args)
|
||||
if len(args) != 2:
|
||||
parser.print_usage()
|
||||
return 1
|
||||
|
||||
util.configure_script_logging(opts)
|
||||
|
||||
node_name = args[0]
|
||||
capacity = int(args[1])
|
||||
|
||||
kwds = {}
|
||||
if opts.available is not None:
|
||||
kwds["available"] = opts.available
|
||||
if opts.current_load is not None:
|
||||
kwds["current_load"] = opts.current_load
|
||||
if opts.backoff is not None:
|
||||
kwds["backoff"] = opts.backoff
|
||||
if opts.downed is not None:
|
||||
kwds["downed"] = opts.downed
|
||||
|
||||
add_node(node_name, capacity, **kwds)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
util.run_script(main)
|
||||
73
tools/tokenserver/allocate_user.py
Normal file
73
tools/tokenserver/allocate_user.py
Normal file
@ -0,0 +1,73 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
"""
|
||||
|
||||
Script to allocate a specific user to a node.
|
||||
|
||||
This script allocates the specified user to a node. A particular node
|
||||
may be specified, or the best available node used by default.
|
||||
|
||||
The allocated node is printed to stdout.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import optparse
|
||||
|
||||
from database import Database
|
||||
import util
|
||||
|
||||
|
||||
logger = logging.getLogger("tokenserver.scripts.allocate_user")
|
||||
|
||||
|
||||
def allocate_user(email, node=None):
|
||||
logger.info("Allocating node for user %s", email)
|
||||
try:
|
||||
database = Database()
|
||||
user = database.get_user(email)
|
||||
if user is None:
|
||||
user = database.allocate_user(email, node=node)
|
||||
else:
|
||||
database.update_user(user, node=node)
|
||||
except Exception:
|
||||
logger.exception("Error while updating node")
|
||||
return False
|
||||
else:
|
||||
logger.info("Finished updating node %s", node)
|
||||
return True
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Main entry-point for running this script.
|
||||
|
||||
This function parses command-line arguments and passes them on
|
||||
to the allocate_user() function.
|
||||
"""
|
||||
usage = "usage: %prog [options] email [node_name]"
|
||||
descr = "Allocate a user to a node. You may specify a particular node, "\
|
||||
"or omit to use the best available node."
|
||||
parser = optparse.OptionParser(usage=usage, description=descr)
|
||||
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
|
||||
help="Control verbosity of log messages")
|
||||
|
||||
opts, args = parser.parse_args(args)
|
||||
if not 1 <= len(args) <= 2:
|
||||
parser.print_usage()
|
||||
return 1
|
||||
|
||||
util.configure_script_logging(opts)
|
||||
|
||||
email = args[0]
|
||||
if len(args) == 1:
|
||||
node_name = None
|
||||
else:
|
||||
node_name = args[1]
|
||||
|
||||
allocate_user(email, node_name)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
util.run_script(main)
|
||||
98
tools/tokenserver/count_users.py
Normal file
98
tools/tokenserver/count_users.py
Normal file
@ -0,0 +1,98 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
"""
|
||||
|
||||
Script to emit total-user-count metrics for exec dashboard.
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import optparse
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta, tzinfo
|
||||
|
||||
from database import Database
|
||||
import util
|
||||
|
||||
logger = logging.getLogger("tokenserver.scripts.count_users")
|
||||
|
||||
ZERO = timedelta(0)
|
||||
|
||||
|
||||
class UTC(tzinfo):
|
||||
|
||||
def utcoffset(self, dt):
|
||||
return ZERO
|
||||
|
||||
def tzname(self, dt):
|
||||
return "UTC"
|
||||
|
||||
def dst(self, dt):
|
||||
return ZERO
|
||||
|
||||
|
||||
utc = UTC()
|
||||
|
||||
|
||||
def count_users(outfile, timestamp=None):
|
||||
if timestamp is None:
|
||||
ts = time.gmtime()
|
||||
midnight = (ts[0], ts[1], ts[2], 0, 0, 0, ts[6], ts[7], ts[8])
|
||||
timestamp = int(time.mktime(midnight)) * 1000
|
||||
database = Database()
|
||||
logger.debug("Counting users created before %i", timestamp)
|
||||
count = database.count_users(timestamp)
|
||||
logger.debug("Found %d users", count)
|
||||
# Output has heka-filter-compatible JSON object.
|
||||
ts_sec = timestamp / 1000
|
||||
output = {
|
||||
"hostname": socket.gethostname(),
|
||||
"pid": os.getpid(),
|
||||
"op": "sync_count_users",
|
||||
"total_users": count,
|
||||
"time": datetime.fromtimestamp(ts_sec, utc).isoformat(),
|
||||
"v": 0
|
||||
}
|
||||
json.dump(output, outfile)
|
||||
outfile.write("\n")
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Main entry-point for running this script.
|
||||
|
||||
This function parses command-line arguments and passes them on
|
||||
to the add_node() function.
|
||||
"""
|
||||
usage = "usage: %prog [options]"
|
||||
descr = "Count total users in the tokenserver database"
|
||||
parser = optparse.OptionParser(usage=usage, description=descr)
|
||||
parser.add_option("-t", "--timestamp", type="int",
|
||||
help="Max creation timestamp; default previous midnight")
|
||||
parser.add_option("-o", "--output",
|
||||
help="Output file; default stderr")
|
||||
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
|
||||
help="Control verbosity of log messages")
|
||||
|
||||
opts, args = parser.parse_args(args)
|
||||
if len(args) != 0:
|
||||
parser.print_usage()
|
||||
return 1
|
||||
|
||||
util.configure_script_logging(opts)
|
||||
|
||||
if opts.output in (None, "-"):
|
||||
count_users(sys.stdout, opts.timestamp)
|
||||
else:
|
||||
with open(opts.output, "a") as outfile:
|
||||
count_users(outfile, opts.timestamp)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
util.run_script(main)
|
||||
641
tools/tokenserver/database.py
Normal file
641
tools/tokenserver/database.py
Normal file
@ -0,0 +1,641 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import math
|
||||
import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.sql import text as sqltext
|
||||
|
||||
from util import get_timestamp
|
||||
|
||||
# The maximum possible generation number.
|
||||
# Used as a tombstone to mark users that have been "retired" from the db.
|
||||
MAX_GENERATION = 9223372036854775807
|
||||
NODE_FIELDS = ("capacity", "available", "current_load", "downed", "backoff")
|
||||
|
||||
_GET_USER_RECORDS = sqltext("""\
|
||||
select
|
||||
uid, nodes.node, generation, keys_changed_at, client_state, created_at,
|
||||
replaced_at
|
||||
from
|
||||
users left outer join nodes on users.nodeid = nodes.id
|
||||
where
|
||||
email = :email and users.service = :service
|
||||
order by
|
||||
created_at desc, uid desc
|
||||
limit
|
||||
20
|
||||
""")
|
||||
|
||||
_CREATE_USER_RECORD = sqltext("""\
|
||||
insert into
|
||||
users
|
||||
(service, email, nodeid, generation, keys_changed_at, client_state,
|
||||
created_at, replaced_at)
|
||||
values
|
||||
(:service, :email, :nodeid, :generation, :keys_changed_at,
|
||||
:client_state, :timestamp, NULL)
|
||||
""")
|
||||
|
||||
# The `where` clause on this statement is designed as an extra layer of
|
||||
# protection, to ensure that concurrent updates don't accidentally move
|
||||
# timestamp fields backwards in time. The handling of `keys_changed_at`
|
||||
# is additionally weird because we want to treat the default `NULL` value
|
||||
# as zero.
|
||||
_UPDATE_USER_RECORD_IN_PLACE = sqltext("""\
|
||||
update
|
||||
users
|
||||
set
|
||||
generation = COALESCE(:generation, generation),
|
||||
keys_changed_at = COALESCE(:keys_changed_at, keys_changed_at)
|
||||
where
|
||||
service = :service and email = :email and
|
||||
generation <= COALESCE(:generation, generation) and
|
||||
COALESCE(keys_changed_at, 0) <=
|
||||
COALESCE(:keys_changed_at, keys_changed_at, 0) and
|
||||
replaced_at is null
|
||||
""")
|
||||
|
||||
|
||||
_REPLACE_USER_RECORDS = sqltext("""\
|
||||
update
|
||||
users
|
||||
set
|
||||
replaced_at = :timestamp
|
||||
where
|
||||
service = :service and email = :email
|
||||
and replaced_at is null and created_at < :timestamp
|
||||
""")
|
||||
|
||||
|
||||
# Mark all records for the user as replaced,
|
||||
# and set a large generation number to block future logins.
|
||||
_RETIRE_USER_RECORDS = sqltext("""\
|
||||
update
|
||||
users
|
||||
set
|
||||
replaced_at = :timestamp,
|
||||
generation = :generation
|
||||
where
|
||||
email = :email
|
||||
and replaced_at is null
|
||||
""")
|
||||
|
||||
|
||||
_GET_OLD_USER_RECORDS_FOR_SERVICE = sqltext("""\
|
||||
select
|
||||
uid, email, generation, keys_changed_at, client_state,
|
||||
nodes.node, nodes.downed, created_at, replaced_at
|
||||
from
|
||||
users left outer join nodes on users.nodeid = nodes.id
|
||||
where
|
||||
users.service = :service
|
||||
and
|
||||
replaced_at is not null and replaced_at < :timestamp
|
||||
order by
|
||||
replaced_at desc, uid desc
|
||||
limit
|
||||
:limit
|
||||
offset
|
||||
:offset
|
||||
""")
|
||||
|
||||
|
||||
_GET_ALL_USER_RECORDS_FOR_SERVICE = sqltext("""\
|
||||
select
|
||||
uid, nodes.node, created_at, replaced_at
|
||||
from
|
||||
users left outer join nodes on users.nodeid = nodes.id
|
||||
where
|
||||
email = :email and users.service = :service
|
||||
order by
|
||||
created_at asc, uid desc
|
||||
""")
|
||||
|
||||
|
||||
_REPLACE_USER_RECORD = sqltext("""\
|
||||
update
|
||||
users
|
||||
set
|
||||
replaced_at = :timestamp
|
||||
where
|
||||
service = :service
|
||||
and
|
||||
uid = :uid
|
||||
""")
|
||||
|
||||
|
||||
_DELETE_USER_RECORD = sqltext("""\
|
||||
delete from
|
||||
users
|
||||
where
|
||||
service = :service
|
||||
and
|
||||
uid = :uid
|
||||
""")
|
||||
|
||||
|
||||
_FREE_SLOT_ON_NODE = sqltext("""\
|
||||
update
|
||||
nodes
|
||||
set
|
||||
available = available + 1, current_load = current_load - 1
|
||||
where
|
||||
id = (SELECT nodeid FROM users WHERE service=:service AND uid=:uid)
|
||||
""")
|
||||
|
||||
|
||||
_COUNT_USER_RECORDS = sqltext("""\
|
||||
select
|
||||
count(email)
|
||||
from
|
||||
users
|
||||
where
|
||||
replaced_at is null
|
||||
and created_at <= :timestamp
|
||||
""")
|
||||
|
||||
|
||||
_GET_BEST_NODE = sqltext("""\
|
||||
select
|
||||
id, node
|
||||
from
|
||||
nodes
|
||||
where
|
||||
service = :service
|
||||
and available > 0
|
||||
and capacity > current_load
|
||||
and downed = 0
|
||||
and backoff = 0
|
||||
order by
|
||||
log(current_load) / log(capacity)
|
||||
limit 1
|
||||
""")
|
||||
|
||||
|
||||
_RELEASE_NODE_CAPACITY = sqltext("""\
|
||||
update
|
||||
nodes
|
||||
set
|
||||
available = least(capacity * :capacity_release_rate,
|
||||
capacity - current_load)
|
||||
where
|
||||
service = :service
|
||||
and available <= 0
|
||||
and capacity > current_load
|
||||
and downed = 0
|
||||
""")
|
||||
|
||||
|
||||
_ADD_USER_TO_NODE = sqltext("""\
|
||||
update
|
||||
nodes
|
||||
set
|
||||
current_load = current_load + 1,
|
||||
available = greatest(available - 1, 0)
|
||||
where
|
||||
service = :service
|
||||
and node = :node
|
||||
""")
|
||||
|
||||
|
||||
_GET_SERVICE_ID = sqltext("""\
|
||||
select
|
||||
id
|
||||
from
|
||||
services
|
||||
where
|
||||
service = :service
|
||||
""")
|
||||
|
||||
|
||||
_GET_NODE = sqltext("""\
|
||||
select
|
||||
*
|
||||
from
|
||||
nodes
|
||||
where
|
||||
service = :service
|
||||
and node = :node
|
||||
""")
|
||||
|
||||
SERVICE_NAME = 'sync-1.5'
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
engine = create_engine(os.environ['SYNC_TOKENSERVER__DATABASE_URL'])
|
||||
self.database = engine. \
|
||||
execution_options(isolation_level="AUTOCOMMIT"). \
|
||||
connect()
|
||||
self.service_id = self._get_service_id(SERVICE_NAME)
|
||||
self.capacity_release_rate = os.environ. \
|
||||
get("NODE_CAPACITY_RELEASE_RATE", 0.1)
|
||||
|
||||
def _execute_sql(self, *args, **kwds):
|
||||
return self.database.execute(*args, **kwds)
|
||||
|
||||
def close(self):
|
||||
self.database.close()
|
||||
|
||||
def get_user(self, email):
|
||||
params = {'service': self.service_id, 'email': email}
|
||||
res = self._execute_sql(_GET_USER_RECORDS, **params)
|
||||
try:
|
||||
# The query fetches rows ordered by created_at, but we want
|
||||
# to ensure that they're ordered by (generation, created_at).
|
||||
# This is almost always true, except for strange race conditions
|
||||
# during row creation. Sorting them is an easy way to enforce
|
||||
# this without bloating the db index.
|
||||
rows = res.fetchall()
|
||||
rows.sort(key=lambda r: (r.generation, r.created_at), reverse=True)
|
||||
if not rows:
|
||||
return None
|
||||
# The first row is the most up-to-date user record.
|
||||
# The rest give previously-seen client-state values.
|
||||
cur_row = rows[0]
|
||||
old_rows = rows[1:]
|
||||
user = {
|
||||
'email': email,
|
||||
'uid': cur_row.uid,
|
||||
'node': cur_row.node,
|
||||
'generation': cur_row.generation,
|
||||
'keys_changed_at': cur_row.keys_changed_at or 0,
|
||||
'client_state': cur_row.client_state,
|
||||
'old_client_states': {},
|
||||
'first_seen_at': cur_row.created_at,
|
||||
}
|
||||
# If the current row is marked as replaced or is missing a node,
|
||||
# and they haven't been retired, then assign them a new node.
|
||||
if cur_row.replaced_at is not None or cur_row.node is None:
|
||||
if cur_row.generation < MAX_GENERATION:
|
||||
user = self.allocate_user(email,
|
||||
cur_row.generation,
|
||||
cur_row.client_state,
|
||||
cur_row.keys_changed_at)
|
||||
for old_row in old_rows:
|
||||
# Collect any previously-seen client-state values.
|
||||
if old_row.client_state != user['client_state']:
|
||||
user['old_client_states'][old_row.client_state] = True
|
||||
# Make sure each old row is marked as replaced.
|
||||
# They might not be, due to races in row creation.
|
||||
if old_row.replaced_at is None:
|
||||
timestamp = cur_row.created_at
|
||||
self.replace_user_record(old_row.uid, timestamp)
|
||||
# Track backwards to the oldest timestamp at which we saw them.
|
||||
user['first_seen_at'] = old_row.created_at
|
||||
return user
|
||||
finally:
|
||||
res.close()
|
||||
|
||||
def allocate_user(self, email, generation=0, client_state='',
|
||||
keys_changed_at=0, node=None, timestamp=None):
|
||||
if timestamp is None:
|
||||
timestamp = get_timestamp()
|
||||
if node is None:
|
||||
nodeid, node = self.get_best_node()
|
||||
else:
|
||||
nodeid = self.get_node_id(node)
|
||||
params = {
|
||||
'service': self.service_id,
|
||||
'email': email,
|
||||
'nodeid': nodeid,
|
||||
'generation': generation,
|
||||
'keys_changed_at': keys_changed_at,
|
||||
'client_state': client_state,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
res = self._execute_sql(_CREATE_USER_RECORD, **params)
|
||||
return {
|
||||
'email': email,
|
||||
'uid': res.lastrowid,
|
||||
'node': node,
|
||||
'generation': generation,
|
||||
'keys_changed_at': keys_changed_at,
|
||||
'client_state': client_state,
|
||||
'old_client_states': {},
|
||||
'first_seen_at': timestamp,
|
||||
}
|
||||
|
||||
def update_user(self, user, generation=None, client_state=None,
|
||||
keys_changed_at=None, node=None):
|
||||
if client_state is None and node is None:
|
||||
# No need for a node-reassignment, just update the row in place.
|
||||
# Note that if we're changing keys_changed_at without changing
|
||||
# client_state, it's because we're seeing an existing value of
|
||||
# keys_changed_at for the first time.
|
||||
params = {
|
||||
'service': self.service_id,
|
||||
'email': user['email'],
|
||||
'generation': generation,
|
||||
'keys_changed_at': keys_changed_at
|
||||
}
|
||||
res = self._execute_sql(_UPDATE_USER_RECORD_IN_PLACE, **params)
|
||||
res.close()
|
||||
|
||||
user['generation'] = max([x
|
||||
for x
|
||||
in [generation, user['generation']]
|
||||
if x is not None])
|
||||
user['keys_changed_at'] = max([x
|
||||
for x
|
||||
in [keys_changed_at,
|
||||
user['keys_changed_at']]
|
||||
if x is not None])
|
||||
else:
|
||||
# Reject previously-seen client-state strings.
|
||||
if client_state is None:
|
||||
client_state = user['client_state']
|
||||
else:
|
||||
if client_state == user['client_state']:
|
||||
raise Exception('previously seen client-state string')
|
||||
if client_state in user['old_client_states']:
|
||||
raise Exception('previously seen client-state string')
|
||||
# Need to create a new record for new user state.
|
||||
# If the node is not explicitly changing, try to keep them on the
|
||||
# same node, but if e.g. it no longer exists them allocate them to
|
||||
# a new one.
|
||||
if node is not None:
|
||||
nodeid = self.get_node_id(node)
|
||||
user['node'] = node
|
||||
else:
|
||||
try:
|
||||
nodeid = self.get_node_id(user['node'])
|
||||
except ValueError:
|
||||
nodeid, node = self.get_best_node()
|
||||
user['node'] = node
|
||||
if generation is not None:
|
||||
generation = max(user['generation'], generation)
|
||||
else:
|
||||
generation = user['generation']
|
||||
if keys_changed_at is not None:
|
||||
keys_changed_at = max(user['keys_changed_at'], keys_changed_at)
|
||||
else:
|
||||
keys_changed_at = user['keys_changed_at']
|
||||
now = get_timestamp()
|
||||
params = {
|
||||
'service': self.service_id, 'email': user['email'],
|
||||
'nodeid': nodeid, 'generation': generation,
|
||||
'keys_changed_at': keys_changed_at,
|
||||
'client_state': client_state, 'timestamp': now,
|
||||
}
|
||||
res = self._execute_sql(_CREATE_USER_RECORD, **params)
|
||||
res.close()
|
||||
user['uid'] = res.lastrowid
|
||||
user['generation'] = generation
|
||||
user['keys_changed_at'] = keys_changed_at
|
||||
user['old_client_states'][user['client_state']] = True
|
||||
user['client_state'] = client_state
|
||||
# mark old records as having been replaced.
|
||||
# if we crash here, they are unmarked and we may fail to
|
||||
# garbage collect them for a while, but the active state
|
||||
# will be undamaged.
|
||||
self.replace_user_records(user['email'], now)
|
||||
|
||||
def retire_user(self, email):
|
||||
now = get_timestamp()
|
||||
params = {
|
||||
'email': email, 'timestamp': now, 'generation': MAX_GENERATION
|
||||
}
|
||||
# Pass through explicit engine to help with sharded implementation,
|
||||
# since we can't shard by service name here.
|
||||
res = self._execute_sql(_RETIRE_USER_RECORDS, **params)
|
||||
res.close()
|
||||
|
||||
def count_users(self, timestamp=None):
|
||||
if timestamp is None:
|
||||
timestamp = get_timestamp()
|
||||
res = self._execute_sql(_COUNT_USER_RECORDS, timestamp=timestamp)
|
||||
row = res.fetchone()
|
||||
res.close()
|
||||
return row[0]
|
||||
|
||||
#
|
||||
# Methods for low-level user record management.
|
||||
#
|
||||
|
||||
def get_user_records(self, email):
|
||||
"""Get all the user's records, including the old ones."""
|
||||
params = {'service': self.service_id, 'email': email}
|
||||
res = self._execute_sql(_GET_ALL_USER_RECORDS_FOR_SERVICE, **params)
|
||||
try:
|
||||
for row in res:
|
||||
yield row
|
||||
finally:
|
||||
res.close()
|
||||
|
||||
def get_old_user_records(self, grace_period=-1, limit=100,
|
||||
offset=0):
|
||||
"""Get user records that were replaced outside the grace period."""
|
||||
if grace_period < 0:
|
||||
grace_period = 60 * 60 * 24 * 7 # one week, in seconds
|
||||
grace_period = int(grace_period * 1000) # convert seconds -> millis
|
||||
params = {
|
||||
"service": self.service_id,
|
||||
"timestamp": get_timestamp() - grace_period,
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
res = self._execute_sql(_GET_OLD_USER_RECORDS_FOR_SERVICE, **params)
|
||||
try:
|
||||
for row in res:
|
||||
yield row
|
||||
finally:
|
||||
res.close()
|
||||
|
||||
def replace_user_records(self, email, timestamp=None):
|
||||
"""Mark all existing records for a user as replaced."""
|
||||
if timestamp is None:
|
||||
timestamp = get_timestamp()
|
||||
params = {
|
||||
'service': self.service_id, 'email': email, 'timestamp': timestamp
|
||||
}
|
||||
res = self._execute_sql(_REPLACE_USER_RECORDS, **params)
|
||||
res.close()
|
||||
|
||||
def replace_user_record(self, uid, timestamp=None):
|
||||
"""Mark an existing service record as replaced."""
|
||||
if timestamp is None:
|
||||
timestamp = get_timestamp()
|
||||
params = {
|
||||
'service': self.service_id, 'uid': uid, 'timestamp': timestamp
|
||||
}
|
||||
res = self._execute_sql(_REPLACE_USER_RECORD, **params)
|
||||
res.close()
|
||||
|
||||
def delete_user_record(self, uid):
|
||||
"""Delete the user record with the given uid."""
|
||||
params = {'service': self.service_id, 'uid': uid}
|
||||
res = self._execute_sql(_FREE_SLOT_ON_NODE, **params)
|
||||
res.close()
|
||||
res = self._execute_sql(_DELETE_USER_RECORD, **params)
|
||||
res.close()
|
||||
|
||||
#
|
||||
# Nodes management
|
||||
#
|
||||
|
||||
def _get_service_id(self, service):
|
||||
res = self._execute_sql(_GET_SERVICE_ID, service=service)
|
||||
row = res.fetchone()
|
||||
res.close()
|
||||
if row is None:
|
||||
raise Exception('unknown service: ' + service)
|
||||
return row.id
|
||||
|
||||
def add_service(self, service_name, pattern, **kwds):
|
||||
"""Add definition for a new service."""
|
||||
res = self._execute_sql(sqltext("""
|
||||
insert into services (service, pattern)
|
||||
values (:servicename, :pattern)
|
||||
"""), servicename=service_name, pattern=pattern, **kwds)
|
||||
res.close()
|
||||
return res.lastrowid
|
||||
|
||||
def add_node(self, node, capacity, **kwds):
|
||||
"""Add definition for a new node."""
|
||||
available = kwds.get('available')
|
||||
# We release only a fraction of the node's capacity to start.
|
||||
if available is None:
|
||||
available = math.ceil(capacity * self.capacity_release_rate)
|
||||
cols = ["service", "node", "available", "capacity",
|
||||
"current_load", "downed", "backoff"]
|
||||
args = [":" + v for v in cols]
|
||||
# Handle test cases that require nodeid to be 800
|
||||
if "nodeid" in kwds:
|
||||
cols.append("id")
|
||||
args.append(":nodeid")
|
||||
query = """
|
||||
insert into nodes ({cols})
|
||||
values ({args})
|
||||
""".format(cols=", ".join(cols), args=", ".join(args))
|
||||
res = self._execute_sql(
|
||||
sqltext(query),
|
||||
nodeid=kwds.get('nodeid'),
|
||||
service=self.service_id,
|
||||
node=node,
|
||||
capacity=capacity,
|
||||
available=available,
|
||||
current_load=kwds.get('current_load', 0),
|
||||
downed=kwds.get('downed', 0),
|
||||
backoff=kwds.get('backoff', 0),
|
||||
)
|
||||
res.close()
|
||||
|
||||
def update_node(self, node, **kwds):
|
||||
"""Updates node fields in the db."""
|
||||
values = {}
|
||||
cols = NODE_FIELDS & kwds.keys()
|
||||
for col in NODE_FIELDS:
|
||||
try:
|
||||
values[col] = kwds.pop(col)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
args = [v + " = :" + v for v in cols]
|
||||
query = """
|
||||
update nodes
|
||||
set """
|
||||
query += ", ".join(args)
|
||||
query += """
|
||||
where service = :service and node = :node
|
||||
"""
|
||||
values['service'] = self.service_id
|
||||
values['node'] = node
|
||||
if kwds:
|
||||
raise ValueError("unknown fields: " + str(kwds.keys()))
|
||||
con = self._execute_sql(sqltext(query), **values)
|
||||
con.close()
|
||||
|
||||
def get_node_id(self, node):
|
||||
"""Get numeric id for a node."""
|
||||
res = self._execute_sql(
|
||||
sqltext("""
|
||||
select id from nodes
|
||||
where service=:service and node=:node
|
||||
"""),
|
||||
service=self.service_id, node=node
|
||||
)
|
||||
row = res.fetchone()
|
||||
res.close()
|
||||
if row is None:
|
||||
raise ValueError("unknown node: " + node)
|
||||
return row[0]
|
||||
|
||||
def remove_node(self, node, timestamp=None):
|
||||
"""Remove definition for a node."""
|
||||
nodeid = self.get_node_id(node)
|
||||
res = self._execute_sql(sqltext(
|
||||
"""
|
||||
delete from nodes where id=:nodeid
|
||||
"""),
|
||||
nodeid=nodeid
|
||||
)
|
||||
res.close()
|
||||
self.unassign_node(node, timestamp, nodeid=nodeid)
|
||||
|
||||
def unassign_node(self, node, timestamp=None, nodeid=None):
|
||||
"""Clear any assignments to a node."""
|
||||
if timestamp is None:
|
||||
timestamp = get_timestamp()
|
||||
if nodeid is None:
|
||||
nodeid = self.get_node_id(node)
|
||||
res = self._execute_sql(
|
||||
sqltext("""
|
||||
update users
|
||||
set replaced_at=:timestamp
|
||||
where nodeid=:nodeid
|
||||
"""),
|
||||
nodeid=nodeid, timestamp=timestamp
|
||||
)
|
||||
res.close()
|
||||
|
||||
def get_best_node(self):
|
||||
"""Returns the 'least loaded' node currently available, increments the
|
||||
active count on that node, and decrements the slots currently available
|
||||
"""
|
||||
# We may have to re-try the query if we need to release more capacity.
|
||||
# This loop allows a maximum of five retries before bailing out.
|
||||
for _ in range(5):
|
||||
res = self._execute_sql(_GET_BEST_NODE, service=self.service_id)
|
||||
row = res.fetchone()
|
||||
res.close()
|
||||
if row is None:
|
||||
# Try to release additional capacity from any nodes
|
||||
# that are not fully occupied.
|
||||
res = self._execute_sql(
|
||||
_RELEASE_NODE_CAPACITY,
|
||||
capacity_release_rate=self.capacity_release_rate,
|
||||
service=self.service_id
|
||||
)
|
||||
res.close()
|
||||
if res.rowcount == 0:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# Did we succeed in finding a node?
|
||||
if row is None:
|
||||
raise Exception('unable to get a node')
|
||||
|
||||
nodeid = row.id
|
||||
node = str(row.node)
|
||||
|
||||
# Update the node to reflect the new assignment.
|
||||
# This is a little racy with concurrent assignments, but no big
|
||||
# deal.
|
||||
con = self._execute_sql(_ADD_USER_TO_NODE,
|
||||
service=self.service_id,
|
||||
node=node)
|
||||
con.close()
|
||||
|
||||
return nodeid, node
|
||||
|
||||
def get_node(self, node):
|
||||
res = self._execute_sql(_GET_NODE, service=self.service_id, node=node)
|
||||
row = res.fetchone()
|
||||
res.close()
|
||||
if row is None:
|
||||
raise Exception('unknown node: ' + node)
|
||||
return row
|
||||
179
tools/tokenserver/process_account_events.py
Normal file
179
tools/tokenserver/process_account_events.py
Normal file
@ -0,0 +1,179 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
"""
|
||||
|
||||
Script to process account-related events from an SQS queue.
|
||||
|
||||
This script polls an SQS queue for events indicating activity on an upstream
|
||||
account, as documented here:
|
||||
|
||||
https://github.com/mozilla/fxa-auth-server/blob/master/docs/service_notifications.md
|
||||
|
||||
The following event types are currently supported:
|
||||
|
||||
* "delete": the account was deleted; we mark their records as retired
|
||||
so they'll be cleaned up by our garbage-collection process.
|
||||
|
||||
* "reset": the account password was reset; we update our copy of their
|
||||
generation number to disconnect other devices.
|
||||
|
||||
* "passwordChange": the account password was changed; we update our copy
|
||||
of their generation number to disconnect other devices.
|
||||
|
||||
Note that this is a purely optional administrative task, highly specific to
|
||||
Mozilla's internal Firefox-Accounts-supported deployment.
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import optparse
|
||||
|
||||
import boto
|
||||
import boto.ec2
|
||||
import boto.sqs
|
||||
import boto.sqs.message
|
||||
import boto.utils
|
||||
|
||||
import util
|
||||
from database import Database
|
||||
|
||||
|
||||
logger = logging.getLogger("tokenserver.scripts.process_account_deletions")
|
||||
|
||||
|
||||
def process_account_events(queue_name, aws_region=None, queue_wait_time=20):
|
||||
"""Process account events from an SQS queue.
|
||||
|
||||
This function polls the specified SQS queue for account-realted events,
|
||||
processing each as it is found. It polls indefinitely and does not return;
|
||||
to interrupt execution you'll need to e.g. SIGINT the process.
|
||||
"""
|
||||
logger.info("Processing account events from %s", queue_name)
|
||||
try:
|
||||
# Connect to the SQS queue.
|
||||
# If no region is given, infer it from the instance metadata.
|
||||
if aws_region is None:
|
||||
logger.debug("Finding default region from instance metadata")
|
||||
aws_info = boto.utils.get_instance_metadata()
|
||||
aws_region = aws_info["placement"]["availability-zone"][:-1]
|
||||
logger.debug("Connecting to queue %r in %r", queue_name, aws_region)
|
||||
conn = boto.sqs.connect_to_region(aws_region)
|
||||
queue = conn.get_queue(queue_name)
|
||||
# We must force boto not to b64-decode the message contents, ugh.
|
||||
queue.set_message_class(boto.sqs.message.RawMessage)
|
||||
# Poll for messages indefinitely.
|
||||
while True:
|
||||
msg = queue.read(wait_time_seconds=queue_wait_time)
|
||||
if msg is None:
|
||||
continue
|
||||
process_account_event(msg.get_body())
|
||||
# This intentionally deletes the event even if it was some
|
||||
# unrecognized type. Not point leaving a backlog.
|
||||
queue.delete_message(msg)
|
||||
except Exception:
|
||||
logger.exception("Error while processing account events")
|
||||
raise
|
||||
|
||||
|
||||
def process_account_event(body):
|
||||
"""Parse and process a single account event."""
|
||||
database = Database()
|
||||
# Try very hard not to error out if there's junk in the queue.
|
||||
email = None
|
||||
event_type = None
|
||||
generation = None
|
||||
try:
|
||||
body = json.loads(body)
|
||||
event = json.loads(body['Message'])
|
||||
event_type = event["event"]
|
||||
uid = event["uid"]
|
||||
# Older versions of the fxa-auth-server would send an email-like
|
||||
# identifier the "uid" field, but that doesn't make sense for any
|
||||
# relier other than tokenserver. Newer versions send just the raw uid
|
||||
# in the "uid" field, and include the domain in a separate "iss" field.
|
||||
if "iss" in event:
|
||||
email = "%s@%s" % (uid, event["iss"])
|
||||
else:
|
||||
if "@" not in uid:
|
||||
raise ValueError("uid field does not contain issuer info")
|
||||
email = uid
|
||||
if event_type in ("reset", "passwordChange",):
|
||||
generation = event["generation"]
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.exception("Invalid account message: %s", e)
|
||||
else:
|
||||
if email is not None:
|
||||
if event_type == "delete":
|
||||
# Mark the user as retired.
|
||||
# Actual cleanup is done by a separate process.
|
||||
logger.info("Processing account delete for %r", email)
|
||||
database.retire_user(email)
|
||||
elif event_type == "reset":
|
||||
logger.info("Processing account reset for %r", email)
|
||||
update_generation_number(database, email, generation)
|
||||
elif event_type == "passwordChange":
|
||||
logger.info("Processing password change for %r", email)
|
||||
update_generation_number(database, email, generation)
|
||||
else:
|
||||
logger.warning("Dropping unknown event type %r",
|
||||
event_type)
|
||||
|
||||
|
||||
def update_generation_number(database, email, generation):
|
||||
"""Update the maximum recorded generation number for the given user.
|
||||
|
||||
When the FxA server sends us an update to the user's generation
|
||||
number, we want to update our high-water-mark in the DB in order to
|
||||
immediately lock out disconnected devices. However, since we don't
|
||||
know the new value of the client state that goes with it, we can't just
|
||||
record the new generation number in the DB. If we did, the first
|
||||
device that tried to sync with the new generation number would appear
|
||||
to have an incorrect client state value, and would be rejected.
|
||||
|
||||
Instead, we take advantage of the fact that it's a timestamp, and write
|
||||
it into the DB at one millisecond less than its current value. This
|
||||
ensures that we lock out any devices with an older generation number
|
||||
while avoiding errors with client state handling.
|
||||
|
||||
This does leave a tiny edge-case where we can fail to lock out older
|
||||
devices, if the generation number changes twice in less than a
|
||||
millisecond. This is acceptably unlikely in practice, and we'll recover
|
||||
as soon as we see an updated generation number as part of a sync.
|
||||
"""
|
||||
user = database.get_user(email)
|
||||
if user is not None:
|
||||
database.update_user(user, generation - 1)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Main entry-point for running this script.
|
||||
|
||||
This function parses command-line arguments and passes them on
|
||||
to the process_account_events() function.
|
||||
"""
|
||||
usage = "usage: %prog [options] queue_name"
|
||||
parser = optparse.OptionParser(usage=usage)
|
||||
parser.add_option("", "--aws-region",
|
||||
help="aws region in which the queue can be found")
|
||||
parser.add_option("", "--queue-wait-time", type="int", default=20,
|
||||
help="Number of seconds to wait for jobs on the queue")
|
||||
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
|
||||
help="Control verbosity of log messages")
|
||||
|
||||
opts, args = parser.parse_args(args)
|
||||
if len(args) != 1:
|
||||
parser.print_usage()
|
||||
return 1
|
||||
|
||||
util.configure_script_logging(opts)
|
||||
|
||||
queue_name = args[0]
|
||||
|
||||
process_account_events(queue_name, opts.aws_region, opts.queue_wait_time)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
util.run_script(main)
|
||||
177
tools/tokenserver/purge_old_records.py
Normal file
177
tools/tokenserver/purge_old_records.py
Normal file
@ -0,0 +1,177 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
"""
|
||||
|
||||
Script to purge user records that have been replaced.
|
||||
|
||||
This script purges any obsolete user records from the database.
|
||||
Obsolete records are those that have been replaced by a newer record for
|
||||
the same user.
|
||||
|
||||
Note that this is a purely optional administrative task, since replaced records
|
||||
are handled internally by the assignment backend. But it should help reduce
|
||||
overheads, improve performance etc if run regularly.
|
||||
|
||||
"""
|
||||
|
||||
import binascii
|
||||
import hawkauthlib
|
||||
import logging
|
||||
import optparse
|
||||
import random
|
||||
import requests
|
||||
import time
|
||||
import tokenlib
|
||||
|
||||
import util
|
||||
from database import Database
|
||||
from util import format_key_id
|
||||
|
||||
|
||||
logger = logging.getLogger("tokenserver.scripts.purge_old_records")
|
||||
|
||||
PATTERN = "{node}/1.5/{uid}"
|
||||
|
||||
|
||||
def purge_old_records(secret, grace_period=-1, max_per_loop=10, max_offset=0,
|
||||
request_timeout=60):
|
||||
"""Purge old records from the database.
|
||||
|
||||
This function queries all of the old user records in the database, deletes
|
||||
the Tokenserver database record for each of the users, and issues a delete
|
||||
request to each user's storage node. The result is a gradual pruning of
|
||||
expired items from each database.
|
||||
|
||||
`max_offset` is used to select a random offset into the list of purgeable
|
||||
records. With multiple tasks running concurrently, this will provide each
|
||||
a (likely) different set of records to work on. A cheap, imperfect
|
||||
randomization.
|
||||
"""
|
||||
logger.info("Purging old user records")
|
||||
try:
|
||||
database = Database()
|
||||
# Process batches of <max_per_loop> items, until we run out.
|
||||
while True:
|
||||
offset = random.randint(0, max_offset)
|
||||
kwds = {
|
||||
"grace_period": grace_period,
|
||||
"limit": max_per_loop,
|
||||
"offset": offset,
|
||||
}
|
||||
rows = list(database.get_old_user_records(**kwds))
|
||||
logger.info("Fetched %d rows at offset %d", len(rows), offset)
|
||||
for row in rows:
|
||||
# Don't attempt to purge data from downed nodes.
|
||||
# Instead wait for them to either come back up or to be
|
||||
# completely removed from service.
|
||||
if row.node is None:
|
||||
logger.info("Deleting user record for uid %s on %s",
|
||||
row.uid, row.node)
|
||||
database.delete_user_record(row.uid)
|
||||
elif not row.downed:
|
||||
logger.info("Purging uid %s on %s", row.uid, row.node)
|
||||
delete_service_data(row, secret, timeout=request_timeout)
|
||||
database.delete_user_record(row.uid)
|
||||
if len(rows) < max_per_loop:
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Error while purging old user records")
|
||||
return False
|
||||
else:
|
||||
logger.info("Finished purging old user records")
|
||||
return True
|
||||
|
||||
|
||||
def delete_service_data(user, secret, timeout=60):
|
||||
"""Send a data-deletion request to the user's service node.
|
||||
|
||||
This is a little bit of hackery to cause the user's service node to
|
||||
remove any data it still has stored for the user. We simulate a DELETE
|
||||
request from the user's own account.
|
||||
"""
|
||||
token = tokenlib.make_token({
|
||||
"uid": user.uid,
|
||||
"node": user.node,
|
||||
"fxa_uid": user.email.split("@", 1)[0],
|
||||
"fxa_kid": format_key_id(
|
||||
user.keys_changed_at or user.generation,
|
||||
binascii.unhexlify(user.client_state)
|
||||
),
|
||||
}, secret=secret)
|
||||
secret = tokenlib.get_derived_secret(token, secret=secret)
|
||||
endpoint = PATTERN.format(uid=user.uid, node=user.node)
|
||||
auth = HawkAuth(token, secret)
|
||||
resp = requests.delete(endpoint, auth=auth, timeout=timeout)
|
||||
if resp.status_code >= 400 and resp.status_code != 404:
|
||||
resp.raise_for_status()
|
||||
|
||||
|
||||
class HawkAuth(requests.auth.AuthBase):
|
||||
"""Hawk-signing auth helper class."""
|
||||
|
||||
def __init__(self, token, secret):
|
||||
self.token = token
|
||||
self.secret = secret
|
||||
|
||||
def __call__(self, req):
|
||||
hawkauthlib.sign_request(req, self.token, self.secret)
|
||||
return req
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Main entry-point for running this script.
|
||||
|
||||
This function parses command-line arguments and passes them on
|
||||
to the purge_old_records() function.
|
||||
"""
|
||||
usage = "usage: %prog [options] secret"
|
||||
parser = optparse.OptionParser(usage=usage)
|
||||
parser.add_option("", "--purge-interval", type="int", default=3600,
|
||||
help="Interval to sleep between purging runs")
|
||||
parser.add_option("", "--grace-period", type="int", default=86400,
|
||||
help="Number of seconds grace to allow on replacement")
|
||||
parser.add_option("", "--max-per-loop", type="int", default=10,
|
||||
help="Maximum number of items to fetch in one go")
|
||||
# N.B., if the number of purgeable rows is <<< max_offset then most
|
||||
# selects will return zero rows. Choose this value accordingly.
|
||||
parser.add_option("", "--max-offset", type="int", default=0,
|
||||
help="Use random offset from 0 to max_offset")
|
||||
parser.add_option("", "--request-timeout", type="int", default=60,
|
||||
help="Timeout for service deletion requests")
|
||||
parser.add_option("", "--oneshot", action="store_true",
|
||||
help="Do a single purge run and then exit")
|
||||
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
|
||||
help="Control verbosity of log messages")
|
||||
|
||||
opts, args = parser.parse_args(args)
|
||||
if len(args) != 2:
|
||||
parser.print_usage()
|
||||
return 1
|
||||
|
||||
secret = args[1]
|
||||
|
||||
util.configure_script_logging(opts)
|
||||
|
||||
purge_old_records(secret,
|
||||
grace_period=opts.grace_period,
|
||||
max_per_loop=opts.max_per_loop,
|
||||
max_offset=opts.max_offset,
|
||||
request_timeout=opts.request_timeout)
|
||||
if not opts.oneshot:
|
||||
while True:
|
||||
# Randomize sleep interval +/- thirty percent to desynchronize
|
||||
# instances of this script running on multiple webheads.
|
||||
sleep_time = opts.purge_interval
|
||||
sleep_time += random.randint(-0.3 * sleep_time, 0.3 * sleep_time)
|
||||
logger.debug("Sleeping for %d seconds", sleep_time)
|
||||
time.sleep(sleep_time)
|
||||
purge_old_records(grace_period=opts.grace_period,
|
||||
max_per_loop=opts.max_per_loop,
|
||||
max_offset=opts.max_offset,
|
||||
request_timeout=opts.request_timeout)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
util.run_script(main)
|
||||
73
tools/tokenserver/remove_node.py
Normal file
73
tools/tokenserver/remove_node.py
Normal file
@ -0,0 +1,73 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
"""
|
||||
|
||||
Script to remove a node from the system.
|
||||
|
||||
This script nukes any references to the named node - it is removed from
|
||||
the "nodes" table and any users currently assigned to that node have their
|
||||
assignments cleared.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import optparse
|
||||
|
||||
import util
|
||||
from database import Database
|
||||
|
||||
logger = logging.getLogger("tokenserver.scripts.remove_node")
|
||||
|
||||
|
||||
def remove_node(node):
|
||||
"""Remove the named node from the system."""
|
||||
logger.info("Removing node %s", node)
|
||||
try:
|
||||
database = Database()
|
||||
found = False
|
||||
try:
|
||||
database.remove_node(node)
|
||||
except ValueError:
|
||||
logger.debug(" not found")
|
||||
else:
|
||||
found = True
|
||||
logger.debug(" removed")
|
||||
except Exception:
|
||||
logger.exception("Error while removing node")
|
||||
return False
|
||||
else:
|
||||
if not found:
|
||||
logger.info("Node %s was not found", node)
|
||||
else:
|
||||
logger.info("Finished removing node %s", node)
|
||||
return True
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Main entry-point for running this script.
|
||||
|
||||
This function parses command-line arguments and passes them on
|
||||
to the remove_node() function.
|
||||
"""
|
||||
usage = "usage: %prog [options] node_name"
|
||||
descr = "Remove a node from the tokenserver database"
|
||||
parser = optparse.OptionParser(usage=usage, description=descr)
|
||||
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
|
||||
help="Control verbosity of log messages")
|
||||
|
||||
opts, args = parser.parse_args(args)
|
||||
if len(args) != 1:
|
||||
parser.print_usage()
|
||||
return 1
|
||||
|
||||
util.configure_script_logging(opts)
|
||||
|
||||
node_name = args[0]
|
||||
|
||||
remove_node(node_name)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
util.run_script(main)
|
||||
6
tools/tokenserver/requirements.txt
Normal file
6
tools/tokenserver/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
boto
|
||||
hawkauthlib
|
||||
mysqlclient
|
||||
pyramid
|
||||
sqlalchemy
|
||||
testfixtures
|
||||
25
tools/tokenserver/run_tests.py
Normal file
25
tools/tokenserver/run_tests.py
Normal file
@ -0,0 +1,25 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from test_database import TestDatabase
|
||||
from test_process_account_events import TestProcessAccountEvents
|
||||
from test_purge_old_records import TestPurgeOldRecords
|
||||
from test_scripts import TestScripts
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = unittest.TestLoader()
|
||||
test_cases = [TestDatabase, TestPurgeOldRecords, TestProcessAccountEvents,
|
||||
TestScripts]
|
||||
|
||||
res = 0
|
||||
for test_case in test_cases:
|
||||
suite = loader.loadTestsFromTestCase(test_case)
|
||||
runner = unittest.TextTestRunner()
|
||||
if not runner.run(suite).wasSuccessful():
|
||||
res = 1
|
||||
|
||||
sys.exit(res)
|
||||
446
tools/tokenserver/test_database.py
Normal file
446
tools/tokenserver/test_database.py
Normal file
@ -0,0 +1,446 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from collections import defaultdict
|
||||
from database import MAX_GENERATION, Database
|
||||
from util import get_timestamp
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super(TestDatabase, self).setUp()
|
||||
self.database = Database()
|
||||
# Start each test with a blank slate.
|
||||
cursor = self.database._execute_sql(('DELETE FROM users'), ())
|
||||
cursor.close()
|
||||
|
||||
cursor = self.database._execute_sql(('DELETE FROM nodes'), ())
|
||||
cursor.close()
|
||||
|
||||
self.database.add_node('https://phx12', 100)
|
||||
|
||||
def tearDown(self):
|
||||
super(TestDatabase, self).tearDown()
|
||||
# And clean up at the end, for good measure.
|
||||
cursor = self.database._execute_sql(('DELETE FROM users'), ())
|
||||
cursor.close()
|
||||
|
||||
cursor = self.database._execute_sql(('DELETE FROM nodes'), ())
|
||||
cursor.close()
|
||||
|
||||
self.database.close()
|
||||
|
||||
def test_node_allocation(self):
|
||||
user = self.database.get_user('test1@example.com')
|
||||
self.assertEquals(user, None)
|
||||
|
||||
user = self.database.allocate_user('test1@example.com')
|
||||
wanted = 'https://phx12'
|
||||
self.assertEqual(user['node'], wanted)
|
||||
|
||||
user = self.database.get_user('test1@example.com')
|
||||
self.assertEqual(user['node'], wanted)
|
||||
|
||||
def test_allocation_to_least_loaded_node(self):
|
||||
self.database.add_node('https://phx13', 100)
|
||||
user1 = self.database.allocate_user('test1@mozilla.com')
|
||||
user2 = self.database.allocate_user('test2@mozilla.com')
|
||||
self.assertNotEqual(user1['node'], user2['node'])
|
||||
|
||||
def test_allocation_is_not_allowed_to_downed_nodes(self):
|
||||
self.database.update_node('https://phx12',
|
||||
downed=True)
|
||||
with self.assertRaises(Exception):
|
||||
self.database.allocate_user('test1@mozilla.com')
|
||||
|
||||
def test_allocation_is_not_allowed_to_backoff_nodes(self):
|
||||
self.database.update_node('https://phx12',
|
||||
backoff=True)
|
||||
with self.assertRaises(Exception):
|
||||
self.database.allocate_user('test1@mozilla.com')
|
||||
|
||||
def test_update_generation_number(self):
|
||||
user = self.database.allocate_user('test1@example.com')
|
||||
self.assertEqual(user['generation'], 0)
|
||||
self.assertEqual(user['client_state'], '')
|
||||
orig_uid = user['uid']
|
||||
orig_node = user['node']
|
||||
|
||||
# Changing generation should leave other properties unchanged.
|
||||
self.database.update_user(user, generation=42)
|
||||
self.assertEqual(user['uid'], orig_uid)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 42)
|
||||
self.assertEqual(user['client_state'], '')
|
||||
|
||||
user = self.database.get_user('test1@example.com')
|
||||
self.assertEqual(user['uid'], orig_uid)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 42)
|
||||
self.assertEqual(user['client_state'], '')
|
||||
|
||||
# It's not possible to move generation number backwards.
|
||||
self.database.update_user(user, generation=17)
|
||||
self.assertEqual(user['uid'], orig_uid)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 42)
|
||||
self.assertEqual(user['client_state'], '')
|
||||
|
||||
user = self.database.get_user('test1@example.com')
|
||||
self.assertEqual(user['uid'], orig_uid)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 42)
|
||||
self.assertEqual(user['client_state'], '')
|
||||
|
||||
def test_update_client_state(self):
|
||||
user = self.database.allocate_user('test1@example.com')
|
||||
self.assertEqual(user['generation'], 0)
|
||||
self.assertEqual(user['client_state'], '')
|
||||
self.assertEqual(set(user['old_client_states']), set(()))
|
||||
seen_uids = set((user['uid'],))
|
||||
orig_node = user['node']
|
||||
|
||||
# Changing client-state allocates a new userid.
|
||||
self.database.update_user(user, client_state='aaa')
|
||||
self.assertTrue(user['uid'] not in seen_uids)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 0)
|
||||
self.assertEqual(user['client_state'], 'aaa')
|
||||
self.assertEqual(set(user['old_client_states']), set(('',)))
|
||||
|
||||
user = self.database.get_user('test1@example.com')
|
||||
self.assertTrue(user['uid'] not in seen_uids)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 0)
|
||||
self.assertEqual(user['client_state'], 'aaa')
|
||||
self.assertEqual(set(user['old_client_states']), set(('',)))
|
||||
|
||||
seen_uids.add(user['uid'])
|
||||
|
||||
# It's possible to change client-state and generation at once.
|
||||
self.database.update_user(user,
|
||||
client_state='bbb', generation=12)
|
||||
self.assertTrue(user['uid'] not in seen_uids)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 12)
|
||||
self.assertEqual(user['client_state'], 'bbb')
|
||||
self.assertEqual(set(user['old_client_states']), set(('', 'aaa')))
|
||||
|
||||
user = self.database.get_user('test1@example.com')
|
||||
self.assertTrue(user['uid'] not in seen_uids)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 12)
|
||||
self.assertEqual(user['client_state'], 'bbb')
|
||||
self.assertEqual(set(user['old_client_states']), set(('', 'aaa')))
|
||||
|
||||
# You can't got back to an old client_state.
|
||||
orig_uid = user['uid']
|
||||
with self.assertRaises(Exception):
|
||||
self.database.update_user(user,
|
||||
client_state='aaa')
|
||||
|
||||
user = self.database.get_user('test1@example.com')
|
||||
self.assertEqual(user['uid'], orig_uid)
|
||||
self.assertEqual(user['node'], orig_node)
|
||||
self.assertEqual(user['generation'], 12)
|
||||
self.assertEqual(user['client_state'], 'bbb')
|
||||
self.assertEqual(set(user['old_client_states']), set(('', 'aaa')))
|
||||
|
||||
def test_user_retirement(self):
|
||||
self.database.allocate_user('test@mozilla.com')
|
||||
user1 = self.database.get_user('test@mozilla.com')
|
||||
self.database.retire_user('test@mozilla.com')
|
||||
user2 = self.database.get_user('test@mozilla.com')
|
||||
self.assertTrue(user2['generation'] > user1['generation'])
|
||||
|
||||
def test_cleanup_of_old_records(self):
|
||||
# Create 6 user records for the first user.
|
||||
# Do a sleep halfway through so we can test use of grace period.
|
||||
email1 = 'test1@mozilla.com'
|
||||
user1 = self.database.allocate_user(email1)
|
||||
self.database.update_user(user1, client_state='a')
|
||||
self.database.update_user(user1, client_state='b')
|
||||
self.database.update_user(user1, client_state='c')
|
||||
break_time = time.time()
|
||||
time.sleep(0.1)
|
||||
self.database.update_user(user1, client_state='d')
|
||||
self.database.update_user(user1, client_state='e')
|
||||
records = list(self.database.get_user_records(email1))
|
||||
self.assertEqual(len(records), 6)
|
||||
# Create 3 user records for the second user.
|
||||
email2 = 'test2@mozilla.com'
|
||||
user2 = self.database.allocate_user(email2)
|
||||
self.database.update_user(user2, client_state='a')
|
||||
self.database.update_user(user2, client_state='b')
|
||||
records = list(self.database.get_user_records(email2))
|
||||
self.assertEqual(len(records), 3)
|
||||
# That should be a total of 7 old records.
|
||||
old_records = list(self.database.get_old_user_records(0))
|
||||
self.assertEqual(len(old_records), 7)
|
||||
# And with max_offset of 3, the first record should be id 4
|
||||
old_records = list(self.database.get_old_user_records(0,
|
||||
100, 3))
|
||||
# The 'limit' parameter should be respected.
|
||||
old_records = list(self.database.get_old_user_records(0, 2))
|
||||
self.assertEqual(len(old_records), 2)
|
||||
# The default grace period is too big to pick them up.
|
||||
old_records = list(self.database.get_old_user_records())
|
||||
self.assertEqual(len(old_records), 0)
|
||||
# The grace period can select a subset of the records.
|
||||
grace = time.time() - break_time
|
||||
old_records = list(self.database.get_old_user_records(grace))
|
||||
self.assertEqual(len(old_records), 3)
|
||||
# Old records can be successfully deleted:
|
||||
for record in old_records:
|
||||
self.database.delete_user_record(record.uid)
|
||||
old_records = list(self.database.get_old_user_records(0))
|
||||
self.assertEqual(len(old_records), 4)
|
||||
|
||||
def test_node_reassignment_when_records_are_replaced(self):
|
||||
self.database.allocate_user('test@mozilla.com',
|
||||
generation=42,
|
||||
keys_changed_at=12,
|
||||
client_state='aaa')
|
||||
user1 = self.database.get_user('test@mozilla.com')
|
||||
self.database.replace_user_records('test@mozilla.com')
|
||||
user2 = self.database.get_user('test@mozilla.com')
|
||||
# They should have got a new uid.
|
||||
self.assertNotEqual(user2['uid'], user1['uid'])
|
||||
# But their account metadata should have been preserved.
|
||||
self.assertEqual(user2['generation'], user1['generation'])
|
||||
self.assertEqual(user2['keys_changed_at'], user1['keys_changed_at'])
|
||||
self.assertEqual(user2['client_state'], user1['client_state'])
|
||||
|
||||
def test_node_reassignment_not_done_for_retired_users(self):
|
||||
self.database.allocate_user('test@mozilla.com',
|
||||
generation=42, client_state='aaa')
|
||||
user1 = self.database.get_user('test@mozilla.com')
|
||||
self.database.retire_user('test@mozilla.com')
|
||||
user2 = self.database.get_user('test@mozilla.com')
|
||||
self.assertEqual(user2['uid'], user1['uid'])
|
||||
self.assertEqual(user2['generation'], MAX_GENERATION)
|
||||
self.assertEqual(user2['client_state'], user2['client_state'])
|
||||
|
||||
def test_recovery_from_racy_record_creation(self):
|
||||
timestamp = get_timestamp()
|
||||
# Simulate race for forcing creation of two rows with same timestamp.
|
||||
user1 = self.database.allocate_user('test@mozilla.com',
|
||||
timestamp=timestamp)
|
||||
user2 = self.database.allocate_user('test@mozilla.com',
|
||||
timestamp=timestamp)
|
||||
self.assertNotEqual(user1['uid'], user2['uid'])
|
||||
# Neither is marked replaced initially.
|
||||
old_records = list(
|
||||
self.database.get_old_user_records(0)
|
||||
)
|
||||
self.assertEqual(len(old_records), 0)
|
||||
# Reading current details will detect the problem and fix it.
|
||||
self.database.get_user('test@mozilla.com')
|
||||
old_records = list(
|
||||
self.database.get_old_user_records(0)
|
||||
)
|
||||
self.assertEqual(len(old_records), 1)
|
||||
|
||||
def test_that_race_recovery_respects_generation_number_monotonicity(self):
|
||||
timestamp = get_timestamp()
|
||||
# Simulate race between clients with different generation numbers,
|
||||
# in which the out-of-date client gets a higher timestamp.
|
||||
user1 = self.database.allocate_user('test@mozilla.com',
|
||||
generation=1,
|
||||
timestamp=timestamp)
|
||||
user2 = self.database.allocate_user('test@mozilla.com',
|
||||
generation=2,
|
||||
timestamp=timestamp - 1)
|
||||
self.assertNotEqual(user1['uid'], user2['uid'])
|
||||
# Reading current details should promote the higher-generation one.
|
||||
user = self.database.get_user('test@mozilla.com')
|
||||
self.assertEqual(user['generation'], 2)
|
||||
self.assertEqual(user['uid'], user2['uid'])
|
||||
# And the other record should get marked as replaced.
|
||||
old_records = list(
|
||||
self.database.get_old_user_records(0)
|
||||
)
|
||||
self.assertEqual(len(old_records), 1)
|
||||
|
||||
def test_node_reassignment_and_removal(self):
|
||||
NODE1 = 'https://phx12'
|
||||
NODE2 = 'https://phx13'
|
||||
# note that NODE1 is created by default for all tests.
|
||||
self.database.add_node(NODE2, 100)
|
||||
# Assign four users, we should get two on each node.
|
||||
user1 = self.database.allocate_user('test1@mozilla.com')
|
||||
user2 = self.database.allocate_user('test2@mozilla.com')
|
||||
user3 = self.database.allocate_user('test3@mozilla.com')
|
||||
user4 = self.database.allocate_user('test4@mozilla.com')
|
||||
node_counts = defaultdict(lambda: 0)
|
||||
for user in (user1, user2, user3, user4):
|
||||
node_counts[user['node']] += 1
|
||||
self.assertEqual(node_counts[NODE1], 2)
|
||||
self.assertEqual(node_counts[NODE2], 2)
|
||||
# Clear the assignments for NODE1, and re-assign.
|
||||
# The users previously on NODE1 should balance across both nodes,
|
||||
# giving 1 on NODE1 and 3 on NODE2.
|
||||
self.database.unassign_node(NODE1)
|
||||
node_counts = defaultdict(lambda: 0)
|
||||
for user in (user1, user2, user3, user4):
|
||||
new_user = self.database.get_user(user['email'])
|
||||
if user['node'] == NODE2:
|
||||
self.assertEqual(new_user['node'], NODE2)
|
||||
node_counts[new_user['node']] += 1
|
||||
self.assertEqual(node_counts[NODE1], 1)
|
||||
self.assertEqual(node_counts[NODE2], 3)
|
||||
# Remove NODE2. Everyone should wind up on NODE1.
|
||||
self.database.remove_node(NODE2)
|
||||
for user in (user1, user2, user3, user4):
|
||||
new_user = self.database.get_user(user['email'])
|
||||
self.assertEqual(new_user['node'], NODE1)
|
||||
# The old users records pointing to NODE2 should have a NULL 'node'
|
||||
# property since it has been removed from the db.
|
||||
null_node_count = 0
|
||||
for row in self.database.get_old_user_records(0):
|
||||
if row.node is None:
|
||||
null_node_count += 1
|
||||
else:
|
||||
self.assertEqual(row.node, NODE1)
|
||||
self.assertEqual(null_node_count, 3)
|
||||
|
||||
def test_that_race_recovery_respects_generation_after_reassignment(self):
|
||||
timestamp = get_timestamp()
|
||||
# Simulate race between clients with different generation numbers,
|
||||
# in which the out-of-date client gets a higher timestamp.
|
||||
user1 = self.database.allocate_user('test@mozilla.com',
|
||||
generation=1,
|
||||
timestamp=timestamp)
|
||||
user2 = self.database.allocate_user('test@mozilla.com',
|
||||
generation=2,
|
||||
timestamp=timestamp - 1)
|
||||
self.assertNotEqual(user1['uid'], user2['uid'])
|
||||
# Force node re-assignment by marking all records as replaced.
|
||||
self.database.replace_user_records('test@mozilla.com',
|
||||
timestamp=timestamp + 1)
|
||||
# The next client to show up should get a new assignment, marked
|
||||
# with the correct generation number.
|
||||
user = self.database.get_user('test@mozilla.com')
|
||||
self.assertEqual(user['generation'], 2)
|
||||
self.assertNotEqual(user['uid'], user1['uid'])
|
||||
self.assertNotEqual(user['uid'], user2['uid'])
|
||||
|
||||
def test_that_we_can_allocate_users_to_a_specific_node(self):
|
||||
node = 'https://phx13'
|
||||
self.database.add_node(node, 50)
|
||||
# The new node is not selected by default, because of lower capacity.
|
||||
user = self.database.allocate_user('test1@mozilla.com')
|
||||
self.assertNotEqual(user['node'], node)
|
||||
# But we can force it using keyword argument.
|
||||
user = self.database.allocate_user('test2@mozilla.com',
|
||||
node=node)
|
||||
self.assertEqual(user['node'], node)
|
||||
|
||||
def test_that_we_can_move_users_to_a_specific_node(self):
|
||||
node = 'https://phx13'
|
||||
self.database.add_node(node, 50)
|
||||
# The new node is not selected by default, because of lower capacity.
|
||||
user = self.database.allocate_user('test@mozilla.com')
|
||||
self.assertNotEqual(user['node'], node)
|
||||
# But we can move them there explicitly using keyword argument.
|
||||
self.database.update_user(user, node=node)
|
||||
self.assertEqual(user['node'], node)
|
||||
# Sanity-check by re-reading it from the db.
|
||||
user = self.database.get_user('test@mozilla.com')
|
||||
self.assertEqual(user['node'], node)
|
||||
# Check that it properly respects client-state and generation.
|
||||
self.database.update_user(user, generation=12)
|
||||
self.database.update_user(user, client_state='XXX')
|
||||
self.database.update_user(user, generation=42,
|
||||
client_state='YYY', node='https://phx12')
|
||||
self.assertEqual(user['node'], 'https://phx12')
|
||||
self.assertEqual(user['generation'], 42)
|
||||
self.assertEqual(user['client_state'], 'YYY')
|
||||
self.assertEqual(sorted(user['old_client_states']), ['', 'XXX'])
|
||||
# Sanity-check by re-reading it from the db.
|
||||
user = self.database.get_user('test@mozilla.com')
|
||||
self.assertEqual(user['node'], 'https://phx12')
|
||||
self.assertEqual(user['generation'], 42)
|
||||
self.assertEqual(user['client_state'], 'YYY')
|
||||
self.assertEqual(sorted(user['old_client_states']), ['', 'XXX'])
|
||||
|
||||
def test_that_record_cleanup_frees_slots_on_the_node(self):
|
||||
node = 'https://phx12'
|
||||
self.database.update_node(node, capacity=10, available=1,
|
||||
current_load=9)
|
||||
# We should only be able to allocate one more user to that node.
|
||||
user = self.database.allocate_user('test1@mozilla.com')
|
||||
self.assertEqual(user['node'], node)
|
||||
with self.assertRaises(Exception):
|
||||
self.database.allocate_user('test2@mozilla.com')
|
||||
# But when we clean up the user's record, it frees up the slot.
|
||||
self.database.retire_user('test1@mozilla.com')
|
||||
self.database.delete_user_record(user['uid'])
|
||||
user = self.database.allocate_user('test2@mozilla.com')
|
||||
self.assertEqual(user['node'], node)
|
||||
|
||||
def test_gradual_release_of_node_capacity(self):
|
||||
node1 = 'https://phx12'
|
||||
self.database.update_node(node1, capacity=8, available=1,
|
||||
current_load=4)
|
||||
node2 = 'https://phx13'
|
||||
self.database.add_node(node2, capacity=6,
|
||||
available=1, current_load=4)
|
||||
# Two allocations should succeed without update, one on each node.
|
||||
user = self.database.allocate_user('test1@mozilla.com')
|
||||
self.assertEqual(user['node'], node1)
|
||||
user = self.database.allocate_user('test2@mozilla.com')
|
||||
self.assertEqual(user['node'], node2)
|
||||
# The next allocation attempt will release 10% more capacity,
|
||||
# which is one more slot for each node.
|
||||
user = self.database.allocate_user('test3@mozilla.com')
|
||||
self.assertEqual(user['node'], node1)
|
||||
user = self.database.allocate_user('test4@mozilla.com')
|
||||
self.assertEqual(user['node'], node2)
|
||||
# Now node2 is full, so further allocations all go to node1.
|
||||
user = self.database.allocate_user('test5@mozilla.com')
|
||||
self.assertEqual(user['node'], node1)
|
||||
user = self.database.allocate_user('test6@mozilla.com')
|
||||
self.assertEqual(user['node'], node1)
|
||||
# Until it finally reaches capacity.
|
||||
with self.assertRaises(Exception):
|
||||
self.database.allocate_user('test7@mozilla.com')
|
||||
|
||||
def test_count_users(self):
|
||||
user = self.database.allocate_user('test1@example.com')
|
||||
self.assertEqual(self.database.count_users(), 1)
|
||||
old_timestamp = get_timestamp()
|
||||
time.sleep(0.01)
|
||||
# Adding users increases the count.
|
||||
user = self.database.allocate_user('rfkelly@mozilla.com')
|
||||
self.assertEqual(self.database.count_users(), 2)
|
||||
# Updating a user doesn't change the count.
|
||||
self.database.update_user(user, client_state='aaa')
|
||||
self.assertEqual(self.database.count_users(), 2)
|
||||
# Looking back in time doesn't count newer users.
|
||||
self.assertEqual(self.database.count_users(old_timestamp), 1)
|
||||
# Retiring a user decreases the count.
|
||||
self.database.retire_user('test1@example.com')
|
||||
self.assertEqual(self.database.count_users(), 1)
|
||||
|
||||
def test_first_seen_at(self):
|
||||
EMAIL = 'test1@example.com'
|
||||
user0 = self.database.allocate_user(EMAIL)
|
||||
user1 = self.database.get_user(EMAIL)
|
||||
self.assertEqual(user1['uid'], user0['uid'])
|
||||
self.assertEqual(user1['first_seen_at'], user0['first_seen_at'])
|
||||
# It should stay consistent if we re-allocate the user's node.
|
||||
time.sleep(0.1)
|
||||
self.database.update_user(user1, client_state='aaa')
|
||||
user2 = self.database.get_user(EMAIL)
|
||||
self.assertNotEqual(user2['uid'], user0['uid'])
|
||||
self.assertEqual(user2['first_seen_at'], user0['first_seen_at'])
|
||||
# Until we purge their old node-assignment records.
|
||||
self.database.delete_user_record(user0['uid'])
|
||||
user3 = self.database.get_user(EMAIL)
|
||||
self.assertEqual(user3['uid'], user2['uid'])
|
||||
self.assertNotEqual(user3['first_seen_at'], user2['first_seen_at'])
|
||||
244
tools/tokenserver/test_process_account_events.py
Normal file
244
tools/tokenserver/test_process_account_events.py
Normal file
@ -0,0 +1,244 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from pyramid import testing
|
||||
from testfixtures import LogCapture
|
||||
|
||||
from database import Database
|
||||
from process_account_events import process_account_event
|
||||
|
||||
|
||||
PATTERN = "{node}/1.5/{uid}"
|
||||
EMAIL = "test@example.com"
|
||||
UID = "test"
|
||||
ISS = "example.com"
|
||||
|
||||
|
||||
def message_body(**kwds):
|
||||
return json.dumps({
|
||||
"Message": json.dumps(kwds)
|
||||
})
|
||||
|
||||
|
||||
class TestProcessAccountEvents(unittest.TestCase):
|
||||
|
||||
def get_ini(self):
|
||||
return os.path.join(os.path.dirname(__file__),
|
||||
'test_sql.ini')
|
||||
|
||||
def setUp(self):
|
||||
self.database = Database()
|
||||
self.database.add_node("https://phx12", 100)
|
||||
self.logs = LogCapture()
|
||||
|
||||
def tearDown(self):
|
||||
self.logs.uninstall()
|
||||
testing.tearDown()
|
||||
|
||||
cursor = self.database._execute_sql('DELETE FROM nodes')
|
||||
cursor.close()
|
||||
|
||||
cursor = self.database._execute_sql('DELETE FROM users')
|
||||
cursor.close
|
||||
|
||||
def assertMessageWasLogged(self, msg):
|
||||
"""Check that a metric was logged during the request."""
|
||||
for r in self.logs.records:
|
||||
if msg in r.getMessage():
|
||||
break
|
||||
else:
|
||||
assert False, "message %r was not logged" % (msg,)
|
||||
|
||||
def clearLogs(self):
|
||||
del self.logs.records[:]
|
||||
|
||||
def test_delete_user(self):
|
||||
self.database.allocate_user(EMAIL)
|
||||
user = self.database.get_user(EMAIL)
|
||||
self.database.update_user(user, client_state="abcdef")
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 2)
|
||||
self.assertTrue(records[0]["replaced_at"] is not None)
|
||||
|
||||
process_account_event(message_body(
|
||||
event="delete",
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
))
|
||||
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 2)
|
||||
for row in records:
|
||||
self.assertTrue(row["replaced_at"] is not None)
|
||||
|
||||
def test_delete_user_by_legacy_uid_format(self):
|
||||
self.database.allocate_user(EMAIL)
|
||||
user = self.database.get_user(EMAIL)
|
||||
self.database.update_user(user, client_state="abcdef")
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 2)
|
||||
self.assertTrue(records[0]["replaced_at"] is not None)
|
||||
|
||||
process_account_event(message_body(
|
||||
event="delete",
|
||||
uid=EMAIL,
|
||||
))
|
||||
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 2)
|
||||
for row in records:
|
||||
self.assertTrue(row["replaced_at"] is not None)
|
||||
|
||||
def test_delete_user_who_is_not_in_the_db(self):
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 0)
|
||||
|
||||
process_account_event(message_body(
|
||||
event="delete",
|
||||
uid=UID,
|
||||
iss=ISS
|
||||
))
|
||||
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 0)
|
||||
|
||||
def test_reset_user(self):
|
||||
self.database.allocate_user(EMAIL, generation=12)
|
||||
|
||||
process_account_event(message_body(
|
||||
event="reset",
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
generation=43,
|
||||
))
|
||||
|
||||
user = self.database.get_user(EMAIL)
|
||||
self.assertEquals(user["generation"], 42)
|
||||
|
||||
def test_reset_user_by_legacy_uid_format(self):
|
||||
self.database.allocate_user(EMAIL, generation=12)
|
||||
|
||||
process_account_event(message_body(
|
||||
event="reset",
|
||||
uid=EMAIL,
|
||||
generation=43,
|
||||
))
|
||||
|
||||
user = self.database.get_user(EMAIL)
|
||||
self.assertEquals(user["generation"], 42)
|
||||
|
||||
def test_reset_user_who_is_not_in_the_db(self):
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 0)
|
||||
|
||||
process_account_event(message_body(
|
||||
event="reset",
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
generation=43,
|
||||
))
|
||||
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 0)
|
||||
|
||||
def test_password_change(self):
|
||||
self.database.allocate_user(EMAIL, generation=12)
|
||||
|
||||
process_account_event(message_body(
|
||||
event="passwordChange",
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
generation=43,
|
||||
))
|
||||
|
||||
user = self.database.get_user(EMAIL)
|
||||
self.assertEquals(user["generation"], 42)
|
||||
|
||||
def test_password_change_user_not_in_db(self):
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 0)
|
||||
|
||||
process_account_event(message_body(
|
||||
event="passwordChange",
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
generation=43,
|
||||
))
|
||||
|
||||
records = list(self.database.get_user_records(EMAIL))
|
||||
self.assertEquals(len(records), 0)
|
||||
|
||||
def test_malformed_events(self):
|
||||
|
||||
# Unknown event type.
|
||||
process_account_event(message_body(
|
||||
event="party",
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
generation=43,
|
||||
))
|
||||
self.assertMessageWasLogged("Dropping unknown event type")
|
||||
self.clearLogs()
|
||||
|
||||
# Missing event type.
|
||||
process_account_event(message_body(
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
generation=43,
|
||||
))
|
||||
self.assertMessageWasLogged("Invalid account message")
|
||||
self.clearLogs()
|
||||
|
||||
# Missing uid.
|
||||
process_account_event(message_body(
|
||||
event="delete",
|
||||
iss=ISS,
|
||||
))
|
||||
self.assertMessageWasLogged("Invalid account message")
|
||||
self.clearLogs()
|
||||
|
||||
# Missing generation for reset events.
|
||||
process_account_event(message_body(
|
||||
event="reset",
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
))
|
||||
self.assertMessageWasLogged("Invalid account message")
|
||||
self.clearLogs()
|
||||
|
||||
# Missing generation for passwordChange events.
|
||||
process_account_event(message_body(
|
||||
event="passwordChange",
|
||||
uid=UID,
|
||||
iss=ISS,
|
||||
))
|
||||
self.assertMessageWasLogged("Invalid account message")
|
||||
self.clearLogs()
|
||||
|
||||
# Missing issuer with nonemail uid
|
||||
process_account_event(message_body(
|
||||
event="delete",
|
||||
uid=UID,
|
||||
))
|
||||
self.assertMessageWasLogged("Invalid account message")
|
||||
self.clearLogs()
|
||||
|
||||
# Non-JSON garbage.
|
||||
process_account_event("wat")
|
||||
self.assertMessageWasLogged("Invalid account message")
|
||||
self.clearLogs()
|
||||
|
||||
# Non-JSON garbage in Message field.
|
||||
process_account_event('{ "Message": "wat" }')
|
||||
self.assertMessageWasLogged("Invalid account message")
|
||||
self.clearLogs()
|
||||
|
||||
# Badly-typed JSON value in Message field.
|
||||
process_account_event('{ "Message": "[1, 2, 3"] }')
|
||||
self.assertMessageWasLogged("Invalid account message")
|
||||
self.clearLogs()
|
||||
136
tools/tokenserver/test_purge_old_records.py
Normal file
136
tools/tokenserver/test_purge_old_records.py
Normal file
@ -0,0 +1,136 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import hawkauthlib
|
||||
import re
|
||||
import threading
|
||||
import tokenlib
|
||||
import unittest
|
||||
from wsgiref.simple_server import make_server
|
||||
|
||||
from database import Database
|
||||
from purge_old_records import purge_old_records
|
||||
|
||||
|
||||
class TestPurgeOldRecords(unittest.TestCase):
|
||||
"""A testcase for proper functioning of the purge_old_records.py script.
|
||||
|
||||
This is a tricky one, because we have to actually run the script and
|
||||
test that it does the right thing. We also run a mock downstream service
|
||||
so we can test that data-deletion requests go through ok.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.service_requests = []
|
||||
cls.service_node = "http://localhost:8002"
|
||||
cls.service = make_server("localhost", 8002, cls._service_app)
|
||||
target = cls.service.serve_forever
|
||||
cls.service_thread = threading.Thread(target=target)
|
||||
# Note: If the following `start` causes the test thread to hang,
|
||||
# you may need to specify
|
||||
# `[app::pyramid.app] pyramid.worker_class = sync` in the test_*.ini
|
||||
# files
|
||||
cls.service_thread.start()
|
||||
# This silences nuisance on-by-default logging output.
|
||||
cls.service.RequestHandlerClass.log_request = lambda *a: None
|
||||
|
||||
def setUp(self):
|
||||
super(TestPurgeOldRecords, self).setUp()
|
||||
|
||||
# Configure the node-assignment backend to talk to our test service.
|
||||
self.database = Database()
|
||||
self.database.add_node(self.service_node, 100)
|
||||
|
||||
def tearDown(self):
|
||||
cursor = self.database._execute_sql('DELETE FROM nodes')
|
||||
cursor.close()
|
||||
|
||||
cursor = self.database._execute_sql('DELETE FROM users')
|
||||
cursor.close()
|
||||
|
||||
del self.service_requests[:]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.service.shutdown()
|
||||
cls.service_thread.join()
|
||||
|
||||
@classmethod
|
||||
def _service_app(cls, environ, start_response):
|
||||
cls.service_requests.append(environ)
|
||||
start_response("200 OK", [])
|
||||
return ""
|
||||
|
||||
def test_purging_of_old_user_records(self):
|
||||
# Make some old user records.
|
||||
email = "test@mozilla.com"
|
||||
user = self.database.allocate_user(email, client_state="aa",
|
||||
generation=123)
|
||||
self.database.update_user(user, client_state="bb",
|
||||
generation=456, keys_changed_at=450)
|
||||
self.database.update_user(user, client_state="cc",
|
||||
generation=789)
|
||||
user_records = list(self.database.get_user_records(email))
|
||||
self.assertEqual(len(user_records), 3)
|
||||
user = self.database.get_user(email)
|
||||
self.assertEquals(user["client_state"], "cc")
|
||||
self.assertEquals(len(user["old_client_states"]), 2)
|
||||
|
||||
# The default grace-period should prevent any cleanup.
|
||||
node_secret = "SECRET"
|
||||
self.assertTrue(purge_old_records(node_secret))
|
||||
user_records = list(self.database.get_user_records(email))
|
||||
self.assertEqual(len(user_records), 3)
|
||||
self.assertEqual(len(self.service_requests), 0)
|
||||
|
||||
# With no grace period, we should cleanup two old records.
|
||||
self.assertTrue(purge_old_records(node_secret, grace_period=0))
|
||||
user_records = list(self.database.get_user_records(email))
|
||||
self.assertEqual(len(user_records), 1)
|
||||
self.assertEqual(len(self.service_requests), 2)
|
||||
|
||||
# Check that the proper delete requests were made to the service.
|
||||
expected_kids = ["0000000000450-uw", "0000000000123-qg"]
|
||||
for i, environ in enumerate(self.service_requests):
|
||||
# They must be to the correct path.
|
||||
self.assertEquals(environ["REQUEST_METHOD"], "DELETE")
|
||||
self.assertTrue(re.match("/1.5/[0-9]+", environ["PATH_INFO"]))
|
||||
# They must have a correct request signature.
|
||||
token = hawkauthlib.get_id(environ)
|
||||
secret = tokenlib.get_derived_secret(token, secret=node_secret)
|
||||
self.assertTrue(hawkauthlib.check_signature(environ, secret))
|
||||
userdata = tokenlib.parse_token(token, secret=node_secret)
|
||||
self.assertTrue("uid" in userdata)
|
||||
self.assertTrue("node" in userdata)
|
||||
self.assertEqual(userdata["fxa_uid"], "test")
|
||||
self.assertEqual(userdata["fxa_kid"], expected_kids[i])
|
||||
|
||||
# Check that the user's current state is unaffected
|
||||
user = self.database.get_user(email)
|
||||
self.assertEquals(user["client_state"], "cc")
|
||||
self.assertEquals(len(user["old_client_states"]), 0)
|
||||
|
||||
def test_purging_is_not_done_on_downed_nodes(self):
|
||||
# Make some old user records.
|
||||
node_secret = "SECRET"
|
||||
email = "test@mozilla.com"
|
||||
user = self.database.allocate_user(email, client_state="aa")
|
||||
self.database.update_user(user, client_state="bb")
|
||||
user_records = list(self.database.get_user_records(email))
|
||||
self.assertEqual(len(user_records), 2)
|
||||
|
||||
# With the node down, we should not purge any records.
|
||||
self.database.update_node(self.service_node, downed=1)
|
||||
self.assertTrue(purge_old_records(node_secret, grace_period=0))
|
||||
user_records = list(self.database.get_user_records(email))
|
||||
self.assertEqual(len(user_records), 2)
|
||||
self.assertEqual(len(self.service_requests), 0)
|
||||
|
||||
# With the node back up, we should purge correctly.
|
||||
self.database.update_node(self.service_node, downed=0)
|
||||
self.assertTrue(purge_old_records(node_secret, grace_period=0))
|
||||
user_records = list(self.database.get_user_records(email))
|
||||
self.assertEqual(len(user_records), 1)
|
||||
self.assertEqual(len(self.service_requests), 1)
|
||||
223
tools/tokenserver/test_scripts.py
Normal file
223
tools/tokenserver/test_scripts.py
Normal file
@ -0,0 +1,223 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from add_node import main as add_node_script
|
||||
from allocate_user import main as allocate_user_script
|
||||
from count_users import main as count_users_script
|
||||
from database import Database
|
||||
from remove_node import main as remove_node_script
|
||||
from unassign_node import main as unassign_node_script
|
||||
from update_node import main as update_node_script
|
||||
from util import get_timestamp
|
||||
|
||||
|
||||
class TestScripts(unittest.TestCase):
|
||||
NODE_ID = 800
|
||||
NODE_URL = 'https://node1'
|
||||
|
||||
def setUp(self):
|
||||
self.database = Database()
|
||||
|
||||
# Start each test with a blank slate.
|
||||
cursor = self.database._execute_sql('DELETE FROM users')
|
||||
cursor.close()
|
||||
|
||||
cursor = self.database._execute_sql('DELETE FROM nodes')
|
||||
cursor.close()
|
||||
|
||||
# Ensure we have a node with enough capacity to run the tests.
|
||||
self.database.add_node(self.NODE_URL, 100, id=self.NODE_ID)
|
||||
|
||||
def tearDown(self):
|
||||
# And clean up at the end, for good measure.
|
||||
cursor = self.database._execute_sql('DELETE FROM users')
|
||||
cursor.close()
|
||||
|
||||
cursor = self.database._execute_sql('DELETE FROM nodes')
|
||||
cursor.close()
|
||||
|
||||
self.database.close()
|
||||
|
||||
def test_add_node(self):
|
||||
add_node_script(
|
||||
args=['--current-load', '9', 'test_node', '100']
|
||||
)
|
||||
res = self.database.get_node('test_node')
|
||||
# The node should have the expected attributes
|
||||
self.assertEqual(res.capacity, 100)
|
||||
self.assertEqual(res.available, 10)
|
||||
self.assertEqual(res.current_load, 9)
|
||||
self.assertEqual(res.downed, 0)
|
||||
self.assertEqual(res.backoff, 0)
|
||||
self.assertEqual(res.service, self.database.service_id)
|
||||
|
||||
def test_add_node_with_explicit_available(self):
|
||||
args = ['--current-load', '9', '--available', '5', 'test_node', '100']
|
||||
add_node_script(args=args)
|
||||
res = self.database.get_node('test_node')
|
||||
# The node should have the expected attributes
|
||||
self.assertEqual(res.capacity, 100)
|
||||
self.assertEqual(res.available, 5)
|
||||
self.assertEqual(res.current_load, 9)
|
||||
self.assertEqual(res.downed, 0)
|
||||
self.assertEqual(res.backoff, 0)
|
||||
self.assertEqual(res.service, self.database.service_id)
|
||||
|
||||
def test_add_downed_node(self):
|
||||
add_node_script(
|
||||
args=['--downed', 'test_node', '100']
|
||||
)
|
||||
res = self.database.get_node('test_node')
|
||||
# The node should have the expected attributes
|
||||
self.assertEqual(res.capacity, 100)
|
||||
self.assertEqual(res.available, 10)
|
||||
self.assertEqual(res.current_load, 0)
|
||||
self.assertEqual(res.downed, 1)
|
||||
self.assertEqual(res.backoff, 0)
|
||||
self.assertEqual(res.service, self.database.service_id)
|
||||
|
||||
def test_add_backoff_node(self):
|
||||
add_node_script(
|
||||
args=['--backoff', 'test_node', '100']
|
||||
)
|
||||
res = self.database.get_node('test_node')
|
||||
# The node should have the expected attributes
|
||||
self.assertEqual(res.capacity, 100)
|
||||
self.assertEqual(res.available, 10)
|
||||
self.assertEqual(res.current_load, 0)
|
||||
self.assertEqual(res.downed, 0)
|
||||
self.assertEqual(res.backoff, 1)
|
||||
self.assertEqual(res.service, self.database.service_id)
|
||||
|
||||
def test_allocate_user_user_already_exists(self):
|
||||
email = 'test@test.com'
|
||||
self.database.allocate_user(email)
|
||||
node = 'https://node2'
|
||||
self.database.add_node(node, 100)
|
||||
allocate_user_script(args=[email, node])
|
||||
user = self.database.get_user(email)
|
||||
# The user should be assigned to the given node
|
||||
self.assertEqual(user['node'], node)
|
||||
# Another user should not have been created
|
||||
count = self.database.count_users()
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
def test_allocate_user_given_node(self):
|
||||
email = 'test@test.com'
|
||||
node = 'https://node2'
|
||||
self.database.add_node(node, 100)
|
||||
allocate_user_script(args=[email, node])
|
||||
user = self.database.get_user(email)
|
||||
# A new user should be created and assigned to the given node
|
||||
self.assertEqual(user['node'], node)
|
||||
|
||||
def test_allocate_user_not_given_node(self):
|
||||
email = 'test@test.com'
|
||||
self.database.add_node('https://node2', 100,
|
||||
current_load=10)
|
||||
self.database.add_node('https://node3', 100,
|
||||
current_load=20)
|
||||
self.database.add_node('https://node4', 100,
|
||||
current_load=30)
|
||||
allocate_user_script(args=[email])
|
||||
user = self.database.get_user(email)
|
||||
# The user should be assigned to the least-loaded node
|
||||
self.assertEqual(user['node'], 'https://node1')
|
||||
|
||||
def test_count_users(self):
|
||||
self.database.allocate_user('test1@test.com')
|
||||
self.database.allocate_user('test2@test.com')
|
||||
self.database.allocate_user('test3@test.com')
|
||||
|
||||
timestamp = get_timestamp()
|
||||
filename = '/tmp/' + str(uuid.uuid4())
|
||||
try:
|
||||
count_users_script(
|
||||
args=['--output', filename, '--timestamp', str(timestamp)]
|
||||
)
|
||||
|
||||
with open(filename) as f:
|
||||
info = json.loads(f.readline())
|
||||
self.assertEqual(info['total_users'], 3)
|
||||
self.assertEqual(info['op'], 'sync_count_users')
|
||||
finally:
|
||||
os.remove(filename)
|
||||
|
||||
filename = '/tmp/' + str(uuid.uuid4())
|
||||
try:
|
||||
args = ['--output', filename, '--timestamp',
|
||||
str(timestamp - 10000)]
|
||||
count_users_script(args=args)
|
||||
|
||||
with open(filename) as f:
|
||||
info = json.loads(f.readline())
|
||||
self.assertEqual(info['total_users'], 0)
|
||||
self.assertEqual(info['op'], 'sync_count_users')
|
||||
finally:
|
||||
os.remove(filename)
|
||||
|
||||
def test_remove_node(self):
|
||||
self.database.add_node('https://node2', 100)
|
||||
self.database.allocate_user('test1@test.com',
|
||||
node='https://node2')
|
||||
self.database.allocate_user('test2@test.com',
|
||||
node=self.NODE_URL)
|
||||
self.database.allocate_user('test3@test.com',
|
||||
node=self.NODE_URL)
|
||||
|
||||
remove_node_script(args=['https://node2'])
|
||||
|
||||
# The node should have been removed from the database
|
||||
args = ['https://node2']
|
||||
self.assertRaises(ValueError, self.database.get_node_id, *args)
|
||||
# The first user should have been assigned to a new node
|
||||
user = self.database.get_user('test1@test.com')
|
||||
self.assertEqual(user['node'], self.NODE_URL)
|
||||
# The second and third users should still be on the first node
|
||||
user = self.database.get_user('test2@test.com')
|
||||
self.assertEqual(user['node'], self.NODE_URL)
|
||||
user = self.database.get_user('test3@test.com')
|
||||
self.assertEqual(user['node'], self.NODE_URL)
|
||||
|
||||
def test_unassign_node(self):
|
||||
self.database.add_node('https://node2', 100)
|
||||
self.database.allocate_user('test1@test.com',
|
||||
node='https://node2')
|
||||
self.database.allocate_user('test2@test.com',
|
||||
node='https://node2')
|
||||
self.database.allocate_user('test3@test.com',
|
||||
node=self.NODE_URL)
|
||||
|
||||
unassign_node_script(args=['https://node2'])
|
||||
self.database.remove_node('https://node2')
|
||||
# All of the users should now be assigned to the first node
|
||||
user = self.database.get_user('test1@test.com')
|
||||
self.assertEqual(user['node'], self.NODE_URL)
|
||||
user = self.database.get_user('test2@test.com')
|
||||
self.assertEqual(user['node'], self.NODE_URL)
|
||||
user = self.database.get_user('test3@test.com')
|
||||
self.assertEqual(user['node'], self.NODE_URL)
|
||||
|
||||
def test_update_node(self):
|
||||
self.database.add_node('https://node2', 100)
|
||||
update_node_script(args=[
|
||||
'--capacity', '150',
|
||||
'--available', '125',
|
||||
'--current-load', '25',
|
||||
'--downed',
|
||||
'--backoff',
|
||||
'https://node2'
|
||||
])
|
||||
node = self.database.get_node('https://node2')
|
||||
# Ensure the node has the expected attributes
|
||||
self.assertEqual(node['capacity'], 150)
|
||||
self.assertEqual(node['available'], 125)
|
||||
self.assertEqual(node['current_load'], 25)
|
||||
self.assertEqual(node['downed'], 1)
|
||||
self.assertEqual(node['backoff'], 1)
|
||||
72
tools/tokenserver/unassign_node.py
Normal file
72
tools/tokenserver/unassign_node.py
Normal file
@ -0,0 +1,72 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
"""
|
||||
|
||||
Script to remove a node from the system.
|
||||
|
||||
This script clears any assignments to the named node.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import optparse
|
||||
|
||||
from database import Database
|
||||
import util
|
||||
|
||||
|
||||
logger = logging.getLogger("tokenserver.scripts.unassign_node")
|
||||
|
||||
|
||||
def unassign_node(node):
|
||||
"""Clear any assignments to the named node."""
|
||||
logger.info("Unassignment node %s", node)
|
||||
try:
|
||||
database = Database()
|
||||
found = False
|
||||
try:
|
||||
database.unassign_node(node)
|
||||
except ValueError:
|
||||
logger.debug(" not found")
|
||||
else:
|
||||
found = True
|
||||
logger.debug(" unassigned")
|
||||
except Exception:
|
||||
logger.exception("Error while unassigning node")
|
||||
return False
|
||||
else:
|
||||
if not found:
|
||||
logger.info("Node %s was not found", node)
|
||||
else:
|
||||
logger.info("Finished unassigning node %s", node)
|
||||
return True
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Main entry-point for running this script.
|
||||
|
||||
This function parses command-line arguments and passes them on
|
||||
to the unassign_node() function.
|
||||
"""
|
||||
usage = "usage: %prog [options] node_name"
|
||||
descr = "Clear all assignments to node in the tokenserver database"
|
||||
parser = optparse.OptionParser(usage=usage, description=descr)
|
||||
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
|
||||
help="Control verbosity of log messages")
|
||||
|
||||
opts, args = parser.parse_args(args)
|
||||
if len(args) != 1:
|
||||
parser.print_usage()
|
||||
return 1
|
||||
|
||||
util.configure_script_logging(opts)
|
||||
|
||||
node_name = args[0]
|
||||
|
||||
unassign_node(node_name)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
util.run_script(main)
|
||||
83
tools/tokenserver/update_node.py
Normal file
83
tools/tokenserver/update_node.py
Normal file
@ -0,0 +1,83 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
"""
|
||||
|
||||
Script to update node status in the db.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import optparse
|
||||
|
||||
from database import Database
|
||||
import util
|
||||
|
||||
|
||||
logger = logging.getLogger("tokenserver.scripts.update_node")
|
||||
|
||||
|
||||
def update_node(node, **kwds):
|
||||
"""Update details of a node."""
|
||||
logger.info("Updating node %s for service %s", node)
|
||||
logger.debug("Value: %r", kwds)
|
||||
try:
|
||||
database = Database()
|
||||
database.update_node(node, **kwds)
|
||||
except Exception:
|
||||
logger.exception("Error while updating node")
|
||||
return False
|
||||
else:
|
||||
logger.info("Finished updating node %s", node)
|
||||
return True
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Main entry-point for running this script.
|
||||
|
||||
This function parses command-line arguments and passes them on
|
||||
to the update_node() function.
|
||||
"""
|
||||
usage = "usage: %prog [options] node_name"
|
||||
descr = "Update node details in the tokenserver database"
|
||||
parser = optparse.OptionParser(usage=usage, description=descr)
|
||||
parser.add_option("", "--capacity", type="int",
|
||||
help="How many user slots the node has overall")
|
||||
parser.add_option("", "--available", type="int",
|
||||
help="How many user slots the node has available")
|
||||
parser.add_option("", "--current-load", type="int",
|
||||
help="How many user slots the node has occupied")
|
||||
parser.add_option("", "--downed", action="store_true",
|
||||
help="Mark the node as down in the db")
|
||||
parser.add_option("", "--backoff", action="store_true",
|
||||
help="Mark the node as backed-off in the db")
|
||||
parser.add_option("-v", "--verbose", action="count", dest="verbosity",
|
||||
help="Control verbosity of log messages")
|
||||
|
||||
opts, args = parser.parse_args(args)
|
||||
if len(args) != 1:
|
||||
parser.print_usage()
|
||||
return 1
|
||||
|
||||
util.configure_script_logging(opts)
|
||||
|
||||
node_name = args[0]
|
||||
|
||||
kwds = {}
|
||||
if opts.capacity is not None:
|
||||
kwds["capacity"] = opts.capacity
|
||||
if opts.available is not None:
|
||||
kwds["available"] = opts.available
|
||||
if opts.current_load is not None:
|
||||
kwds["current_load"] = opts.current_load
|
||||
if opts.backoff is not None:
|
||||
kwds["backoff"] = opts.backoff
|
||||
if opts.downed is not None:
|
||||
kwds["downed"] = opts.downed
|
||||
|
||||
update_node(node_name, **kwds)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
util.run_script(main)
|
||||
58
tools/tokenserver/util.py
Normal file
58
tools/tokenserver/util.py
Normal file
@ -0,0 +1,58 @@
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
# You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
"""
|
||||
|
||||
Admin/managment scripts for TokenServer.
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
|
||||
from browserid.utils import encode_bytes as encode_bytes_b64
|
||||
|
||||
|
||||
def run_script(main):
|
||||
"""Simple wrapper for running scripts in __main__ section."""
|
||||
try:
|
||||
exitcode = main()
|
||||
except KeyboardInterrupt:
|
||||
exitcode = 1
|
||||
sys.exit(exitcode)
|
||||
|
||||
|
||||
def configure_script_logging(opts=None):
|
||||
"""Configure stdlib logging to produce output from the script.
|
||||
|
||||
This basically configures logging to send messages to stderr, with
|
||||
formatting that's more for human readability than machine parsing.
|
||||
It also takes care of the --verbosity command-line option.
|
||||
"""
|
||||
if not opts or not opts.verbosity:
|
||||
loglevel = logging.WARNING
|
||||
elif opts.verbosity == 1:
|
||||
loglevel = logging.INFO
|
||||
else:
|
||||
loglevel = logging.DEBUG
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
handler.setLevel(loglevel)
|
||||
logger = logging.getLogger("")
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(loglevel)
|
||||
|
||||
|
||||
def format_key_id(keys_changed_at, key_hash):
|
||||
"""Format an FxA key ID from a timestamp and key hash."""
|
||||
return "{:013d}-{}".format(
|
||||
keys_changed_at,
|
||||
encode_bytes_b64(key_hash),
|
||||
)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
"""Get current timestamp in milliseconds."""
|
||||
return int(time.time() * 1000)
|
||||
Loading…
x
Reference in New Issue
Block a user