From cc160822419cd56646d15d425812cf36a19d89a2 Mon Sep 17 00:00:00 2001 From: JR Conlin Date: Tue, 30 Apr 2024 08:53:27 -0700 Subject: [PATCH] feat: Allow uid range for purge function (SYNC-4246) (#1547) * feat: Allow uid range for purge function In an attempt to parallelize this script after a very long delay, specify a range so that multiple scripts can try to process different ranges of the database. Closes #1548 --- .../integration_tests/tokenserver/test_e2e.py | 2 + tools/tokenserver/database.py | 48 ++++++++++++++++++- tools/tokenserver/purge_old_records.py | 30 +++++++++++- tools/tokenserver/test_database.py | 24 ++++++++++ 4 files changed, 101 insertions(+), 3 deletions(-) diff --git a/tools/integration_tests/tokenserver/test_e2e.py b/tools/integration_tests/tokenserver/test_e2e.py index 222d19ff..692fd6ef 100644 --- a/tools/integration_tests/tokenserver/test_e2e.py +++ b/tools/integration_tests/tokenserver/test_e2e.py @@ -34,6 +34,8 @@ PASSWORD_LENGTH = 32 SCOPE = 'https://identity.mozilla.com/apps/oldsync' +@unittest.skip("Pending PyFxA oauth fix: " + "https://github.com/mozilla/PyFxA/issues/101") class TestE2e(TestCase, unittest.TestCase): def setUp(self): diff --git a/tools/tokenserver/database.py b/tools/tokenserver/database.py index d1aa711b..79a2e6b5 100644 --- a/tools/tokenserver/database.py +++ b/tools/tokenserver/database.py @@ -101,6 +101,26 @@ offset :offset """) +_GET_OLD_USER_RECORDS_FOR_SERVICE_RANGE = """\ +select + uid, email, generation, keys_changed_at, client_state, + nodes.node, nodes.downed, created_at, replaced_at +from + users left outer join nodes on users.nodeid = nodes.id +where + users.service = :service +and + ::RANGE:: +and + replaced_at is not null and replaced_at < :timestamp +order by + replaced_at desc, uid desc +limit + :limit +offset + :offset +""" + _GET_ALL_USER_RECORDS_FOR_SERVICE = sqltext("""\ select @@ -420,8 +440,29 @@ class Database: finally: res.close() + def _build_old_user_query(self, range, params, **kwargs): + if range: + # construct the range from the passed arguments + rstr = [] + try: + if range[0]: + rstr.append("uid > :start") + params["start"] = range[0] + if range[1]: + rstr.append("uid < :end") + params["end"] = range[1] + except IndexError: + pass + rrep = " and ".join(rstr) + sql = sqltext( + _GET_OLD_USER_RECORDS_FOR_SERVICE_RANGE.replace( + "::RANGE::", rrep)) + else: + sql = _GET_OLD_USER_RECORDS_FOR_SERVICE + return sql + def get_old_user_records(self, grace_period=-1, limit=100, - offset=0): + offset=0, range=None): """Get user records that were replaced outside the grace period.""" if grace_period < 0: grace_period = 60 * 60 * 24 * 7 # one week, in seconds @@ -432,7 +473,10 @@ class Database: "limit": limit, "offset": offset } - res = self._execute_sql(_GET_OLD_USER_RECORDS_FOR_SERVICE, **params) + + sql = self._build_old_user_query(range, params) + + res = self._execute_sql(sql, **params) try: for row in res: yield row diff --git a/tools/tokenserver/purge_old_records.py b/tools/tokenserver/purge_old_records.py index b0901b0f..6a53ccdf 100644 --- a/tools/tokenserver/purge_old_records.py +++ b/tools/tokenserver/purge_old_records.py @@ -45,6 +45,7 @@ def purge_old_records( dryrun=False, force=False, override_node=None, + uid_range=None, ): """Purge old records from the database. @@ -69,6 +70,7 @@ def purge_old_records( "grace_period": grace_period, "limit": max_per_loop, "offset": offset, + "range": uid_range, } rows = list(database.get_old_user_records(**kwds)) if not rows: @@ -77,7 +79,14 @@ def purge_old_records( if rows == previous_list: raise Exception("Loop detected") previous_list = rows - logger.info("Fetched %d rows at offset %d", len(rows), offset) + range_msg = "" + if uid_range: + range_msg = ( + f" within range {uid_range[0] or 'Start'}" + f" to {uid_range[1] or 'End'}" + ) + logger.info( + f"Fetched {len(rows)} rows at offset {offset}{range_msg}") counter = 0 for row in rows: # Don't attempt to purge data from downed nodes. @@ -313,6 +322,18 @@ def main(args=None): "", "--override_node", help="Use this node when deleting (if data was copied)" ) + parser.add_option( + "", + "--range_start", + default=None, + help="Start of UID range to check" + ) + parser.add_option( + "", + "--range_end", + default=None, + help="End of UID range to check" + ) opts, args = parser.parse_args(args) if len(args) != 2: @@ -323,6 +344,10 @@ def main(args=None): util.configure_script_logging(opts) + uid_range = None + if opts.start_range or opts.end_range: + uid_range = (opts.start_range, opts.end_range) + purge_old_records( secret, grace_period=opts.grace_period, @@ -333,6 +358,7 @@ def main(args=None): dryrun=opts.dryrun, force=opts.force, override_node=opts.override_node, + range=uid_range, ) if not opts.oneshot: while True: @@ -343,6 +369,7 @@ def main(args=None): logger.debug("Sleeping for %d seconds", sleep_time) time.sleep(sleep_time) purge_old_records( + secret, grace_period=opts.grace_period, max_per_loop=opts.max_per_loop, max_offset=opts.max_offset, @@ -351,6 +378,7 @@ def main(args=None): dryrun=opts.dryrun, force=opts.force, override_node=opts.override_node, + range=uid_range, ) return 0 diff --git a/tools/tokenserver/test_database.py b/tools/tokenserver/test_database.py index 065b7a8d..fac5980a 100644 --- a/tools/tokenserver/test_database.py +++ b/tools/tokenserver/test_database.py @@ -463,3 +463,27 @@ class TestDatabase(unittest.TestCase): user3 = self.database.get_user(EMAIL) self.assertEqual(user3['uid'], user2['uid']) self.assertNotEqual(user3['first_seen_at'], user2['first_seen_at']) + + def test_build_old_range(self): + params = dict() + sql = self.database._build_old_user_query(None, params) + self.assert_(sql.text.find("uid > :start") < 0) + self.assert_(sql.text.find("uid < :end") < 0) + self.assertIsNone(params.get("start")) + self.assertIsNone(params.get("end")) + + params = dict() + rrange = (None, "abcd") + sql = self.database._build_old_user_query(rrange, params) + self.assert_(sql.text.find("uid > :start") < 0) + self.assert_(sql.text.find("uid < :end") > 0) + self.assertIsNone(params.get("start")) + self.assertEqual(params.get("end"), rrange[1]) + + params = dict() + rrange = ("1234", "abcd") + sql = self.database._build_old_user_query(rrange, params) + self.assert_(sql.text.find("uid > :start") > 0) + self.assert_(sql.text.find("uid < :end") > 0) + self.assertEqual(params.get("start"), rrange[0]) + self.assertEqual(params.get("end"), rrange[1])