From 4a145dd18bc13345179dbaedf6a0ae2d31ad4281 Mon Sep 17 00:00:00 2001 From: Philip Jenvey Date: Thu, 2 May 2024 09:09:53 -0700 Subject: [PATCH] feat: optionally force the spanner node via get_best_node (#1553) Issue SYNC-4181 --- tools/tokenserver/database.py | 54 +++++++++++----- .../test_process_account_events.py | 62 ++++++++++++++----- 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/tools/tokenserver/database.py b/tools/tokenserver/database.py index 1e8d7be2..58bfbd31 100644 --- a/tools/tokenserver/database.py +++ b/tools/tokenserver/database.py @@ -240,6 +240,18 @@ where and node = :node """) + +_GET_SPANNER_NODE = sqltext("""\ +select + id, node +from + nodes +where + id = :id +limit + 1 +""") + SERVICE_NAME = 'sync-1.5' @@ -251,6 +263,8 @@ class Database: connect() self.capacity_release_rate = os.environ. \ get("NODE_CAPACITY_RELEASE_RATE", 0.1) + self.spanner_node_id = os.environ.get( + "SYNC_TOKENSERVER__SPANNER_NODE_ID") def _execute_sql(self, *args, **kwds): return self.database.execute(*args, **kwds) @@ -641,26 +655,36 @@ class Database: """Returns the 'least loaded' node currently available, increments the active count on that node, and decrements the slots currently available """ - # We may have to re-try the query if we need to release more capacity. - # This loop allows a maximum of five retries before bailing out. - for _ in range(5): - res = self._execute_sql(_GET_BEST_NODE, - service=self._get_service_id(SERVICE_NAME)) + if self.spanner_node_id: + res = self._execute_sql( + _GET_SPANNER_NODE, + id=self.spanner_node_id + ) row = res.fetchone() res.close() - if row is None: - # Try to release additional capacity from any nodes - # that are not fully occupied. + else: + # We may have to re-try the query if we need to release more + # capacity. This loop allows a maximum of five retries before + # bailing out. + for _ in range(5): res = self._execute_sql( - _RELEASE_NODE_CAPACITY, - capacity_release_rate=self.capacity_release_rate, - service=self._get_service_id(SERVICE_NAME) - ) + _GET_BEST_NODE, + service=self._get_service_id(SERVICE_NAME)) + row = res.fetchone() res.close() - if res.rowcount == 0: + if row is None: + # Try to release additional capacity from any nodes + # that are not fully occupied. + res = self._execute_sql( + _RELEASE_NODE_CAPACITY, + capacity_release_rate=self.capacity_release_rate, + service=self._get_service_id(SERVICE_NAME) + ) + res.close() + if res.rowcount == 0: + break + else: break - else: - break # Did we succeed in finding a node? if row is None: diff --git a/tools/tokenserver/test_process_account_events.py b/tools/tokenserver/test_process_account_events.py index 52941b25..ca02ef17 100644 --- a/tools/tokenserver/test_process_account_events.py +++ b/tools/tokenserver/test_process_account_events.py @@ -25,7 +25,7 @@ def message_body(**kwds): }) -class TestProcessAccountEvents(unittest.TestCase): +class ProcessAccountEventsTestCase(unittest.TestCase): def get_ini(self): return os.path.join(os.path.dirname(__file__), @@ -64,12 +64,15 @@ class TestProcessAccountEvents(unittest.TestCase): def process_account_event(self, body): process_account_event(self.database, body) + +class TestProcessAccountEvents(ProcessAccountEventsTestCase): + def test_delete_user(self): self.database.allocate_user(EMAIL) user = self.database.get_user(EMAIL) self.database.update_user(user, client_state="abcdef") records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 2) + self.assertEqual(len(records), 2) self.assertTrue(records[0]["replaced_at"] is not None) self.process_account_event(message_body( @@ -79,7 +82,7 @@ class TestProcessAccountEvents(unittest.TestCase): )) records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 2) + self.assertEqual(len(records), 2) for row in records: self.assertTrue(row["replaced_at"] is not None) @@ -88,7 +91,7 @@ class TestProcessAccountEvents(unittest.TestCase): user = self.database.get_user(EMAIL) self.database.update_user(user, client_state="abcdef") records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 2) + self.assertEqual(len(records), 2) self.assertTrue(records[0]["replaced_at"] is not None) self.process_account_event(message_body( @@ -97,13 +100,13 @@ class TestProcessAccountEvents(unittest.TestCase): )) records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 2) + self.assertEqual(len(records), 2) for row in records: self.assertTrue(row["replaced_at"] is not None) def test_delete_user_who_is_not_in_the_db(self): records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 0) + self.assertEqual(len(records), 0) self.process_account_event(message_body( event="delete", @@ -112,7 +115,7 @@ class TestProcessAccountEvents(unittest.TestCase): )) records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 0) + self.assertEqual(len(records), 0) def test_reset_user(self): self.database.allocate_user(EMAIL, generation=12) @@ -125,7 +128,7 @@ class TestProcessAccountEvents(unittest.TestCase): )) user = self.database.get_user(EMAIL) - self.assertEquals(user["generation"], 42) + self.assertEqual(user["generation"], 42) def test_reset_user_by_legacy_uid_format(self): self.database.allocate_user(EMAIL, generation=12) @@ -137,11 +140,11 @@ class TestProcessAccountEvents(unittest.TestCase): )) user = self.database.get_user(EMAIL) - self.assertEquals(user["generation"], 42) + self.assertEqual(user["generation"], 42) def test_reset_user_who_is_not_in_the_db(self): records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 0) + self.assertEqual(len(records), 0) self.process_account_event(message_body( event="reset", @@ -151,7 +154,7 @@ class TestProcessAccountEvents(unittest.TestCase): )) records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 0) + self.assertEqual(len(records), 0) def test_password_change(self): self.database.allocate_user(EMAIL, generation=12) @@ -164,11 +167,11 @@ class TestProcessAccountEvents(unittest.TestCase): )) user = self.database.get_user(EMAIL) - self.assertEquals(user["generation"], 42) + self.assertEqual(user["generation"], 42) def test_password_change_user_not_in_db(self): records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 0) + self.assertEqual(len(records), 0) self.process_account_event(message_body( event="passwordChange", @@ -178,7 +181,7 @@ class TestProcessAccountEvents(unittest.TestCase): )) records = list(self.database.get_user_records(EMAIL)) - self.assertEquals(len(records), 0) + self.assertEqual(len(records), 0) def test_malformed_events(self): @@ -274,7 +277,7 @@ class TestProcessAccountEvents(unittest.TestCase): )) user = self.database.get_user(EMAIL) - self.assertEquals(user["generation"], 42) + self.assertEqual(user["generation"], 42) def test_update_with_no_keys_changed_at2(self): user = self.database.allocate_user( @@ -294,4 +297,31 @@ class TestProcessAccountEvents(unittest.TestCase): )) user = self.database.get_user(EMAIL) - self.assertEquals(user["generation"], 42) + self.assertEqual(user["generation"], 42) + + +class TestProcessAccountEventsForceSpanner(ProcessAccountEventsTestCase): + + def setUp(self): + super().setUp() + self.database.spanner_node_id = self.database.get_node_id( + "https://phx12") + + def test_delete_user_force_spanner(self): + self.database.allocate_user(EMAIL) + user = self.database.get_user(EMAIL) + self.database.update_user(user, client_state="abcdef") + records = list(self.database.get_user_records(EMAIL)) + self.assertEqual(len(records), 2) + self.assertTrue(records[0]["replaced_at"] is not None) + + self.process_account_event(message_body( + event="delete", + uid=UID, + iss=ISS, + )) + + records = list(self.database.get_user_records(EMAIL)) + self.assertEqual(len(records), 2) + for row in records: + self.assertTrue(row["replaced_at"] is not None)