mirror of
https://github.com/mozilla-services/syncstorage-rs.git
synced 2025-08-06 11:56:58 +02:00
feat: spanner scripts parse gcp project (#1714)
Some checks failed
Glean probe-scraper / glean-probe-scraper (push) Has been cancelled
Some checks failed
Glean probe-scraper / glean-probe-scraper (push) Has been cancelled
feat: spanner scripts parse gcp project
This commit is contained in:
parent
4ddf5b4169
commit
d716ac5d10
@ -13,6 +13,7 @@ from statsd.defaults.env import statsd
|
|||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from google.cloud import spanner
|
from google.cloud import spanner
|
||||||
|
from utils import ids_from_env
|
||||||
|
|
||||||
# set up logger
|
# set up logger
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -23,33 +24,24 @@ logging.basicConfig(
|
|||||||
# Change these to match your install.
|
# Change these to match your install.
|
||||||
client = spanner.Client()
|
client = spanner.Client()
|
||||||
|
|
||||||
|
def spanner_read_data(query: str, table: str) -> None:
|
||||||
|
"""
|
||||||
|
Executes a query on the specified Spanner table to count expired rows,
|
||||||
|
logs the result, and sends metrics to statsd.
|
||||||
|
|
||||||
def from_env():
|
Args:
|
||||||
try:
|
query (str): The SQL query to execute.
|
||||||
url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL")
|
table (str): The name of the table being queried.
|
||||||
if not url:
|
Returns:
|
||||||
raise Exception("no url")
|
None
|
||||||
purl = parse.urlparse(url)
|
"""
|
||||||
if purl.scheme == "spanner":
|
(instance_id, database_id, project_id) = ids_from_env()
|
||||||
path = purl.path.split("/")
|
|
||||||
instance_id = path[-3]
|
|
||||||
database_id = path[-1]
|
|
||||||
except Exception as e:
|
|
||||||
# Change these to reflect your Spanner instance install
|
|
||||||
print("Exception {}".format(e))
|
|
||||||
instance_id = os.environ.get("INSTANCE_ID", "spanner-test")
|
|
||||||
database_id = os.environ.get("DATABASE_ID", "sync_stage")
|
|
||||||
return (instance_id, database_id)
|
|
||||||
|
|
||||||
|
|
||||||
def spanner_read_data(query, table):
|
|
||||||
(instance_id, database_id) = from_env()
|
|
||||||
instance = client.instance(instance_id)
|
instance = client.instance(instance_id)
|
||||||
database = instance.database(database_id)
|
database = instance.database(database_id)
|
||||||
|
|
||||||
logging.info("For {}:{}".format(instance_id, database_id))
|
logging.info(f"For {instance_id}:{database_id} {project_id}")
|
||||||
|
|
||||||
# Count bsos expired rows
|
# Count expired rows in the specified table
|
||||||
with statsd.timer(f"syncstorage.count_expired_{table}_rows.duration"):
|
with statsd.timer(f"syncstorage.count_expired_{table}_rows.duration"):
|
||||||
with database.snapshot() as snapshot:
|
with database.snapshot() as snapshot:
|
||||||
result = snapshot.execute_sql(query)
|
result = snapshot.execute_sql(query)
|
||||||
|
@ -13,6 +13,8 @@ from statsd.defaults.env import statsd
|
|||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
from google.cloud import spanner
|
from google.cloud import spanner
|
||||||
|
from typing import Tuple
|
||||||
|
from utils import ids_from_env
|
||||||
|
|
||||||
# set up logger
|
# set up logger
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -23,31 +25,26 @@ logging.basicConfig(
|
|||||||
# Change these to match your install.
|
# Change these to match your install.
|
||||||
client = spanner.Client()
|
client = spanner.Client()
|
||||||
|
|
||||||
|
def spanner_read_data() -> None:
|
||||||
|
"""
|
||||||
|
Reads data from a Google Cloud Spanner database to count the number of distinct users.
|
||||||
|
|
||||||
def from_env():
|
This function connects to a Spanner instance and database using environment variables,
|
||||||
try:
|
executes a SQL query to count the number of distinct `fxa_uid` entries in the `user_collections` table,
|
||||||
url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL")
|
and logs the result. It also records the duration of the operation and the user count using statsd metrics.
|
||||||
if not url:
|
|
||||||
raise Exception("no url")
|
|
||||||
purl = parse.urlparse(url)
|
|
||||||
if purl.scheme == "spanner":
|
|
||||||
path = purl.path.split("/")
|
|
||||||
instance_id = path[-3]
|
|
||||||
database_id = path[-1]
|
|
||||||
except Exception as e:
|
|
||||||
# Change these to reflect your Spanner instance install
|
|
||||||
print("Exception {}".format(e))
|
|
||||||
instance_id = os.environ.get("INSTANCE_ID", "spanner-test")
|
|
||||||
database_id = os.environ.get("DATABASE_ID", "sync_stage")
|
|
||||||
return (instance_id, database_id)
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
def spanner_read_data(request=None):
|
Returns:
|
||||||
(instance_id, database_id) = from_env()
|
None
|
||||||
|
"""
|
||||||
|
(instance_id, database_id, project_id) = ids_from_env()
|
||||||
instance = client.instance(instance_id)
|
instance = client.instance(instance_id)
|
||||||
database = instance.database(database_id)
|
database = instance.database(database_id)
|
||||||
|
project = instance.database(database_id)
|
||||||
|
|
||||||
logging.info("For {}:{}".format(instance_id, database_id))
|
logging.info(f"For {instance_id}:{database_id} {project}")
|
||||||
|
|
||||||
# Count users
|
# Count users
|
||||||
with statsd.timer("syncstorage.count_users.duration"):
|
with statsd.timer("syncstorage.count_users.duration"):
|
||||||
@ -56,7 +53,7 @@ def spanner_read_data(request=None):
|
|||||||
result = snapshot.execute_sql(query)
|
result = snapshot.execute_sql(query)
|
||||||
user_count = result.one()[0]
|
user_count = result.one()[0]
|
||||||
statsd.gauge("syncstorage.distinct_fxa_uid", user_count)
|
statsd.gauge("syncstorage.distinct_fxa_uid", user_count)
|
||||||
logging.info("Count found {} distinct users".format(user_count))
|
logging.info(f"Count found {user_count} distinct users")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -17,6 +17,8 @@ from google.cloud.spanner_v1.database import Database
|
|||||||
from google.cloud.spanner_v1 import param_types
|
from google.cloud.spanner_v1 import param_types
|
||||||
from statsd.defaults.env import statsd
|
from statsd.defaults.env import statsd
|
||||||
|
|
||||||
|
from utils import ids_from_env, Mode
|
||||||
|
|
||||||
# set up logger
|
# set up logger
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format='{"datetime": "%(asctime)s", "message": "%(message)s"}',
|
format='{"datetime": "%(asctime)s", "message": "%(message)s"}',
|
||||||
@ -26,23 +28,6 @@ logging.basicConfig(
|
|||||||
# Change these to match your install.
|
# Change these to match your install.
|
||||||
client = spanner.Client()
|
client = spanner.Client()
|
||||||
|
|
||||||
|
|
||||||
def use_dsn(args):
|
|
||||||
try:
|
|
||||||
if not args.sync_database_url:
|
|
||||||
raise Exception("no url")
|
|
||||||
url = args.sync_database_url
|
|
||||||
purl = parse.urlparse(url)
|
|
||||||
if purl.scheme == "spanner":
|
|
||||||
path = purl.path.split("/")
|
|
||||||
args.instance_id = path[-3]
|
|
||||||
args.database_id = path[-1]
|
|
||||||
except Exception as e:
|
|
||||||
# Change these to reflect your Spanner instance install
|
|
||||||
print("Exception {}".format(e))
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def deleter(database: Database,
|
def deleter(database: Database,
|
||||||
name: str,
|
name: str,
|
||||||
query: str,
|
query: str,
|
||||||
@ -50,17 +35,15 @@ def deleter(database: Database,
|
|||||||
params: Optional[dict]=None,
|
params: Optional[dict]=None,
|
||||||
param_types: Optional[dict]=None,
|
param_types: Optional[dict]=None,
|
||||||
dryrun: Optional[bool]=False):
|
dryrun: Optional[bool]=False):
|
||||||
with statsd.timer("syncstorage.purge_ttl.{}_duration".format(name)):
|
with statsd.timer(f"syncstorage.purge_ttl.{name}_duration"):
|
||||||
logging.info("Running: {} :: {}".format(query, params))
|
logging.info(f"Running: {query} :: {params}")
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
result = 0
|
result = 0
|
||||||
if not dryrun:
|
if not dryrun:
|
||||||
result = database.execute_partitioned_dml(query, params=params, param_types=param_types)
|
result = database.execute_partitioned_dml(query, params=params, param_types=param_types)
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
logging.info(
|
logging.info(
|
||||||
"{name}: removed {result} rows, {name}_duration: {time}, prefix: {prefix}".format(
|
f"{name}: removed {result} rows, {name}_duration: {end - start}, prefix: {prefix}")
|
||||||
name=name, result=result, time=end - start, prefix=prefix))
|
|
||||||
|
|
||||||
|
|
||||||
def add_conditions(args, query: str, prefix: Optional[str]):
|
def add_conditions(args, query: str, prefix: Optional[str]):
|
||||||
"""
|
"""
|
||||||
@ -82,7 +65,7 @@ def add_conditions(args, query: str, prefix: Optional[str]):
|
|||||||
types['collection_id'] = param_types.INT64
|
types['collection_id'] = param_types.INT64
|
||||||
else:
|
else:
|
||||||
for count,id in enumerate(ids):
|
for count,id in enumerate(ids):
|
||||||
name = 'collection_id_{}'.format(count)
|
name = f'collection_id_{count}'
|
||||||
params[name] = id
|
params[name] = id
|
||||||
types[name] = param_types.INT64
|
types[name] = param_types.INT64
|
||||||
query += " in (@{})".format(
|
query += " in (@{})".format(
|
||||||
@ -105,28 +88,43 @@ def get_expiry_condition(args):
|
|||||||
elif args.expiry_mode == "midnight":
|
elif args.expiry_mode == "midnight":
|
||||||
return 'expiry < TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, "UTC")'
|
return 'expiry < TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, "UTC")'
|
||||||
else:
|
else:
|
||||||
raise Exception("Invalid expiry mode: {}".format(args.expiry_mode))
|
raise Exception(f"Invalid expiry mode: {args.expiry_mode}")
|
||||||
|
|
||||||
|
|
||||||
def spanner_purge(args):
|
def spanner_purge(args) -> None:
|
||||||
|
"""
|
||||||
|
Purges expired TTL records from Spanner based on the provided arguments.
|
||||||
|
|
||||||
|
This function connects to the specified Spanner instance and database,
|
||||||
|
determines the expiry condition, and deletes expired records from the
|
||||||
|
'batches' and/or 'bsos' tables according to the purge mode. Supports
|
||||||
|
filtering by collection IDs and UID prefixes, and can operate in dry-run mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (argparse.Namespace): Parsed command-line arguments containing
|
||||||
|
Spanner connection info, purge options, and filters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
instance = client.instance(args.instance_id)
|
instance = client.instance(args.instance_id)
|
||||||
database = instance.database(args.database_id)
|
database = instance.database(args.database_id)
|
||||||
expiry_condition = get_expiry_condition(args)
|
expiry_condition = get_expiry_condition(args)
|
||||||
if args.auto_split:
|
if args.auto_split:
|
||||||
args.uid_prefixes = [
|
args.uid_prefixes = [
|
||||||
hex(i).lstrip("0x").zfill(args.auto_split) for i in range(
|
hex(i).lstrip("0x").zfill(args.auto_split) for i in range(
|
||||||
0, 16 ** args.auto_split)]
|
0, 16 ** args.auto_split)]
|
||||||
prefixes = args.uid_prefixes if args.uid_prefixes else [None]
|
prefixes = args.uid_prefixes if args.uid_prefixes else [None]
|
||||||
|
|
||||||
for prefix in prefixes:
|
for prefix in prefixes:
|
||||||
logging.info("For {}:{}, prefix = {}".format(args.instance_id, args.database_id, prefix))
|
logging.info(f"For {args.instance_id}:{args.database_id}, prefix = {prefix}")
|
||||||
|
|
||||||
if args.mode in ["batches", "both"]:
|
if args.mode in ["batches", "both"]:
|
||||||
# Delete Batches. Also deletes child batch_bsos rows (INTERLEAVE
|
# Delete Batches. Also deletes child batch_bsos rows (INTERLEAVE
|
||||||
# IN PARENT batches ON DELETE CASCADE)
|
# IN PARENT batches ON DELETE CASCADE)
|
||||||
(batch_query, params, types) = add_conditions(
|
(batch_query, params, types) = add_conditions(
|
||||||
args,
|
args,
|
||||||
'DELETE FROM batches WHERE {}'.format(expiry_condition),
|
f'DELETE FROM batches WHERE {expiry_condition}',
|
||||||
prefix,
|
prefix,
|
||||||
)
|
)
|
||||||
deleter(
|
deleter(
|
||||||
@ -143,7 +141,7 @@ def spanner_purge(args):
|
|||||||
# Delete BSOs
|
# Delete BSOs
|
||||||
(bso_query, params, types) = add_conditions(
|
(bso_query, params, types) = add_conditions(
|
||||||
args,
|
args,
|
||||||
'DELETE FROM bsos WHERE {}'.format(expiry_condition),
|
f'DELETE FROM bsos WHERE {expiry_condition}',
|
||||||
prefix
|
prefix
|
||||||
)
|
)
|
||||||
deleter(
|
deleter(
|
||||||
@ -158,6 +156,23 @@ def spanner_purge(args):
|
|||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
"""
|
||||||
|
Parses and returns command-line arguments for the Spanner TTL purge tool.
|
||||||
|
If a DSN URL is provided, usually `SYNC_SYNCSTORAGE__DATABASE_URL`, its values override the corresponding arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.Namespace: Parsed command-line arguments with the following attributes:
|
||||||
|
- instance_id (str): Spanner instance ID (default from INSTANCE_ID env or 'spanner-test').
|
||||||
|
- database_id (str): Spanner database ID (default from DATABASE_ID env or 'sync_schema3').
|
||||||
|
- project_id (str): Google Cloud project ID (default from GOOGLE_CLOUD_PROJECT env or 'spanner-test').
|
||||||
|
- sync_database_url (str): Spanner database DSN (default from SYNC_SYNCSTORAGE__DATABASE_URL env).
|
||||||
|
- collection_ids (list): List of collection IDs to purge (default from COLLECTION_IDS env or empty list).
|
||||||
|
- uid_prefixes (list): List of UID prefixes to limit purges (default from PURGE_UID_PREFIXES env or empty list).
|
||||||
|
- auto_split (int): Number of digits to auto-generate UID prefixes (default from PURGE_AUTO_SPLIT env).
|
||||||
|
- mode (str): Purge mode, one of 'batches', 'bsos', or 'both' (default from PURGE_MODE env or 'both').
|
||||||
|
- expiry_mode (str): Expiry mode, either 'now' or 'midnight' (default from PURGE_EXPIRY_MODE env or 'midnight').
|
||||||
|
- dryrun (bool): If True, do not actually purge records from Spanner.
|
||||||
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Purge old TTLs"
|
description="Purge old TTLs"
|
||||||
)
|
)
|
||||||
@ -173,6 +188,12 @@ def get_args():
|
|||||||
default=os.environ.get("DATABASE_ID", "sync_schema3"),
|
default=os.environ.get("DATABASE_ID", "sync_schema3"),
|
||||||
help="Spanner Database ID"
|
help="Spanner Database ID"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--project_id",
|
||||||
|
default=os.environ.get("GOOGLE_CLOUD_PROJECT", "spanner-test"),
|
||||||
|
help="Spanner Project ID"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-u",
|
"-u",
|
||||||
"--sync_database_url",
|
"--sync_database_url",
|
||||||
@ -225,17 +246,23 @@ def get_args():
|
|||||||
|
|
||||||
# override using the DSN URL:
|
# override using the DSN URL:
|
||||||
if args.sync_database_url:
|
if args.sync_database_url:
|
||||||
args = use_dsn(args)
|
(instance_id, database_id, project_id) = ids_from_env(args.sync_database_url, mode=Mode.URL)
|
||||||
|
args.instance_id = instance_id
|
||||||
|
args.database_id = database_id
|
||||||
|
args.project_id = project_id
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def parse_args_list(args_list: str) -> List[str]:
|
def parse_args_list(args_list: str) -> List[str]:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Parse a list of items (or a single string) into a list of strings.
|
Parses a string representing a list of items into a list of strings.
|
||||||
Example input: [item1,item2,item3]
|
|
||||||
:param args_list: The list/string
|
Args:
|
||||||
:return: A list of strings
|
args_list (str): String to parse, e.g., "[item1,item2,item3]" or "item1".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of parsed string items.
|
||||||
"""
|
"""
|
||||||
if args_list[0] != "[" or args_list[-1] != "]":
|
if args_list[0] != "[" or args_list[-1] != "]":
|
||||||
# Assume it's a single item
|
# Assume it's a single item
|
||||||
@ -248,11 +275,10 @@ if __name__ == "__main__":
|
|||||||
args = get_args()
|
args = get_args()
|
||||||
with statsd.timer("syncstorage.purge_ttl.total_duration"):
|
with statsd.timer("syncstorage.purge_ttl.total_duration"):
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
logging.info('Starting purge_ttl.py')
|
logging.info("Starting purge_ttl.py")
|
||||||
|
|
||||||
spanner_purge(args)
|
spanner_purge(args)
|
||||||
|
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
duration = end_time - start_time
|
duration = end_time - start_time
|
||||||
logging.info(
|
logging.info(f"Completed purge_ttl.py, total_duration: {duration}")
|
||||||
'Completed purge_ttl.py, total_duration: {}'.format(duration))
|
|
||||||
|
47
tools/spanner/test_count_expired_rows.py
Normal file
47
tools/spanner/test_count_expired_rows.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
import types
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
import pytest
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import count_expired_rows
|
||||||
|
from utils import ids_from_env
|
||||||
|
|
||||||
|
def test_spanner_read_data_counts_and_logs(monkeypatch, caplog):
|
||||||
|
# Prepare mocks
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_database = MagicMock()
|
||||||
|
mock_snapshot_ctx = MagicMock()
|
||||||
|
mock_snapshot = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.one.return_value = [42]
|
||||||
|
mock_snapshot.execute_sql.return_value = mock_result
|
||||||
|
mock_snapshot_ctx.__enter__.return_value = mock_snapshot
|
||||||
|
mock_database.snapshot.return_value = mock_snapshot_ctx
|
||||||
|
|
||||||
|
# Patch spanner client and statsd
|
||||||
|
monkeypatch.setattr(count_expired_rows, "client", MagicMock())
|
||||||
|
count_expired_rows.client.instance.return_value = mock_instance
|
||||||
|
mock_instance.database.return_value = mock_database
|
||||||
|
|
||||||
|
mock_statsd = MagicMock()
|
||||||
|
monkeypatch.setattr(count_expired_rows, "statsd", mock_statsd)
|
||||||
|
mock_statsd.timer.return_value.__enter__.return_value = None
|
||||||
|
mock_statsd.timer.return_value.__exit__.return_value = None
|
||||||
|
|
||||||
|
# Patch from_env to return fixed values
|
||||||
|
monkeypatch.setattr(count_expired_rows, "ids_from_env", lambda: ("inst", "db", "proj"))
|
||||||
|
|
||||||
|
# Run function
|
||||||
|
with caplog.at_level(logging.INFO):
|
||||||
|
count_expired_rows.spanner_read_data("SELECT COUNT(*) FROM foo", "foo")
|
||||||
|
|
||||||
|
# Check logs
|
||||||
|
assert any("For inst:db proj" in m for m in caplog.messages)
|
||||||
|
assert any("Found 42 expired rows in foo" in m for m in caplog.messages)
|
||||||
|
|
||||||
|
# Check statsd calls
|
||||||
|
mock_statsd.gauge.assert_called_with("syncstorage.expired_foo_rows", 42)
|
||||||
|
mock_statsd.timer.assert_called_with("syncstorage.count_expired_foo_rows.duration")
|
||||||
|
mock_database.snapshot.assert_called_once()
|
||||||
|
mock_snapshot.execute_sql.assert_called_with("SELECT COUNT(*) FROM foo")
|
147
tools/spanner/test_purge_ttl.py
Normal file
147
tools/spanner/test_purge_ttl.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest import mock
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Import the functions to test from purge_ttl.py
|
||||||
|
import purge_ttl
|
||||||
|
|
||||||
|
def test_parse_args_list_single_item():
|
||||||
|
assert purge_ttl.parse_args_list("foo") == ["foo"]
|
||||||
|
|
||||||
|
def test_parse_args_list_multiple_items():
|
||||||
|
assert purge_ttl.parse_args_list("[a,b,c]") == ["a", "b", "c"]
|
||||||
|
|
||||||
|
def test_get_expiry_condition_now():
|
||||||
|
args = SimpleNamespace(expiry_mode="now")
|
||||||
|
assert purge_ttl.get_expiry_condition(args) == 'expiry < CURRENT_TIMESTAMP()'
|
||||||
|
|
||||||
|
def test_get_expiry_condition_midnight():
|
||||||
|
args = SimpleNamespace(expiry_mode="midnight")
|
||||||
|
assert purge_ttl.get_expiry_condition(args) == 'expiry < TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, "UTC")'
|
||||||
|
|
||||||
|
def test_get_expiry_condition_invalid():
|
||||||
|
args = SimpleNamespace(expiry_mode="invalid")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
purge_ttl.get_expiry_condition(args)
|
||||||
|
|
||||||
|
def test_add_conditions_no_collections_no_prefix():
|
||||||
|
args = SimpleNamespace(collection_ids=[], uid_prefixes=None)
|
||||||
|
query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None)
|
||||||
|
assert query == "SELECT * FROM foo WHERE 1=1"
|
||||||
|
assert params == {}
|
||||||
|
assert types == {}
|
||||||
|
|
||||||
|
def test_add_conditions_with_collections_single():
|
||||||
|
args = SimpleNamespace(collection_ids=["123"])
|
||||||
|
query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None)
|
||||||
|
assert "collection_id = @collection_id" in query
|
||||||
|
assert params["collection_id"] == "123"
|
||||||
|
assert types["collection_id"] == purge_ttl.param_types.INT64
|
||||||
|
|
||||||
|
def test_add_conditions_with_collections_multiple():
|
||||||
|
args = SimpleNamespace(collection_ids=["1", "2"])
|
||||||
|
query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None)
|
||||||
|
assert "collection_id in" in query
|
||||||
|
assert params["collection_id_0"] == "1"
|
||||||
|
assert params["collection_id_1"] == "2"
|
||||||
|
assert types["collection_id_0"] == purge_ttl.param_types.INT64
|
||||||
|
|
||||||
|
def test_add_conditions_with_prefix():
|
||||||
|
args = SimpleNamespace(collection_ids=[])
|
||||||
|
query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", "abc")
|
||||||
|
assert "STARTS_WITH(fxa_uid, @prefix)" in query
|
||||||
|
assert params["prefix"] == "abc"
|
||||||
|
assert types["prefix"] == purge_ttl.param_types.STRING
|
||||||
|
|
||||||
|
@mock.patch("purge_ttl.statsd")
|
||||||
|
def test_deleter_dryrun(statsd_mock):
|
||||||
|
database = mock.Mock()
|
||||||
|
statsd_mock.timer.return_value.__enter__.return_value = None
|
||||||
|
statsd_mock.timer.return_value.__exit__.return_value = None
|
||||||
|
purge_ttl.deleter(database, "batches", "DELETE FROM batches", dryrun=True)
|
||||||
|
database.execute_partitioned_dml.assert_not_called()
|
||||||
|
|
||||||
|
@mock.patch("purge_ttl.statsd")
|
||||||
|
def test_deleter_executes(statsd_mock):
|
||||||
|
database = mock.Mock()
|
||||||
|
statsd_mock.timer.return_value.__enter__.return_value = None
|
||||||
|
statsd_mock.timer.return_value.__exit__.return_value = None
|
||||||
|
database.execute_partitioned_dml.return_value = 42
|
||||||
|
|
||||||
|
purge_ttl.deleter(database, "batches", "DELETE FROM batches", dryrun=False)
|
||||||
|
database.execute_partitioned_dml.assert_called_once()
|
||||||
|
|
||||||
|
@mock.patch("purge_ttl.deleter")
|
||||||
|
@mock.patch("purge_ttl.add_conditions")
|
||||||
|
@mock.patch("purge_ttl.get_expiry_condition")
|
||||||
|
@mock.patch("purge_ttl.client")
|
||||||
|
def test_spanner_purge_both(client_mock, get_expiry_condition_mock, add_conditions_mock, deleter_mock):
|
||||||
|
# Setup
|
||||||
|
args = SimpleNamespace(
|
||||||
|
instance_id="inst",
|
||||||
|
database_id="db",
|
||||||
|
expiry_mode="now",
|
||||||
|
auto_split=None,
|
||||||
|
uid_prefixes=None,
|
||||||
|
mode="both",
|
||||||
|
dryrun=True,
|
||||||
|
collection_ids=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
instance = mock.Mock()
|
||||||
|
database = mock.Mock()
|
||||||
|
client_mock.instance.return_value = instance
|
||||||
|
instance.database.return_value = database
|
||||||
|
|
||||||
|
get_expiry_condition_mock.return_value = "expiry < CURRENT_TIMESTAMP()"
|
||||||
|
add_conditions_mock.side_effect = [
|
||||||
|
("batch_query", {"a": 1}, {"a": 2}),
|
||||||
|
("bso_query", {"b": 3}, {"b": 4}),
|
||||||
|
]
|
||||||
|
purge_ttl.spanner_purge(args)
|
||||||
|
|
||||||
|
assert deleter_mock.call_count == 2
|
||||||
|
deleter_mock.assert_called()
|
||||||
|
|
||||||
|
deleter_mock.assert_any_call(
|
||||||
|
database,
|
||||||
|
name="batches",
|
||||||
|
query="batch_query",
|
||||||
|
params={"a": 1},
|
||||||
|
param_types={"a": 2},
|
||||||
|
prefix=None,
|
||||||
|
dryrun=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
deleter_mock.assert_any_call(
|
||||||
|
database,
|
||||||
|
name="bso",
|
||||||
|
query="bso_query",
|
||||||
|
params={"b": 3},
|
||||||
|
param_types={"b": 4},
|
||||||
|
prefix=None,
|
||||||
|
dryrun=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@mock.patch("argparse.ArgumentParser.parse_args")
|
||||||
|
def test_get_args_env_and_dsn(parse_args_mock):
|
||||||
|
# Simulate args with DSN
|
||||||
|
args = SimpleNamespace(
|
||||||
|
instance_id="foo",
|
||||||
|
database_id="bar",
|
||||||
|
project_id="baz",
|
||||||
|
sync_database_url="spanner://projects/proj/instances/inst/databases/db",
|
||||||
|
collection_ids=[],
|
||||||
|
uid_prefixes=[],
|
||||||
|
auto_split=None,
|
||||||
|
mode="both",
|
||||||
|
expiry_mode="midnight",
|
||||||
|
dryrun=False,
|
||||||
|
)
|
||||||
|
parse_args_mock.return_value = args
|
||||||
|
result = purge_ttl.get_args()
|
||||||
|
assert result.project_id == "proj"
|
||||||
|
assert result.instance_id == "inst"
|
||||||
|
assert result.database_id == "db"
|
46
tools/spanner/test_utils.py
Normal file
46
tools/spanner/test_utils.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_env(monkeypatch):
|
||||||
|
# Reset environment variables before each test
|
||||||
|
for var in [
|
||||||
|
"SYNC_SYNCSTORAGE__DATABASE_URL",
|
||||||
|
"INSTANCE_ID",
|
||||||
|
"DATABASE_ID",
|
||||||
|
"GOOGLE_CLOUD_PROJECT"
|
||||||
|
]:
|
||||||
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
|
||||||
|
def test_ids_from_env_parses_url(monkeypatch):
|
||||||
|
"""Test with passed in DSN"""
|
||||||
|
monkeypatch.setenv("SYNC_SYNCSTORAGE__DATABASE_URL", "spanner://projects/proj/instances/inst/databases/db")
|
||||||
|
dsn = "SYNC_SYNCSTORAGE__DATABASE_URL"
|
||||||
|
instance_id, database_id, project_id = utils.ids_from_env(dsn)
|
||||||
|
assert project_id == "proj"
|
||||||
|
assert instance_id == "inst"
|
||||||
|
assert database_id == "db"
|
||||||
|
|
||||||
|
def test_ids_from_env_with_missing_url(monkeypatch):
|
||||||
|
"""Test ensures that default env vars set id values."""
|
||||||
|
monkeypatch.setenv("INSTANCE_ID", "foo")
|
||||||
|
monkeypatch.setenv("DATABASE_ID", "bar")
|
||||||
|
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "baz")
|
||||||
|
instance_id, database_id, project_id = utils.ids_from_env()
|
||||||
|
assert instance_id == "foo"
|
||||||
|
assert database_id == "bar"
|
||||||
|
assert project_id == "baz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_env_with_invalid_url(monkeypatch):
|
||||||
|
monkeypatch.setenv("SYNC_SYNCSTORAGE__DATABASE_URL", "notaspanner://foo")
|
||||||
|
monkeypatch.setenv("INSTANCE_ID", "default")
|
||||||
|
monkeypatch.setenv("DATABASE_ID", "default-db")
|
||||||
|
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "default-proj")
|
||||||
|
|
||||||
|
instance_id, database_id, project_id = utils.ids_from_env()
|
||||||
|
assert instance_id == "default"
|
||||||
|
assert database_id == "default-db"
|
||||||
|
assert project_id == "default-proj"
|
68
tools/spanner/utils.py
Normal file
68
tools/spanner/utils.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Utility Module for spanner CLI scripts
|
||||||
|
#
|
||||||
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
from enum import auto, Enum
|
||||||
|
import os
|
||||||
|
from urllib import parse
|
||||||
|
from typing import Tuple
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
DSN_URL = "SYNC_SYNCSTORAGE__DATABASE_URL"
|
||||||
|
"""
|
||||||
|
Environment variable that stores Sync database URL
|
||||||
|
Depending on deployment, can be MySQL or Spanner.
|
||||||
|
In this context, should always point to spanner for these scripts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Mode(Enum):
|
||||||
|
URL = auto()
|
||||||
|
ENV_VAR = auto()
|
||||||
|
|
||||||
|
def ids_from_env(dsn=DSN_URL, mode=Mode.ENV_VAR) -> Tuple[str, str, str]:
|
||||||
|
"""
|
||||||
|
Function that extracts the instance, project, and database ids from the DSN url.
|
||||||
|
It is defined as the SYNC_SYNCSTORAGE__DATABASE_URL environment variable.
|
||||||
|
The defined defaults are in webservices-infra/sync and can be configured there for
|
||||||
|
production runs.
|
||||||
|
|
||||||
|
`dsn` argument is set to default to the `DSN_URL` constant.
|
||||||
|
|
||||||
|
For reference, an example spanner url passed in is in the following format:
|
||||||
|
|
||||||
|
`spanner://projects/moz-fx-sync-prod-xxxx/instances/sync/databases/syncdb`
|
||||||
|
database_id = `syncdb`, instance_id = `sync`, project_id = `moz-fx-sync-prod-xxxx`
|
||||||
|
"""
|
||||||
|
# Change these to reflect your Spanner instance install
|
||||||
|
instance_id = None
|
||||||
|
database_id = None
|
||||||
|
project_id = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if mode == Mode.ENV_VAR:
|
||||||
|
url = os.environ.get(dsn)
|
||||||
|
if not url:
|
||||||
|
raise Exception(f"No env var found for provided DSN: {dsn}")
|
||||||
|
elif mode == Mode.URL:
|
||||||
|
url = dsn
|
||||||
|
if not url:
|
||||||
|
raise Exception(f"No valid url found: {url}")
|
||||||
|
parsed_url = parse.urlparse(url)
|
||||||
|
if parsed_url.scheme == "spanner":
|
||||||
|
path = parsed_url.path.split("/")
|
||||||
|
instance_id = path[-3]
|
||||||
|
project_id = path[-5]
|
||||||
|
database_id = path[-1]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception parsing url: {e}")
|
||||||
|
# Fallbacks if not set
|
||||||
|
if not instance_id:
|
||||||
|
instance_id = os.environ.get("INSTANCE_ID", "spanner-test")
|
||||||
|
if not database_id:
|
||||||
|
database_id = os.environ.get("DATABASE_ID", "sync_stage")
|
||||||
|
if not project_id:
|
||||||
|
project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", "test-project")
|
||||||
|
|
||||||
|
return (instance_id, database_id, project_id)
|
@ -21,6 +21,8 @@ from google.api_core.exceptions import AlreadyExists
|
|||||||
from google.cloud import spanner
|
from google.cloud import spanner
|
||||||
from google.cloud.spanner_v1 import param_types
|
from google.cloud.spanner_v1 import param_types
|
||||||
|
|
||||||
|
from utils import ids_from_env
|
||||||
|
|
||||||
|
|
||||||
# max batch size for this write is 2000, otherwise we run into:
|
# max batch size for this write is 2000, otherwise we run into:
|
||||||
"""google.api_core.exceptions.InvalidArgument: 400 The transaction
|
"""google.api_core.exceptions.InvalidArgument: 400 The transaction
|
||||||
@ -79,12 +81,12 @@ PAYLOAD = ''.join(
|
|||||||
def load(instance, db, coll_id, name):
|
def load(instance, db, coll_id, name):
|
||||||
fxa_uid = "DEADBEEF" + uuid.uuid4().hex[8:]
|
fxa_uid = "DEADBEEF" + uuid.uuid4().hex[8:]
|
||||||
fxa_kid = "{:013d}-{}".format(22, fxa_uid)
|
fxa_kid = "{:013d}-{}".format(22, fxa_uid)
|
||||||
print("{} -> Loading {} {}".format(name, fxa_uid, fxa_kid))
|
print(f"{name} -> Loading {fxa_uid} {fxa_kid}")
|
||||||
name = threading.current_thread().getName()
|
name = threading.current_thread().getName()
|
||||||
spanner_client = spanner.Client()
|
spanner_client = spanner.Client()
|
||||||
instance = spanner_client.instance(instance)
|
instance = spanner_client.instance(instance)
|
||||||
db = instance.database(db)
|
db = instance.database(db)
|
||||||
print('{name} Db: {db}'.format(name=name, db=db))
|
print(f"{name} Db: {db}")
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
|
|
||||||
def create_user(txn):
|
def create_user(txn):
|
||||||
@ -110,16 +112,14 @@ def load(instance, db, coll_id, name):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
db.run_in_transaction(create_user)
|
db.run_in_transaction(create_user)
|
||||||
print('{name} Created user (fxa_uid: {uid}, fxa_kid: {kid})'.format(
|
print(f"{name} Created user (fxa_uid: {fxa_uid}, fxa_kid: {fxa_kid})")
|
||||||
name=name, uid=fxa_uid, kid=fxa_kid))
|
|
||||||
except AlreadyExists:
|
except AlreadyExists:
|
||||||
print('{name} Existing user (fxa_uid: {uid}}, fxa_kid: {kid}})'.format(
|
print(f"{name} Existing user (fxa_uid: {fxa_uid}, fxa_kid: {fxa_kid})")
|
||||||
name=name, uid=fxa_uid, kid=fxa_kid))
|
|
||||||
|
|
||||||
# approximately 1892 bytes
|
# approximately 1892 bytes
|
||||||
rlen = 0
|
rlen = 0
|
||||||
|
|
||||||
print('{name} Loading..'.format(name=name))
|
print(f"{name} Loading..")
|
||||||
for j in range(BATCHES):
|
for j in range(BATCHES):
|
||||||
records = []
|
records = []
|
||||||
for i in range(BATCH_SIZE):
|
for i in range(BATCH_SIZE):
|
||||||
@ -176,29 +176,11 @@ def load(instance, db, coll_id, name):
|
|||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
def from_env():
|
|
||||||
try:
|
|
||||||
url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL")
|
|
||||||
if not url:
|
|
||||||
raise Exception("no url")
|
|
||||||
purl = parse.urlparse(url)
|
|
||||||
if purl.scheme == "spanner":
|
|
||||||
path = purl.path.split("/")
|
|
||||||
instance_id = path[-3]
|
|
||||||
database_id = path[-1]
|
|
||||||
except Exception as e:
|
|
||||||
# Change these to reflect your Spanner instance install
|
|
||||||
print("Exception {}".format(e))
|
|
||||||
instance_id = os.environ.get("INSTANCE_ID", "spanner-test")
|
|
||||||
database_id = os.environ.get("DATABASE_ID", "sync_stage")
|
|
||||||
return (instance_id, database_id)
|
|
||||||
|
|
||||||
|
|
||||||
def loader():
|
def loader():
|
||||||
# Prefix uaids for easy filtering later
|
# Prefix uaids for easy filtering later
|
||||||
# Each loader thread gets it's own fake user to prevent some hotspot
|
# Each loader thread gets it's own fake user to prevent some hotspot
|
||||||
# issues.
|
# issues.
|
||||||
(instance_id, database_id) = from_env()
|
(instance_id, database_id, _) = ids_from_env()
|
||||||
# switching uid/kid to per load because of weird google trimming
|
# switching uid/kid to per load because of weird google trimming
|
||||||
name = threading.current_thread().getName()
|
name = threading.current_thread().getName()
|
||||||
load(instance_id, database_id, COLL_ID, name)
|
load(instance_id, database_id, COLL_ID, name)
|
||||||
@ -206,9 +188,9 @@ def loader():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
for c in range(THREAD_COUNT):
|
for c in range(THREAD_COUNT):
|
||||||
print("Starting thread {}".format(c))
|
print(f"Starting thread {c}")
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
name="loader_{}".format(c),
|
name=f"loader_{c}",
|
||||||
target=loader)
|
target=loader)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user