feat: optionally force the spanner node via get_best_node (#1553)

Issue SYNC-4181
This commit is contained in:
Philip Jenvey 2024-05-02 09:09:53 -07:00 committed by GitHub
parent aa9ef4649d
commit 4a145dd18b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 31 deletions

View File

@ -240,6 +240,18 @@ where
and node = :node
""")
_GET_SPANNER_NODE = sqltext("""\
select
id, node
from
nodes
where
id = :id
limit
1
""")
SERVICE_NAME = 'sync-1.5'
@ -251,6 +263,8 @@ class Database:
connect()
self.capacity_release_rate = os.environ. \
get("NODE_CAPACITY_RELEASE_RATE", 0.1)
self.spanner_node_id = os.environ.get(
"SYNC_TOKENSERVER__SPANNER_NODE_ID")
def _execute_sql(self, *args, **kwds):
return self.database.execute(*args, **kwds)
@ -641,26 +655,36 @@ class Database:
"""Returns the 'least loaded' node currently available, increments the
active count on that node, and decrements the slots currently available
"""
# We may have to re-try the query if we need to release more capacity.
# This loop allows a maximum of five retries before bailing out.
for _ in range(5):
res = self._execute_sql(_GET_BEST_NODE,
service=self._get_service_id(SERVICE_NAME))
if self.spanner_node_id:
res = self._execute_sql(
_GET_SPANNER_NODE,
id=self.spanner_node_id
)
row = res.fetchone()
res.close()
if row is None:
# Try to release additional capacity from any nodes
# that are not fully occupied.
else:
# We may have to re-try the query if we need to release more
# capacity. This loop allows a maximum of five retries before
# bailing out.
for _ in range(5):
res = self._execute_sql(
_RELEASE_NODE_CAPACITY,
capacity_release_rate=self.capacity_release_rate,
service=self._get_service_id(SERVICE_NAME)
)
_GET_BEST_NODE,
service=self._get_service_id(SERVICE_NAME))
row = res.fetchone()
res.close()
if res.rowcount == 0:
if row is None:
# Try to release additional capacity from any nodes
# that are not fully occupied.
res = self._execute_sql(
_RELEASE_NODE_CAPACITY,
capacity_release_rate=self.capacity_release_rate,
service=self._get_service_id(SERVICE_NAME)
)
res.close()
if res.rowcount == 0:
break
else:
break
else:
break
# Did we succeed in finding a node?
if row is None:

View File

@ -25,7 +25,7 @@ def message_body(**kwds):
})
class TestProcessAccountEvents(unittest.TestCase):
class ProcessAccountEventsTestCase(unittest.TestCase):
def get_ini(self):
return os.path.join(os.path.dirname(__file__),
@ -64,12 +64,15 @@ class TestProcessAccountEvents(unittest.TestCase):
def process_account_event(self, body):
process_account_event(self.database, body)
class TestProcessAccountEvents(ProcessAccountEventsTestCase):
def test_delete_user(self):
self.database.allocate_user(EMAIL)
user = self.database.get_user(EMAIL)
self.database.update_user(user, client_state="abcdef")
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 2)
self.assertEqual(len(records), 2)
self.assertTrue(records[0]["replaced_at"] is not None)
self.process_account_event(message_body(
@ -79,7 +82,7 @@ class TestProcessAccountEvents(unittest.TestCase):
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 2)
self.assertEqual(len(records), 2)
for row in records:
self.assertTrue(row["replaced_at"] is not None)
@ -88,7 +91,7 @@ class TestProcessAccountEvents(unittest.TestCase):
user = self.database.get_user(EMAIL)
self.database.update_user(user, client_state="abcdef")
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 2)
self.assertEqual(len(records), 2)
self.assertTrue(records[0]["replaced_at"] is not None)
self.process_account_event(message_body(
@ -97,13 +100,13 @@ class TestProcessAccountEvents(unittest.TestCase):
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 2)
self.assertEqual(len(records), 2)
for row in records:
self.assertTrue(row["replaced_at"] is not None)
def test_delete_user_who_is_not_in_the_db(self):
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
self.assertEqual(len(records), 0)
self.process_account_event(message_body(
event="delete",
@ -112,7 +115,7 @@ class TestProcessAccountEvents(unittest.TestCase):
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
self.assertEqual(len(records), 0)
def test_reset_user(self):
self.database.allocate_user(EMAIL, generation=12)
@ -125,7 +128,7 @@ class TestProcessAccountEvents(unittest.TestCase):
))
user = self.database.get_user(EMAIL)
self.assertEquals(user["generation"], 42)
self.assertEqual(user["generation"], 42)
def test_reset_user_by_legacy_uid_format(self):
self.database.allocate_user(EMAIL, generation=12)
@ -137,11 +140,11 @@ class TestProcessAccountEvents(unittest.TestCase):
))
user = self.database.get_user(EMAIL)
self.assertEquals(user["generation"], 42)
self.assertEqual(user["generation"], 42)
def test_reset_user_who_is_not_in_the_db(self):
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
self.assertEqual(len(records), 0)
self.process_account_event(message_body(
event="reset",
@ -151,7 +154,7 @@ class TestProcessAccountEvents(unittest.TestCase):
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
self.assertEqual(len(records), 0)
def test_password_change(self):
self.database.allocate_user(EMAIL, generation=12)
@ -164,11 +167,11 @@ class TestProcessAccountEvents(unittest.TestCase):
))
user = self.database.get_user(EMAIL)
self.assertEquals(user["generation"], 42)
self.assertEqual(user["generation"], 42)
def test_password_change_user_not_in_db(self):
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
self.assertEqual(len(records), 0)
self.process_account_event(message_body(
event="passwordChange",
@ -178,7 +181,7 @@ class TestProcessAccountEvents(unittest.TestCase):
))
records = list(self.database.get_user_records(EMAIL))
self.assertEquals(len(records), 0)
self.assertEqual(len(records), 0)
def test_malformed_events(self):
@ -274,7 +277,7 @@ class TestProcessAccountEvents(unittest.TestCase):
))
user = self.database.get_user(EMAIL)
self.assertEquals(user["generation"], 42)
self.assertEqual(user["generation"], 42)
def test_update_with_no_keys_changed_at2(self):
user = self.database.allocate_user(
@ -294,4 +297,31 @@ class TestProcessAccountEvents(unittest.TestCase):
))
user = self.database.get_user(EMAIL)
self.assertEquals(user["generation"], 42)
self.assertEqual(user["generation"], 42)
class TestProcessAccountEventsForceSpanner(ProcessAccountEventsTestCase):
def setUp(self):
super().setUp()
self.database.spanner_node_id = self.database.get_node_id(
"https://phx12")
def test_delete_user_force_spanner(self):
self.database.allocate_user(EMAIL)
user = self.database.get_user(EMAIL)
self.database.update_user(user, client_state="abcdef")
records = list(self.database.get_user_records(EMAIL))
self.assertEqual(len(records), 2)
self.assertTrue(records[0]["replaced_at"] is not None)
self.process_account_event(message_body(
event="delete",
uid=UID,
iss=ISS,
))
records = list(self.database.get_user_records(EMAIL))
self.assertEqual(len(records), 2)
for row in records:
self.assertTrue(row["replaced_at"] is not None)