mirror of
https://github.com/mozilla-services/syncstorage-rs.git
synced 2025-08-06 03:46:57 +02:00
feat: spanner scripts parse gcp project (#1714)
Some checks failed
Glean probe-scraper / glean-probe-scraper (push) Has been cancelled
Some checks failed
Glean probe-scraper / glean-probe-scraper (push) Has been cancelled
feat: spanner scripts parse gcp project
This commit is contained in:
parent
4ddf5b4169
commit
d716ac5d10
@ -13,6 +13,7 @@ from statsd.defaults.env import statsd
|
||||
from urllib import parse
|
||||
|
||||
from google.cloud import spanner
|
||||
from utils import ids_from_env
|
||||
|
||||
# set up logger
|
||||
logging.basicConfig(
|
||||
@ -23,33 +24,24 @@ logging.basicConfig(
|
||||
# Change these to match your install.
|
||||
client = spanner.Client()
|
||||
|
||||
def spanner_read_data(query: str, table: str) -> None:
|
||||
"""
|
||||
Executes a query on the specified Spanner table to count expired rows,
|
||||
logs the result, and sends metrics to statsd.
|
||||
|
||||
def from_env():
|
||||
try:
|
||||
url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL")
|
||||
if not url:
|
||||
raise Exception("no url")
|
||||
purl = parse.urlparse(url)
|
||||
if purl.scheme == "spanner":
|
||||
path = purl.path.split("/")
|
||||
instance_id = path[-3]
|
||||
database_id = path[-1]
|
||||
except Exception as e:
|
||||
# Change these to reflect your Spanner instance install
|
||||
print("Exception {}".format(e))
|
||||
instance_id = os.environ.get("INSTANCE_ID", "spanner-test")
|
||||
database_id = os.environ.get("DATABASE_ID", "sync_stage")
|
||||
return (instance_id, database_id)
|
||||
|
||||
|
||||
def spanner_read_data(query, table):
|
||||
(instance_id, database_id) = from_env()
|
||||
Args:
|
||||
query (str): The SQL query to execute.
|
||||
table (str): The name of the table being queried.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
(instance_id, database_id, project_id) = ids_from_env()
|
||||
instance = client.instance(instance_id)
|
||||
database = instance.database(database_id)
|
||||
|
||||
logging.info("For {}:{}".format(instance_id, database_id))
|
||||
logging.info(f"For {instance_id}:{database_id} {project_id}")
|
||||
|
||||
# Count bsos expired rows
|
||||
# Count expired rows in the specified table
|
||||
with statsd.timer(f"syncstorage.count_expired_{table}_rows.duration"):
|
||||
with database.snapshot() as snapshot:
|
||||
result = snapshot.execute_sql(query)
|
||||
|
@ -13,6 +13,8 @@ from statsd.defaults.env import statsd
|
||||
from urllib import parse
|
||||
|
||||
from google.cloud import spanner
|
||||
from typing import Tuple
|
||||
from utils import ids_from_env
|
||||
|
||||
# set up logger
|
||||
logging.basicConfig(
|
||||
@ -23,31 +25,26 @@ logging.basicConfig(
|
||||
# Change these to match your install.
|
||||
client = spanner.Client()
|
||||
|
||||
def spanner_read_data() -> None:
|
||||
"""
|
||||
Reads data from a Google Cloud Spanner database to count the number of distinct users.
|
||||
|
||||
def from_env():
|
||||
try:
|
||||
url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL")
|
||||
if not url:
|
||||
raise Exception("no url")
|
||||
purl = parse.urlparse(url)
|
||||
if purl.scheme == "spanner":
|
||||
path = purl.path.split("/")
|
||||
instance_id = path[-3]
|
||||
database_id = path[-1]
|
||||
except Exception as e:
|
||||
# Change these to reflect your Spanner instance install
|
||||
print("Exception {}".format(e))
|
||||
instance_id = os.environ.get("INSTANCE_ID", "spanner-test")
|
||||
database_id = os.environ.get("DATABASE_ID", "sync_stage")
|
||||
return (instance_id, database_id)
|
||||
This function connects to a Spanner instance and database using environment variables,
|
||||
executes a SQL query to count the number of distinct `fxa_uid` entries in the `user_collections` table,
|
||||
and logs the result. It also records the duration of the operation and the user count using statsd metrics.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
def spanner_read_data(request=None):
|
||||
(instance_id, database_id) = from_env()
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
(instance_id, database_id, project_id) = ids_from_env()
|
||||
instance = client.instance(instance_id)
|
||||
database = instance.database(database_id)
|
||||
project = instance.database(database_id)
|
||||
|
||||
logging.info("For {}:{}".format(instance_id, database_id))
|
||||
logging.info(f"For {instance_id}:{database_id} {project}")
|
||||
|
||||
# Count users
|
||||
with statsd.timer("syncstorage.count_users.duration"):
|
||||
@ -56,7 +53,7 @@ def spanner_read_data(request=None):
|
||||
result = snapshot.execute_sql(query)
|
||||
user_count = result.one()[0]
|
||||
statsd.gauge("syncstorage.distinct_fxa_uid", user_count)
|
||||
logging.info("Count found {} distinct users".format(user_count))
|
||||
logging.info(f"Count found {user_count} distinct users")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -17,6 +17,8 @@ from google.cloud.spanner_v1.database import Database
|
||||
from google.cloud.spanner_v1 import param_types
|
||||
from statsd.defaults.env import statsd
|
||||
|
||||
from utils import ids_from_env, Mode
|
||||
|
||||
# set up logger
|
||||
logging.basicConfig(
|
||||
format='{"datetime": "%(asctime)s", "message": "%(message)s"}',
|
||||
@ -26,23 +28,6 @@ logging.basicConfig(
|
||||
# Change these to match your install.
|
||||
client = spanner.Client()
|
||||
|
||||
|
||||
def use_dsn(args):
|
||||
try:
|
||||
if not args.sync_database_url:
|
||||
raise Exception("no url")
|
||||
url = args.sync_database_url
|
||||
purl = parse.urlparse(url)
|
||||
if purl.scheme == "spanner":
|
||||
path = purl.path.split("/")
|
||||
args.instance_id = path[-3]
|
||||
args.database_id = path[-1]
|
||||
except Exception as e:
|
||||
# Change these to reflect your Spanner instance install
|
||||
print("Exception {}".format(e))
|
||||
return args
|
||||
|
||||
|
||||
def deleter(database: Database,
|
||||
name: str,
|
||||
query: str,
|
||||
@ -50,17 +35,15 @@ def deleter(database: Database,
|
||||
params: Optional[dict]=None,
|
||||
param_types: Optional[dict]=None,
|
||||
dryrun: Optional[bool]=False):
|
||||
with statsd.timer("syncstorage.purge_ttl.{}_duration".format(name)):
|
||||
logging.info("Running: {} :: {}".format(query, params))
|
||||
with statsd.timer(f"syncstorage.purge_ttl.{name}_duration"):
|
||||
logging.info(f"Running: {query} :: {params}")
|
||||
start = datetime.now()
|
||||
result = 0
|
||||
if not dryrun:
|
||||
result = database.execute_partitioned_dml(query, params=params, param_types=param_types)
|
||||
end = datetime.now()
|
||||
logging.info(
|
||||
"{name}: removed {result} rows, {name}_duration: {time}, prefix: {prefix}".format(
|
||||
name=name, result=result, time=end - start, prefix=prefix))
|
||||
|
||||
f"{name}: removed {result} rows, {name}_duration: {end - start}, prefix: {prefix}")
|
||||
|
||||
def add_conditions(args, query: str, prefix: Optional[str]):
|
||||
"""
|
||||
@ -82,7 +65,7 @@ def add_conditions(args, query: str, prefix: Optional[str]):
|
||||
types['collection_id'] = param_types.INT64
|
||||
else:
|
||||
for count,id in enumerate(ids):
|
||||
name = 'collection_id_{}'.format(count)
|
||||
name = f'collection_id_{count}'
|
||||
params[name] = id
|
||||
types[name] = param_types.INT64
|
||||
query += " in (@{})".format(
|
||||
@ -105,28 +88,43 @@ def get_expiry_condition(args):
|
||||
elif args.expiry_mode == "midnight":
|
||||
return 'expiry < TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, "UTC")'
|
||||
else:
|
||||
raise Exception("Invalid expiry mode: {}".format(args.expiry_mode))
|
||||
raise Exception(f"Invalid expiry mode: {args.expiry_mode}")
|
||||
|
||||
|
||||
def spanner_purge(args):
|
||||
def spanner_purge(args) -> None:
|
||||
"""
|
||||
Purges expired TTL records from Spanner based on the provided arguments.
|
||||
|
||||
This function connects to the specified Spanner instance and database,
|
||||
determines the expiry condition, and deletes expired records from the
|
||||
'batches' and/or 'bsos' tables according to the purge mode. Supports
|
||||
filtering by collection IDs and UID prefixes, and can operate in dry-run mode.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Parsed command-line arguments containing
|
||||
Spanner connection info, purge options, and filters.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
instance = client.instance(args.instance_id)
|
||||
database = instance.database(args.database_id)
|
||||
expiry_condition = get_expiry_condition(args)
|
||||
if args.auto_split:
|
||||
args.uid_prefixes = [
|
||||
hex(i).lstrip("0x").zfill(args.auto_split) for i in range(
|
||||
0, 16 ** args.auto_split)]
|
||||
hex(i).lstrip("0x").zfill(args.auto_split) for i in range(
|
||||
0, 16 ** args.auto_split)]
|
||||
prefixes = args.uid_prefixes if args.uid_prefixes else [None]
|
||||
|
||||
for prefix in prefixes:
|
||||
logging.info("For {}:{}, prefix = {}".format(args.instance_id, args.database_id, prefix))
|
||||
logging.info(f"For {args.instance_id}:{args.database_id}, prefix = {prefix}")
|
||||
|
||||
if args.mode in ["batches", "both"]:
|
||||
# Delete Batches. Also deletes child batch_bsos rows (INTERLEAVE
|
||||
# IN PARENT batches ON DELETE CASCADE)
|
||||
(batch_query, params, types) = add_conditions(
|
||||
args,
|
||||
'DELETE FROM batches WHERE {}'.format(expiry_condition),
|
||||
f'DELETE FROM batches WHERE {expiry_condition}',
|
||||
prefix,
|
||||
)
|
||||
deleter(
|
||||
@ -143,7 +141,7 @@ def spanner_purge(args):
|
||||
# Delete BSOs
|
||||
(bso_query, params, types) = add_conditions(
|
||||
args,
|
||||
'DELETE FROM bsos WHERE {}'.format(expiry_condition),
|
||||
f'DELETE FROM bsos WHERE {expiry_condition}',
|
||||
prefix
|
||||
)
|
||||
deleter(
|
||||
@ -158,6 +156,23 @@ def spanner_purge(args):
|
||||
|
||||
|
||||
def get_args():
|
||||
"""
|
||||
Parses and returns command-line arguments for the Spanner TTL purge tool.
|
||||
If a DSN URL is provided, usually `SYNC_SYNCSTORAGE__DATABASE_URL`, its values override the corresponding arguments.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed command-line arguments with the following attributes:
|
||||
- instance_id (str): Spanner instance ID (default from INSTANCE_ID env or 'spanner-test').
|
||||
- database_id (str): Spanner database ID (default from DATABASE_ID env or 'sync_schema3').
|
||||
- project_id (str): Google Cloud project ID (default from GOOGLE_CLOUD_PROJECT env or 'spanner-test').
|
||||
- sync_database_url (str): Spanner database DSN (default from SYNC_SYNCSTORAGE__DATABASE_URL env).
|
||||
- collection_ids (list): List of collection IDs to purge (default from COLLECTION_IDS env or empty list).
|
||||
- uid_prefixes (list): List of UID prefixes to limit purges (default from PURGE_UID_PREFIXES env or empty list).
|
||||
- auto_split (int): Number of digits to auto-generate UID prefixes (default from PURGE_AUTO_SPLIT env).
|
||||
- mode (str): Purge mode, one of 'batches', 'bsos', or 'both' (default from PURGE_MODE env or 'both').
|
||||
- expiry_mode (str): Expiry mode, either 'now' or 'midnight' (default from PURGE_EXPIRY_MODE env or 'midnight').
|
||||
- dryrun (bool): If True, do not actually purge records from Spanner.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Purge old TTLs"
|
||||
)
|
||||
@ -173,6 +188,12 @@ def get_args():
|
||||
default=os.environ.get("DATABASE_ID", "sync_schema3"),
|
||||
help="Spanner Database ID"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--project_id",
|
||||
default=os.environ.get("GOOGLE_CLOUD_PROJECT", "spanner-test"),
|
||||
help="Spanner Project ID"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u",
|
||||
"--sync_database_url",
|
||||
@ -225,17 +246,23 @@ def get_args():
|
||||
|
||||
# override using the DSN URL:
|
||||
if args.sync_database_url:
|
||||
args = use_dsn(args)
|
||||
|
||||
(instance_id, database_id, project_id) = ids_from_env(args.sync_database_url, mode=Mode.URL)
|
||||
args.instance_id = instance_id
|
||||
args.database_id = database_id
|
||||
args.project_id = project_id
|
||||
return args
|
||||
|
||||
|
||||
def parse_args_list(args_list: str) -> List[str]:
|
||||
|
||||
"""
|
||||
Parse a list of items (or a single string) into a list of strings.
|
||||
Example input: [item1,item2,item3]
|
||||
:param args_list: The list/string
|
||||
:return: A list of strings
|
||||
Parses a string representing a list of items into a list of strings.
|
||||
|
||||
Args:
|
||||
args_list (str): String to parse, e.g., "[item1,item2,item3]" or "item1".
|
||||
|
||||
Returns:
|
||||
List[str]: List of parsed string items.
|
||||
"""
|
||||
if args_list[0] != "[" or args_list[-1] != "]":
|
||||
# Assume it's a single item
|
||||
@ -248,11 +275,10 @@ if __name__ == "__main__":
|
||||
args = get_args()
|
||||
with statsd.timer("syncstorage.purge_ttl.total_duration"):
|
||||
start_time = datetime.now()
|
||||
logging.info('Starting purge_ttl.py')
|
||||
logging.info("Starting purge_ttl.py")
|
||||
|
||||
spanner_purge(args)
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = end_time - start_time
|
||||
logging.info(
|
||||
'Completed purge_ttl.py, total_duration: {}'.format(duration))
|
||||
logging.info(f"Completed purge_ttl.py, total_duration: {duration}")
|
||||
|
47
tools/spanner/test_count_expired_rows.py
Normal file
47
tools/spanner/test_count_expired_rows.py
Normal file
@ -0,0 +1,47 @@
|
||||
import os
|
||||
import types
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
import count_expired_rows
|
||||
from utils import ids_from_env
|
||||
|
||||
def test_spanner_read_data_counts_and_logs(monkeypatch, caplog):
|
||||
# Prepare mocks
|
||||
mock_instance = MagicMock()
|
||||
mock_database = MagicMock()
|
||||
mock_snapshot_ctx = MagicMock()
|
||||
mock_snapshot = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.one.return_value = [42]
|
||||
mock_snapshot.execute_sql.return_value = mock_result
|
||||
mock_snapshot_ctx.__enter__.return_value = mock_snapshot
|
||||
mock_database.snapshot.return_value = mock_snapshot_ctx
|
||||
|
||||
# Patch spanner client and statsd
|
||||
monkeypatch.setattr(count_expired_rows, "client", MagicMock())
|
||||
count_expired_rows.client.instance.return_value = mock_instance
|
||||
mock_instance.database.return_value = mock_database
|
||||
|
||||
mock_statsd = MagicMock()
|
||||
monkeypatch.setattr(count_expired_rows, "statsd", mock_statsd)
|
||||
mock_statsd.timer.return_value.__enter__.return_value = None
|
||||
mock_statsd.timer.return_value.__exit__.return_value = None
|
||||
|
||||
# Patch from_env to return fixed values
|
||||
monkeypatch.setattr(count_expired_rows, "ids_from_env", lambda: ("inst", "db", "proj"))
|
||||
|
||||
# Run function
|
||||
with caplog.at_level(logging.INFO):
|
||||
count_expired_rows.spanner_read_data("SELECT COUNT(*) FROM foo", "foo")
|
||||
|
||||
# Check logs
|
||||
assert any("For inst:db proj" in m for m in caplog.messages)
|
||||
assert any("Found 42 expired rows in foo" in m for m in caplog.messages)
|
||||
|
||||
# Check statsd calls
|
||||
mock_statsd.gauge.assert_called_with("syncstorage.expired_foo_rows", 42)
|
||||
mock_statsd.timer.assert_called_with("syncstorage.count_expired_foo_rows.duration")
|
||||
mock_database.snapshot.assert_called_once()
|
||||
mock_snapshot.execute_sql.assert_called_with("SELECT COUNT(*) FROM foo")
|
147
tools/spanner/test_purge_ttl.py
Normal file
147
tools/spanner/test_purge_ttl.py
Normal file
@ -0,0 +1,147 @@
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from types import SimpleNamespace
|
||||
|
||||
import sys
|
||||
|
||||
# Import the functions to test from purge_ttl.py
|
||||
import purge_ttl
|
||||
|
||||
def test_parse_args_list_single_item():
|
||||
assert purge_ttl.parse_args_list("foo") == ["foo"]
|
||||
|
||||
def test_parse_args_list_multiple_items():
|
||||
assert purge_ttl.parse_args_list("[a,b,c]") == ["a", "b", "c"]
|
||||
|
||||
def test_get_expiry_condition_now():
|
||||
args = SimpleNamespace(expiry_mode="now")
|
||||
assert purge_ttl.get_expiry_condition(args) == 'expiry < CURRENT_TIMESTAMP()'
|
||||
|
||||
def test_get_expiry_condition_midnight():
|
||||
args = SimpleNamespace(expiry_mode="midnight")
|
||||
assert purge_ttl.get_expiry_condition(args) == 'expiry < TIMESTAMP_TRUNC(CURRENT_TIMESTAMP(), DAY, "UTC")'
|
||||
|
||||
def test_get_expiry_condition_invalid():
|
||||
args = SimpleNamespace(expiry_mode="invalid")
|
||||
with pytest.raises(Exception):
|
||||
purge_ttl.get_expiry_condition(args)
|
||||
|
||||
def test_add_conditions_no_collections_no_prefix():
|
||||
args = SimpleNamespace(collection_ids=[], uid_prefixes=None)
|
||||
query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None)
|
||||
assert query == "SELECT * FROM foo WHERE 1=1"
|
||||
assert params == {}
|
||||
assert types == {}
|
||||
|
||||
def test_add_conditions_with_collections_single():
|
||||
args = SimpleNamespace(collection_ids=["123"])
|
||||
query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None)
|
||||
assert "collection_id = @collection_id" in query
|
||||
assert params["collection_id"] == "123"
|
||||
assert types["collection_id"] == purge_ttl.param_types.INT64
|
||||
|
||||
def test_add_conditions_with_collections_multiple():
|
||||
args = SimpleNamespace(collection_ids=["1", "2"])
|
||||
query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", None)
|
||||
assert "collection_id in" in query
|
||||
assert params["collection_id_0"] == "1"
|
||||
assert params["collection_id_1"] == "2"
|
||||
assert types["collection_id_0"] == purge_ttl.param_types.INT64
|
||||
|
||||
def test_add_conditions_with_prefix():
|
||||
args = SimpleNamespace(collection_ids=[])
|
||||
query, params, types = purge_ttl.add_conditions(args, "SELECT * FROM foo WHERE 1=1", "abc")
|
||||
assert "STARTS_WITH(fxa_uid, @prefix)" in query
|
||||
assert params["prefix"] == "abc"
|
||||
assert types["prefix"] == purge_ttl.param_types.STRING
|
||||
|
||||
@mock.patch("purge_ttl.statsd")
|
||||
def test_deleter_dryrun(statsd_mock):
|
||||
database = mock.Mock()
|
||||
statsd_mock.timer.return_value.__enter__.return_value = None
|
||||
statsd_mock.timer.return_value.__exit__.return_value = None
|
||||
purge_ttl.deleter(database, "batches", "DELETE FROM batches", dryrun=True)
|
||||
database.execute_partitioned_dml.assert_not_called()
|
||||
|
||||
@mock.patch("purge_ttl.statsd")
|
||||
def test_deleter_executes(statsd_mock):
|
||||
database = mock.Mock()
|
||||
statsd_mock.timer.return_value.__enter__.return_value = None
|
||||
statsd_mock.timer.return_value.__exit__.return_value = None
|
||||
database.execute_partitioned_dml.return_value = 42
|
||||
|
||||
purge_ttl.deleter(database, "batches", "DELETE FROM batches", dryrun=False)
|
||||
database.execute_partitioned_dml.assert_called_once()
|
||||
|
||||
@mock.patch("purge_ttl.deleter")
|
||||
@mock.patch("purge_ttl.add_conditions")
|
||||
@mock.patch("purge_ttl.get_expiry_condition")
|
||||
@mock.patch("purge_ttl.client")
|
||||
def test_spanner_purge_both(client_mock, get_expiry_condition_mock, add_conditions_mock, deleter_mock):
|
||||
# Setup
|
||||
args = SimpleNamespace(
|
||||
instance_id="inst",
|
||||
database_id="db",
|
||||
expiry_mode="now",
|
||||
auto_split=None,
|
||||
uid_prefixes=None,
|
||||
mode="both",
|
||||
dryrun=True,
|
||||
collection_ids=[]
|
||||
)
|
||||
|
||||
instance = mock.Mock()
|
||||
database = mock.Mock()
|
||||
client_mock.instance.return_value = instance
|
||||
instance.database.return_value = database
|
||||
|
||||
get_expiry_condition_mock.return_value = "expiry < CURRENT_TIMESTAMP()"
|
||||
add_conditions_mock.side_effect = [
|
||||
("batch_query", {"a": 1}, {"a": 2}),
|
||||
("bso_query", {"b": 3}, {"b": 4}),
|
||||
]
|
||||
purge_ttl.spanner_purge(args)
|
||||
|
||||
assert deleter_mock.call_count == 2
|
||||
deleter_mock.assert_called()
|
||||
|
||||
deleter_mock.assert_any_call(
|
||||
database,
|
||||
name="batches",
|
||||
query="batch_query",
|
||||
params={"a": 1},
|
||||
param_types={"a": 2},
|
||||
prefix=None,
|
||||
dryrun=True,
|
||||
)
|
||||
|
||||
deleter_mock.assert_any_call(
|
||||
database,
|
||||
name="bso",
|
||||
query="bso_query",
|
||||
params={"b": 3},
|
||||
param_types={"b": 4},
|
||||
prefix=None,
|
||||
dryrun=True,
|
||||
)
|
||||
|
||||
@mock.patch("argparse.ArgumentParser.parse_args")
|
||||
def test_get_args_env_and_dsn(parse_args_mock):
|
||||
# Simulate args with DSN
|
||||
args = SimpleNamespace(
|
||||
instance_id="foo",
|
||||
database_id="bar",
|
||||
project_id="baz",
|
||||
sync_database_url="spanner://projects/proj/instances/inst/databases/db",
|
||||
collection_ids=[],
|
||||
uid_prefixes=[],
|
||||
auto_split=None,
|
||||
mode="both",
|
||||
expiry_mode="midnight",
|
||||
dryrun=False,
|
||||
)
|
||||
parse_args_mock.return_value = args
|
||||
result = purge_ttl.get_args()
|
||||
assert result.project_id == "proj"
|
||||
assert result.instance_id == "inst"
|
||||
assert result.database_id == "db"
|
46
tools/spanner/test_utils.py
Normal file
46
tools/spanner/test_utils.py
Normal file
@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
|
||||
import utils
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_env(monkeypatch):
|
||||
# Reset environment variables before each test
|
||||
for var in [
|
||||
"SYNC_SYNCSTORAGE__DATABASE_URL",
|
||||
"INSTANCE_ID",
|
||||
"DATABASE_ID",
|
||||
"GOOGLE_CLOUD_PROJECT"
|
||||
]:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
def test_ids_from_env_parses_url(monkeypatch):
|
||||
"""Test with passed in DSN"""
|
||||
monkeypatch.setenv("SYNC_SYNCSTORAGE__DATABASE_URL", "spanner://projects/proj/instances/inst/databases/db")
|
||||
dsn = "SYNC_SYNCSTORAGE__DATABASE_URL"
|
||||
instance_id, database_id, project_id = utils.ids_from_env(dsn)
|
||||
assert project_id == "proj"
|
||||
assert instance_id == "inst"
|
||||
assert database_id == "db"
|
||||
|
||||
def test_ids_from_env_with_missing_url(monkeypatch):
|
||||
"""Test ensures that default env vars set id values."""
|
||||
monkeypatch.setenv("INSTANCE_ID", "foo")
|
||||
monkeypatch.setenv("DATABASE_ID", "bar")
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "baz")
|
||||
instance_id, database_id, project_id = utils.ids_from_env()
|
||||
assert instance_id == "foo"
|
||||
assert database_id == "bar"
|
||||
assert project_id == "baz"
|
||||
|
||||
|
||||
def test_from_env_with_invalid_url(monkeypatch):
|
||||
monkeypatch.setenv("SYNC_SYNCSTORAGE__DATABASE_URL", "notaspanner://foo")
|
||||
monkeypatch.setenv("INSTANCE_ID", "default")
|
||||
monkeypatch.setenv("DATABASE_ID", "default-db")
|
||||
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "default-proj")
|
||||
|
||||
instance_id, database_id, project_id = utils.ids_from_env()
|
||||
assert instance_id == "default"
|
||||
assert database_id == "default-db"
|
||||
assert project_id == "default-proj"
|
68
tools/spanner/utils.py
Normal file
68
tools/spanner/utils.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Utility Module for spanner CLI scripts
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
|
||||
from enum import auto, Enum
|
||||
import os
|
||||
from urllib import parse
|
||||
from typing import Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
DSN_URL = "SYNC_SYNCSTORAGE__DATABASE_URL"
|
||||
"""
|
||||
Environment variable that stores Sync database URL
|
||||
Depending on deployment, can be MySQL or Spanner.
|
||||
In this context, should always point to spanner for these scripts.
|
||||
"""
|
||||
|
||||
class Mode(Enum):
|
||||
URL = auto()
|
||||
ENV_VAR = auto()
|
||||
|
||||
def ids_from_env(dsn=DSN_URL, mode=Mode.ENV_VAR) -> Tuple[str, str, str]:
|
||||
"""
|
||||
Function that extracts the instance, project, and database ids from the DSN url.
|
||||
It is defined as the SYNC_SYNCSTORAGE__DATABASE_URL environment variable.
|
||||
The defined defaults are in webservices-infra/sync and can be configured there for
|
||||
production runs.
|
||||
|
||||
`dsn` argument is set to default to the `DSN_URL` constant.
|
||||
|
||||
For reference, an example spanner url passed in is in the following format:
|
||||
|
||||
`spanner://projects/moz-fx-sync-prod-xxxx/instances/sync/databases/syncdb`
|
||||
database_id = `syncdb`, instance_id = `sync`, project_id = `moz-fx-sync-prod-xxxx`
|
||||
"""
|
||||
# Change these to reflect your Spanner instance install
|
||||
instance_id = None
|
||||
database_id = None
|
||||
project_id = None
|
||||
|
||||
try:
|
||||
if mode == Mode.ENV_VAR:
|
||||
url = os.environ.get(dsn)
|
||||
if not url:
|
||||
raise Exception(f"No env var found for provided DSN: {dsn}")
|
||||
elif mode == Mode.URL:
|
||||
url = dsn
|
||||
if not url:
|
||||
raise Exception(f"No valid url found: {url}")
|
||||
parsed_url = parse.urlparse(url)
|
||||
if parsed_url.scheme == "spanner":
|
||||
path = parsed_url.path.split("/")
|
||||
instance_id = path[-3]
|
||||
project_id = path[-5]
|
||||
database_id = path[-1]
|
||||
except Exception as e:
|
||||
print(f"Exception parsing url: {e}")
|
||||
# Fallbacks if not set
|
||||
if not instance_id:
|
||||
instance_id = os.environ.get("INSTANCE_ID", "spanner-test")
|
||||
if not database_id:
|
||||
database_id = os.environ.get("DATABASE_ID", "sync_stage")
|
||||
if not project_id:
|
||||
project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", "test-project")
|
||||
|
||||
return (instance_id, database_id, project_id)
|
@ -21,6 +21,8 @@ from google.api_core.exceptions import AlreadyExists
|
||||
from google.cloud import spanner
|
||||
from google.cloud.spanner_v1 import param_types
|
||||
|
||||
from utils import ids_from_env
|
||||
|
||||
|
||||
# max batch size for this write is 2000, otherwise we run into:
|
||||
"""google.api_core.exceptions.InvalidArgument: 400 The transaction
|
||||
@ -79,12 +81,12 @@ PAYLOAD = ''.join(
|
||||
def load(instance, db, coll_id, name):
|
||||
fxa_uid = "DEADBEEF" + uuid.uuid4().hex[8:]
|
||||
fxa_kid = "{:013d}-{}".format(22, fxa_uid)
|
||||
print("{} -> Loading {} {}".format(name, fxa_uid, fxa_kid))
|
||||
print(f"{name} -> Loading {fxa_uid} {fxa_kid}")
|
||||
name = threading.current_thread().getName()
|
||||
spanner_client = spanner.Client()
|
||||
instance = spanner_client.instance(instance)
|
||||
db = instance.database(db)
|
||||
print('{name} Db: {db}'.format(name=name, db=db))
|
||||
print(f"{name} Db: {db}")
|
||||
start = datetime.now()
|
||||
|
||||
def create_user(txn):
|
||||
@ -110,16 +112,14 @@ def load(instance, db, coll_id, name):
|
||||
|
||||
try:
|
||||
db.run_in_transaction(create_user)
|
||||
print('{name} Created user (fxa_uid: {uid}, fxa_kid: {kid})'.format(
|
||||
name=name, uid=fxa_uid, kid=fxa_kid))
|
||||
print(f"{name} Created user (fxa_uid: {fxa_uid}, fxa_kid: {fxa_kid})")
|
||||
except AlreadyExists:
|
||||
print('{name} Existing user (fxa_uid: {uid}}, fxa_kid: {kid}})'.format(
|
||||
name=name, uid=fxa_uid, kid=fxa_kid))
|
||||
print(f"{name} Existing user (fxa_uid: {fxa_uid}, fxa_kid: {fxa_kid})")
|
||||
|
||||
# approximately 1892 bytes
|
||||
rlen = 0
|
||||
|
||||
print('{name} Loading..'.format(name=name))
|
||||
print(f"{name} Loading..")
|
||||
for j in range(BATCHES):
|
||||
records = []
|
||||
for i in range(BATCH_SIZE):
|
||||
@ -176,29 +176,11 @@ def load(instance, db, coll_id, name):
|
||||
))
|
||||
|
||||
|
||||
def from_env():
|
||||
try:
|
||||
url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL")
|
||||
if not url:
|
||||
raise Exception("no url")
|
||||
purl = parse.urlparse(url)
|
||||
if purl.scheme == "spanner":
|
||||
path = purl.path.split("/")
|
||||
instance_id = path[-3]
|
||||
database_id = path[-1]
|
||||
except Exception as e:
|
||||
# Change these to reflect your Spanner instance install
|
||||
print("Exception {}".format(e))
|
||||
instance_id = os.environ.get("INSTANCE_ID", "spanner-test")
|
||||
database_id = os.environ.get("DATABASE_ID", "sync_stage")
|
||||
return (instance_id, database_id)
|
||||
|
||||
|
||||
def loader():
|
||||
# Prefix uaids for easy filtering later
|
||||
# Each loader thread gets it's own fake user to prevent some hotspot
|
||||
# issues.
|
||||
(instance_id, database_id) = from_env()
|
||||
(instance_id, database_id, _) = ids_from_env()
|
||||
# switching uid/kid to per load because of weird google trimming
|
||||
name = threading.current_thread().getName()
|
||||
load(instance_id, database_id, COLL_ID, name)
|
||||
@ -206,9 +188,9 @@ def loader():
|
||||
|
||||
def main():
|
||||
for c in range(THREAD_COUNT):
|
||||
print("Starting thread {}".format(c))
|
||||
print(f"Starting thread {c}")
|
||||
t = threading.Thread(
|
||||
name="loader_{}".format(c),
|
||||
name=f"loader_{c}",
|
||||
target=loader)
|
||||
t.start()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user