diff --git a/tools/spanner/count_expired_rows.py b/tools/spanner/count_expired_rows.py index 9389381f..824a983c 100644 --- a/tools/spanner/count_expired_rows.py +++ b/tools/spanner/count_expired_rows.py @@ -13,6 +13,7 @@ from statsd.defaults.env import statsd from urllib import parse from google.cloud import spanner +from utils import ids_from_env # set up logger logging.basicConfig( @@ -23,33 +24,24 @@ logging.basicConfig( # Change these to match your install. client = spanner.Client() +def spanner_read_data(query: str, table: str) -> None: + """ + Executes a query on the specified Spanner table to count expired rows, + logs the result, and sends metrics to statsd. -def from_env(): - try: - url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL") - if not url: - raise Exception("no url") - purl = parse.urlparse(url) - if purl.scheme == "spanner": - path = purl.path.split("/") - instance_id = path[-3] - database_id = path[-1] - except Exception as e: - # Change these to reflect your Spanner instance install - print("Exception {}".format(e)) - instance_id = os.environ.get("INSTANCE_ID", "spanner-test") - database_id = os.environ.get("DATABASE_ID", "sync_stage") - return (instance_id, database_id) - - -def spanner_read_data(query, table): - (instance_id, database_id) = from_env() + Args: + query (str): The SQL query to execute. + table (str): The name of the table being queried. + Returns: + None + """ + (instance_id, database_id, project_id) = ids_from_env() instance = client.instance(instance_id) database = instance.database(database_id) - logging.info("For {}:{}".format(instance_id, database_id)) + logging.info(f"For {instance_id}:{database_id} {project_id}") - # Count bsos expired rows + # Count expired rows in the specified table with statsd.timer(f"syncstorage.count_expired_{table}_rows.duration"): with database.snapshot() as snapshot: result = snapshot.execute_sql(query) diff --git a/tools/spanner/count_users.py b/tools/spanner/count_users.py index 40fb0b16..db8771b7 100644 --- a/tools/spanner/count_users.py +++ b/tools/spanner/count_users.py @@ -13,6 +13,8 @@ from statsd.defaults.env import statsd from urllib import parse from google.cloud import spanner +from typing import Tuple +from utils import ids_from_env # set up logger logging.basicConfig( @@ -23,31 +25,26 @@ logging.basicConfig( # Change these to match your install. client = spanner.Client() +def spanner_read_data() -> None: + """ + Reads data from a Google Cloud Spanner database to count the number of distinct users. -def from_env(): - try: - url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL") - if not url: - raise Exception("no url") - purl = parse.urlparse(url) - if purl.scheme == "spanner": - path = purl.path.split("/") - instance_id = path[-3] - database_id = path[-1] - except Exception as e: - # Change these to reflect your Spanner instance install - print("Exception {}".format(e)) - instance_id = os.environ.get("INSTANCE_ID", "spanner-test") - database_id = os.environ.get("DATABASE_ID", "sync_stage") - return (instance_id, database_id) + This function connects to a Spanner instance and database using environment variables, + executes a SQL query to count the number of distinct `fxa_uid` entries in the `user_collections` table, + and logs the result. It also records the duration of the operation and the user count using statsd metrics. + Args: + None -def spanner_read_data(request=None): - (instance_id, database_id) = from_env() + Returns: + None + """ + (instance_id, database_id, project_id) = ids_from_env() instance = client.instance(instance_id) database = instance.database(database_id) + project = instance.database(database_id) - logging.info("For {}:{}".format(instance_id, database_id)) + logging.info(f"For {instance_id}:{database_id} {project}") # Count users with statsd.timer("syncstorage.count_users.duration"): @@ -56,7 +53,7 @@ def spanner_read_data(request=None): result = snapshot.execute_sql(query) user_count = result.one()[0] statsd.gauge("syncstorage.distinct_fxa_uid", user_count) - logging.info("Count found {} distinct users".format(user_count)) + logging.info(f"Count found {user_count} distinct users") if __name__ == "__main__": diff --git a/tools/spanner/purge_ttl.py b/tools/spanner/purge_ttl.py index b3d0cf01..19d34f76 100644 --- a/tools/spanner/purge_ttl.py +++ b/tools/spanner/purge_ttl.py @@ -17,6 +17,8 @@ from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1 import param_types from statsd.defaults.env import statsd +from utils import ids_from_env, Mode + # set up logger logging.basicConfig( format='{"datetime": "%(asctime)s", "message": "%(message)s"}', @@ -26,23 +28,6 @@ logging.basicConfig( # Change these to match your install. client = spanner.Client() - -def use_dsn(args): - try: - if not args.sync_database_url: - raise Exception("no url") - url = args.sync_database_url - purl = parse.urlparse(url) - if purl.scheme == "spanner": - path = purl.path.split("/") - args.instance_id = path[-3] - args.database_id = path[-1] - except Exception as e: - # Change these to reflect your Spanner instance install - print("Exception {}".format(e)) - return args - - def deleter(database: Database, name: str, query: str, @@ -50,17 +35,15 @@ def deleter(database: Database, params: Optional[dict]=None, param_types: Optional[dict]=None, dryrun: Optional[bool]=False): - with statsd.timer("syncstorage.purge_ttl.{}_duration".format(name)): - logging.info("Running: {} :: {}".format(query, params)) + with statsd.timer(f"syncstorage.purge_ttl.{name}_duration"): + logging.info(f"Running: {query} :: {params}") start = datetime.now() result = 0 if not dryrun: result = database.execute_partitioned_dml(query, params=params, param_types=param_types) end = datetime.now() logging.info( - "{name}: removed {result} rows, {name}_duration: {time}, prefix: {prefix}".format( - name=name, result=result, time=end - start, prefix=prefix)) - + f"{name}: removed {result} rows, {name}_duration: {end - start}, prefix: {prefix}") def add_conditions(args, query: str, prefix: Optional[str]): """ @@ -82,7 +65,7 @@ def add_conditions(args, query: str, prefix: Optional[str]): types['collection_id'] = param_types.INT64 else: for count,id in enumerate(ids): - name = 'collection_id_{}'.format(count) + name = f'collection_id_{count}' params[name] = id types[name] = param_types.INT64 query += " in (@{})".format( @@ -105,28 +88,43 @@ def get_expiry_condition(args): elif args.expiry_mode == "midnight": return 'expiry < TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, "UTC")' else: - raise Exception("Invalid expiry mode: {}".format(args.expiry_mode)) + raise Exception(f"Invalid expiry mode: {args.expiry_mode}") -def spanner_purge(args): +def spanner_purge(args) -> None: + """ + Purges expired TTL records from Spanner based on the provided arguments. + + This function connects to the specified Spanner instance and database, + determines the expiry condition, and deletes expired records from the + 'batches' and/or 'bsos' tables according to the purge mode. Supports + filtering by collection IDs and UID prefixes, and can operate in dry-run mode. + + Args: + args (argparse.Namespace): Parsed command-line arguments containing + Spanner connection info, purge options, and filters. + + Returns: + None + """ instance = client.instance(args.instance_id) database = instance.database(args.database_id) expiry_condition = get_expiry_condition(args) if args.auto_split: args.uid_prefixes = [ - hex(i).lstrip("0x").zfill(args.auto_split) for i in range( - 0, 16 ** args.auto_split)] + hex(i).lstrip("0x").zfill(args.auto_split) for i in range( + 0, 16 ** args.auto_split)] prefixes = args.uid_prefixes if args.uid_prefixes else [None] for prefix in prefixes: - logging.info("For {}:{}, prefix = {}".format(args.instance_id, args.database_id, prefix)) + logging.info(f"For {args.instance_id}:{args.database_id}, prefix = {prefix}") if args.mode in ["batches", "both"]: # Delete Batches. Also deletes child batch_bsos rows (INTERLEAVE # IN PARENT batches ON DELETE CASCADE) (batch_query, params, types) = add_conditions( args, - 'DELETE FROM batches WHERE {}'.format(expiry_condition), + f'DELETE FROM batches WHERE {expiry_condition}', prefix, ) deleter( @@ -143,7 +141,7 @@ def spanner_purge(args): # Delete BSOs (bso_query, params, types) = add_conditions( args, - 'DELETE FROM bsos WHERE {}'.format(expiry_condition), + f'DELETE FROM bsos WHERE {expiry_condition}', prefix ) deleter( @@ -158,6 +156,23 @@ def spanner_purge(args): def get_args(): + """ + Parses and returns command-line arguments for the Spanner TTL purge tool. + If a DSN URL is provided, usually `SYNC_SYNCSTORAGE__DATABASE_URL`, its values override the corresponding arguments. + + Returns: + argparse.Namespace: Parsed command-line arguments with the following attributes: + - instance_id (str): Spanner instance ID (default from INSTANCE_ID env or 'spanner-test'). + - database_id (str): Spanner database ID (default from DATABASE_ID env or 'sync_schema3'). + - project_id (str): Google Cloud project ID (default from GOOGLE_CLOUD_PROJECT env or 'spanner-test'). + - sync_database_url (str): Spanner database DSN (default from SYNC_SYNCSTORAGE__DATABASE_URL env). + - collection_ids (list): List of collection IDs to purge (default from COLLECTION_IDS env or empty list). + - uid_prefixes (list): List of UID prefixes to limit purges (default from PURGE_UID_PREFIXES env or empty list). + - auto_split (int): Number of digits to auto-generate UID prefixes (default from PURGE_AUTO_SPLIT env). + - mode (str): Purge mode, one of 'batches', 'bsos', or 'both' (default from PURGE_MODE env or 'both'). + - expiry_mode (str): Expiry mode, either 'now' or 'midnight' (default from PURGE_EXPIRY_MODE env or 'midnight'). + - dryrun (bool): If True, do not actually purge records from Spanner. + """ parser = argparse.ArgumentParser( description="Purge old TTLs" ) @@ -173,6 +188,12 @@ def get_args(): default=os.environ.get("DATABASE_ID", "sync_schema3"), help="Spanner Database ID" ) + parser.add_argument( + "-p", + "--project_id", + default=os.environ.get("GOOGLE_CLOUD_PROJECT", "spanner-test"), + help="Spanner Project ID" + ) parser.add_argument( "-u", "--sync_database_url", @@ -225,17 +246,23 @@ def get_args(): # override using the DSN URL: if args.sync_database_url: - args = use_dsn(args) - + (instance_id, database_id, project_id) = ids_from_env(args.sync_database_url, mode=Mode.URL) + args.instance_id = instance_id + args.database_id = database_id + args.project_id = project_id return args def parse_args_list(args_list: str) -> List[str]: + """ - Parse a list of items (or a single string) into a list of strings. - Example input: [item1,item2,item3] - :param args_list: The list/string - :return: A list of strings + Parses a string representing a list of items into a list of strings. + + Args: + args_list (str): String to parse, e.g., "[item1,item2,item3]" or "item1". + + Returns: + List[str]: List of parsed string items. """ if args_list[0] != "[" or args_list[-1] != "]": # Assume it's a single item @@ -248,11 +275,10 @@ if __name__ == "__main__": args = get_args() with statsd.timer("syncstorage.purge_ttl.total_duration"): start_time = datetime.now() - logging.info('Starting purge_ttl.py') + logging.info("Starting purge_ttl.py") spanner_purge(args) end_time = datetime.now() duration = end_time - start_time - logging.info( - 'Completed purge_ttl.py, total_duration: {}'.format(duration)) + logging.info(f"Completed purge_ttl.py, total_duration: {duration}") diff --git a/tools/spanner/test_count_expired_rows.py b/tools/spanner/test_count_expired_rows.py new file mode 100644 index 00000000..bfde5f35 --- /dev/null +++ b/tools/spanner/test_count_expired_rows.py @@ -0,0 +1,47 @@ +import os +import types +from unittest.mock import MagicMock +import pytest +import logging + +import count_expired_rows +from utils import ids_from_env + +def test_spanner_read_data_counts_and_logs(monkeypatch, caplog): + # Prepare mocks + mock_instance = MagicMock() + mock_database = MagicMock() + mock_snapshot_ctx = MagicMock() + mock_snapshot = MagicMock() + mock_result = MagicMock() + mock_result.one.return_value = [42] + mock_snapshot.execute_sql.return_value = mock_result + mock_snapshot_ctx.__enter__.return_value = mock_snapshot + mock_database.snapshot.return_value = mock_snapshot_ctx + + # Patch spanner client and statsd + monkeypatch.setattr(count_expired_rows, "client", MagicMock()) + count_expired_rows.client.instance.return_value = mock_instance + mock_instance.database.return_value = mock_database + + mock_statsd = MagicMock() + monkeypatch.setattr(count_expired_rows, "statsd", mock_statsd) + mock_statsd.timer.return_value.__enter__.return_value = None + mock_statsd.timer.return_value.__exit__.return_value = None + + # Patch from_env to return fixed values + monkeypatch.setattr(count_expired_rows, "ids_from_env", lambda: ("inst", "db", "proj")) + + # Run function + with caplog.at_level(logging.INFO): + count_expired_rows.spanner_read_data("SELECT COUNT(*) FROM foo", "foo") + + # Check logs + assert any("For inst:db proj" in m for m in caplog.messages) + assert any("Found 42 expired rows in foo" in m for m in caplog.messages) + + # Check statsd calls + mock_statsd.gauge.assert_called_with("syncstorage.expired_foo_rows", 42) + mock_statsd.timer.assert_called_with("syncstorage.count_expired_foo_rows.duration") + mock_database.snapshot.assert_called_once() + mock_snapshot.execute_sql.assert_called_with("SELECT COUNT(*) FROM foo") \ No newline at end of file diff --git a/tools/spanner/test_purge_ttl.py b/tools/spanner/test_purge_ttl.py new file mode 100644 index 00000000..d6f0bbb5 --- /dev/null +++ b/tools/spanner/test_purge_ttl.py @@ -0,0 +1,147 @@ +import pytest +from unittest import mock +from types import SimpleNamespace + +import sys + +# Import the functions to test from purge_ttl.py +import purge_ttl + +def test_parse_args_list_single_item(): + assert purge_ttl.parse_args_list("foo") == ["foo"] + +def test_parse_args_list_multiple_items(): + assert purge_ttl.parse_args_list("[a,b,c]") == ["a", "b", "c"] + +def test_get_expiry_condition_now(): + args = SimpleNamespace(expiry_mode="now") + assert purge_ttl.get_expiry_condition(args) == 'expiry < CURRENT_TIMESTAMP()' + +def test_get_expiry_condition_midnight(): + args = SimpleNamespace(expiry_mode="midnight") + assert purge_ttl.get_expiry_condition(args) == 'expiry < TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, "UTC")' + +def test_get_expiry_condition_invalid(): + args = SimpleNamespace(expiry_mode="invalid") + with pytest.raises(Exception): + purge_ttl.get_expiry_condition(args) + +def test_add_conditions_no_collections_no_prefix(): + args = SimpleNamespace(collection_ids=[], uid_prefixes=None) + query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None) + assert query == "SELECT * FROM foo WHERE 1=1" + assert params == {} + assert types == {} + +def test_add_conditions_with_collections_single(): + args = SimpleNamespace(collection_ids=["123"]) + query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None) + assert "collection_id = @collection_id" in query + assert params["collection_id"] == "123" + assert types["collection_id"] == purge_ttl.param_types.INT64 + +def test_add_conditions_with_collections_multiple(): + args = SimpleNamespace(collection_ids=["1", "2"]) + query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None) + assert "collection_id in" in query + assert params["collection_id_0"] == "1" + assert params["collection_id_1"] == "2" + assert types["collection_id_0"] == purge_ttl.param_types.INT64 + +def test_add_conditions_with_prefix(): + args = SimpleNamespace(collection_ids=[]) + query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", "abc") + assert "STARTS_WITH(fxa_uid, @prefix)" in query + assert params["prefix"] == "abc" + assert types["prefix"] == purge_ttl.param_types.STRING + +@mock.patch("purge_ttl.statsd") +def test_deleter_dryrun(statsd_mock): + database = mock.Mock() + statsd_mock.timer.return_value.__enter__.return_value = None + statsd_mock.timer.return_value.__exit__.return_value = None + purge_ttl.deleter(database, "batches", "DELETE FROM batches", dryrun=True) + database.execute_partitioned_dml.assert_not_called() + +@mock.patch("purge_ttl.statsd") +def test_deleter_executes(statsd_mock): + database = mock.Mock() + statsd_mock.timer.return_value.__enter__.return_value = None + statsd_mock.timer.return_value.__exit__.return_value = None + database.execute_partitioned_dml.return_value = 42 + + purge_ttl.deleter(database, "batches", "DELETE FROM batches", dryrun=False) + database.execute_partitioned_dml.assert_called_once() + +@mock.patch("purge_ttl.deleter") +@mock.patch("purge_ttl.add_conditions") +@mock.patch("purge_ttl.get_expiry_condition") +@mock.patch("purge_ttl.client") +def test_spanner_purge_both(client_mock, get_expiry_condition_mock, add_conditions_mock, deleter_mock): + # Setup + args = SimpleNamespace( + instance_id="inst", + database_id="db", + expiry_mode="now", + auto_split=None, + uid_prefixes=None, + mode="both", + dryrun=True, + collection_ids=[] + ) + + instance = mock.Mock() + database = mock.Mock() + client_mock.instance.return_value = instance + instance.database.return_value = database + + get_expiry_condition_mock.return_value = "expiry < CURRENT_TIMESTAMP()" + add_conditions_mock.side_effect = [ + ("batch_query", {"a": 1}, {"a": 2}), + ("bso_query", {"b": 3}, {"b": 4}), + ] + purge_ttl.spanner_purge(args) + + assert deleter_mock.call_count == 2 + deleter_mock.assert_called() + + deleter_mock.assert_any_call( + database, + name="batches", + query="batch_query", + params={"a": 1}, + param_types={"a": 2}, + prefix=None, + dryrun=True, + ) + + deleter_mock.assert_any_call( + database, + name="bso", + query="bso_query", + params={"b": 3}, + param_types={"b": 4}, + prefix=None, + dryrun=True, + ) + +@mock.patch("argparse.ArgumentParser.parse_args") +def test_get_args_env_and_dsn(parse_args_mock): + # Simulate args with DSN + args = SimpleNamespace( + instance_id="foo", + database_id="bar", + project_id="baz", + sync_database_url="spanner://projects/proj/instances/inst/databases/db", + collection_ids=[], + uid_prefixes=[], + auto_split=None, + mode="both", + expiry_mode="midnight", + dryrun=False, + ) + parse_args_mock.return_value = args + result = purge_ttl.get_args() + assert result.project_id == "proj" + assert result.instance_id == "inst" + assert result.database_id == "db" \ No newline at end of file diff --git a/tools/spanner/test_utils.py b/tools/spanner/test_utils.py new file mode 100644 index 00000000..77b47f74 --- /dev/null +++ b/tools/spanner/test_utils.py @@ -0,0 +1,46 @@ +import pytest + +import utils +from unittest.mock import MagicMock + +@pytest.fixture(autouse=True) +def reset_env(monkeypatch): + # Reset environment variables before each test + for var in [ + "SYNC_SYNCSTORAGE__DATABASE_URL", + "INSTANCE_ID", + "DATABASE_ID", + "GOOGLE_CLOUD_PROJECT" + ]: + monkeypatch.delenv(var, raising=False) + +def test_ids_from_env_parses_url(monkeypatch): + """Test with passed in DSN""" + monkeypatch.setenv("SYNC_SYNCSTORAGE__DATABASE_URL", "spanner://projects/proj/instances/inst/databases/db") + dsn = "SYNC_SYNCSTORAGE__DATABASE_URL" + instance_id, database_id, project_id = utils.ids_from_env(dsn) + assert project_id == "proj" + assert instance_id == "inst" + assert database_id == "db" + +def test_ids_from_env_with_missing_url(monkeypatch): + """Test ensures that default env vars set id values.""" + monkeypatch.setenv("INSTANCE_ID", "foo") + monkeypatch.setenv("DATABASE_ID", "bar") + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "baz") + instance_id, database_id, project_id = utils.ids_from_env() + assert instance_id == "foo" + assert database_id == "bar" + assert project_id == "baz" + + +def test_from_env_with_invalid_url(monkeypatch): + monkeypatch.setenv("SYNC_SYNCSTORAGE__DATABASE_URL", "notaspanner://foo") + monkeypatch.setenv("INSTANCE_ID", "default") + monkeypatch.setenv("DATABASE_ID", "default-db") + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "default-proj") + + instance_id, database_id, project_id = utils.ids_from_env() + assert instance_id == "default" + assert database_id == "default-db" + assert project_id == "default-proj" \ No newline at end of file diff --git a/tools/spanner/utils.py b/tools/spanner/utils.py new file mode 100644 index 00000000..eb77b947 --- /dev/null +++ b/tools/spanner/utils.py @@ -0,0 +1,68 @@ +# Utility Module for spanner CLI scripts +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. + +from enum import auto, Enum +import os +from urllib import parse +from typing import Tuple +from unittest.mock import MagicMock + +DSN_URL = "SYNC_SYNCSTORAGE__DATABASE_URL" +""" +Environment variable that stores Sync database URL +Depending on deployment, can be MySQL or Spanner. +In this context, should always point to spanner for these scripts. +""" + +class Mode(Enum): + URL = auto() + ENV_VAR = auto() + +def ids_from_env(dsn=DSN_URL, mode=Mode.ENV_VAR) -> Tuple[str, str, str]: + """ + Function that extracts the instance, project, and database ids from the DSN url. + It is defined as the SYNC_SYNCSTORAGE__DATABASE_URL environment variable. + The defined defaults are in webservices-infra/sync and can be configured there for + production runs. + + `dsn` argument is set to default to the `DSN_URL` constant. + + For reference, an example spanner url passed in is in the following format: + + `spanner://projects/moz-fx-sync-prod-xxxx/instances/sync/databases/syncdb` + database_id = `syncdb`, instance_id = `sync`, project_id = `moz-fx-sync-prod-xxxx` + """ + # Change these to reflect your Spanner instance install + instance_id = None + database_id = None + project_id = None + + try: + if mode == Mode.ENV_VAR: + url = os.environ.get(dsn) + if not url: + raise Exception(f"No env var found for provided DSN: {dsn}") + elif mode == Mode.URL: + url = dsn + if not url: + raise Exception(f"No valid url found: {url}") + parsed_url = parse.urlparse(url) + if parsed_url.scheme == "spanner": + path = parsed_url.path.split("/") + instance_id = path[-3] + project_id = path[-5] + database_id = path[-1] + except Exception as e: + print(f"Exception parsing url: {e}") + # Fallbacks if not set + if not instance_id: + instance_id = os.environ.get("INSTANCE_ID", "spanner-test") + if not database_id: + database_id = os.environ.get("DATABASE_ID", "sync_stage") + if not project_id: + project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", "test-project") + + return (instance_id, database_id, project_id) diff --git a/tools/spanner/write_batch.py b/tools/spanner/write_batch.py index 0a795a01..923e3c7e 100644 --- a/tools/spanner/write_batch.py +++ b/tools/spanner/write_batch.py @@ -21,6 +21,8 @@ from google.api_core.exceptions import AlreadyExists from google.cloud import spanner from google.cloud.spanner_v1 import param_types +from utils import ids_from_env + # max batch size for this write is 2000, otherwise we run into: """google.api_core.exceptions.InvalidArgument: 400 The transaction @@ -79,12 +81,12 @@ PAYLOAD = ''.join( def load(instance, db, coll_id, name): fxa_uid = "DEADBEEF" + uuid.uuid4().hex[8:] fxa_kid = "{:013d}-{}".format(22, fxa_uid) - print("{} -> Loading {} {}".format(name, fxa_uid, fxa_kid)) + print(f"{name} -> Loading {fxa_uid} {fxa_kid}") name = threading.current_thread().getName() spanner_client = spanner.Client() instance = spanner_client.instance(instance) db = instance.database(db) - print('{name} Db: {db}'.format(name=name, db=db)) + print(f"{name} Db: {db}") start = datetime.now() def create_user(txn): @@ -110,16 +112,14 @@ def load(instance, db, coll_id, name): try: db.run_in_transaction(create_user) - print('{name} Created user (fxa_uid: {uid}, fxa_kid: {kid})'.format( - name=name, uid=fxa_uid, kid=fxa_kid)) + print(f"{name} Created user (fxa_uid: {fxa_uid}, fxa_kid: {fxa_kid})") except AlreadyExists: - print('{name} Existing user (fxa_uid: {uid}}, fxa_kid: {kid}})'.format( - name=name, uid=fxa_uid, kid=fxa_kid)) + print(f"{name} Existing user (fxa_uid: {fxa_uid}, fxa_kid: {fxa_kid})") # approximately 1892 bytes rlen = 0 - print('{name} Loading..'.format(name=name)) + print(f"{name} Loading..") for j in range(BATCHES): records = [] for i in range(BATCH_SIZE): @@ -176,29 +176,11 @@ def load(instance, db, coll_id, name): )) -def from_env(): - try: - url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL") - if not url: - raise Exception("no url") - purl = parse.urlparse(url) - if purl.scheme == "spanner": - path = purl.path.split("/") - instance_id = path[-3] - database_id = path[-1] - except Exception as e: - # Change these to reflect your Spanner instance install - print("Exception {}".format(e)) - instance_id = os.environ.get("INSTANCE_ID", "spanner-test") - database_id = os.environ.get("DATABASE_ID", "sync_stage") - return (instance_id, database_id) - - def loader(): # Prefix uaids for easy filtering later # Each loader thread gets it's own fake user to prevent some hotspot # issues. - (instance_id, database_id) = from_env() + (instance_id, database_id, _) = ids_from_env() # switching uid/kid to per load because of weird google trimming name = threading.current_thread().getName() load(instance_id, database_id, COLL_ID, name) @@ -206,9 +188,9 @@ def loader(): def main(): for c in range(THREAD_COUNT): - print("Starting thread {}".format(c)) + print(f"Starting thread {c}") t = threading.Thread( - name="loader_{}".format(c), + name=f"loader_{c}", target=loader) t.start()