Barry Chen 3509177307
use parsed url method to replace the scheme
Co-authored-by: Philip Jenvey <pjenvey@underboss.org>
2025-11-20 06:23:29 -08:00

197 lines
5.8 KiB
Python

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
import argparse
import logging
import os
import sys
from datetime import datetime
from typing import List, Optional, Tuple
from urllib.parse import urlparse
import sqlalchemy
from sqlalchemy import text
from statsd.defaults.env import statsd
logging.basicConfig(
format='{"datetime": "%(asctime)s", "message": "%(message)s"}',
stream=sys.stdout,
level=logging.INFO,
)
def get_db_engine(database_url: str):
"""Create a SQLAlchemy engine from a database url."""
parsed = urlparse(database_url)
if parsed.scheme not in ['postgresql', 'postgres']:
raise ValueError(f"Unsupported database scheme: {parsed.scheme}")
# newer versions of SQLAlchemy want 'postgresql' instead of 'postgres'
if parsed.scheme == 'postgres':
parsed = parsed._replace(scheme='postgresql')
return sqlalchemy.create_engine(parsed.geturl())
def exec_delete(
engine,
name: str,
query: str,
params: Optional[dict] = None,
dryrun: Optional[bool] = False,
):
"""Execute the DELETE query with the given query params."""
with statsd.timer(f"syncstorage.purge_ttl.{name}_duration"):
logging.info(f"Running: {query} :: {params}")
start = datetime.now()
result = 0
if not dryrun:
with engine.connect() as conn:
result_proxy = conn.execute(text(query), params or {})
result = result_proxy.rowcount
end = datetime.now()
logging.info(
f"{name}: removed {result} rows, {name}_duration: {end - start}"
)
def add_conditions(args, query: str) -> Tuple[str, dict]:
"""Add SQL conditions to the query and store the arg values."""
params = {}
if args.collection_ids:
ids = list(filter(len, args.collection_ids))
if ids:
query += " AND collection_id"
if len(ids) == 1:
query += " = :collection_id"
params["collection_id"] = ids[0]
else:
param_names = [f"collection_id_{i}" for i in range(len(ids))]
query += f" IN ({', '.join([':' + name for name in param_names])})"
for i, cid in enumerate(ids):
params[f"collection_id_{i}"] = cid
return (query, params)
def get_expiry_condition(args) -> str:
"""Build the expiry WHERE condition."""
if args.expiry_mode == "now":
return "expiry < CURRENT_TIMESTAMP"
elif args.expiry_mode == "midnight":
return "expiry < DATE_TRUNC('day', CURRENT_TIMESTAMP AT TIME ZONE 'UTC')"
else:
raise Exception(f"Invalid expiry mode: {args.expiry_mode}")
def purge_records(args) -> None:
"""The main fn."""
engine = get_db_engine(args.database_url)
expiry_condition = get_expiry_condition(args)
if args.mode in ["batches", "both"]:
(batch_query, params) = add_conditions(
args,
f"DELETE FROM batches WHERE {expiry_condition}",
)
exec_delete(
engine,
name="batches",
query=batch_query,
params=params,
dryrun=args.dryrun,
)
if args.mode in ["bsos", "both"]:
(bso_query, params) = add_conditions(
args,
f"DELETE FROM bsos WHERE {expiry_condition}",
)
exec_delete(
engine,
name="bsos",
query=bso_query,
params=params,
dryrun=args.dryrun,
)
def get_args():
"""Parse cli args."""
parser = argparse.ArgumentParser(description="Purge expired records from the database")
parser.add_argument(
"-u",
"--database_url",
default=os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL"),
required=False,
help="Database URL (postgresql://... or postgres://...)",
)
parser.add_argument(
"--collection_ids",
"--ids",
type=parse_args_list,
default=os.environ.get("COLLECTION_IDS", "[]"),
help="Array of collection IDs to purge",
)
parser.add_argument(
"--mode",
type=str,
choices=["batches", "bsos", "both"],
default=os.environ.get("PURGE_MODE", "both"),
help="Purge batches, bsos, or both",
)
parser.add_argument(
"--expiry_mode",
type=str,
choices=["now", "midnight"],
default=os.environ.get("PURGE_EXPIRY_MODE", "midnight"),
help="Choose the timestamp used to check if an entry is expired",
)
parser.add_argument(
"--dryrun",
action="store_true",
help="A dry run instead of purging data",
)
args = parser.parse_args()
if not args.database_url:
parser.error("--database_url is required, or set SYNC_SYNCSTORAGE__DATABASE_URL")
return args
def parse_args_list(args_list: str) -> List[str]:
"""Parses a string representing a list of items into a list of strings.
Args:
args_list (str): String to parse, e.g., "[item1,item2,item3]" or "item1".
Returns:
List[str]: List of parsed string items.
"""
if not args_list or args_list == "[]":
return []
if args_list[0] != "[" or args_list[-1] != "]":
# Assume it's a single item
return [args_list]
return args_list[1:-1].split(",")
if __name__ == "__main__":
args = get_args()
with statsd.timer("syncstorage.purge_ttl.total_duration"):
start_time = datetime.now()
logging.info("Starting purge_ttl.py")
purge_records(args)
end_time = datetime.now()
duration = end_time - start_time
logging.info(f"Completed purge_ttl.py, total_duration: {duration}")