diff --git a/tools/integration_tests/conftest.py b/tools/integration_tests/conftest.py index c47e1773..7ceaeb71 100644 --- a/tools/integration_tests/conftest.py +++ b/tools/integration_tests/conftest.py @@ -1,99 +1,29 @@ -"""Pytest configuration and fixtures for integration tests.""" +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +"""Pytest fixtures for storage integration tests. + +Fixture hierarchy +───────────────── +st_ctx — function-scoped composite: sets up Pyramid configurator, creates a + hawk-signed TestApp, seeds a random user, clears that user's data, + and yields a plain dict consumed by test functions. + +Helper functions and constants live in helpers.py. +""" -import contextlib -import logging import os -import random -import time import uuid -import hawkauthlib import pytest -import webtest -from pyramid.interfaces import IAuthenticationPolicy -from pyramid.request import Request -from webtest import TestApp +from tools.integration_tests.helpers import ( + make_auth_state, + make_test_app, + retry_delete, +) from tools.integration_tests.test_support import get_test_configurator -# max number of attempts to check server heartbeat -SYNC_SERVER_STARTUP_MAX_ATTEMPTS = 35 -SYNC_SERVER_URL = os.environ.get("SYNC_SERVER_URL", "http://localhost:8000") - -logger = logging.getLogger("tools.integration-tests") - -if os.environ.get("SYNC_TEST_LOG_HTTP"): - _orig_do_request = webtest.TestApp.do_request - - def _logged_do_request(self, req, *args, **kwargs): - """Wrap request and response logging around original do_request.""" - logger.info(">> %s %s", req.method, req.url) - if req.body: - logger.info(">> BODY: %s", req.body) - resp = _orig_do_request(self, req, *args, **kwargs) - logger.info("<< %s", resp.status) - logger.info("<< BODY: %s", resp.body) - return resp - - webtest.TestApp.do_request = _logged_do_request - - -def _retry_send(func, *args, **kwargs): - """Call a webtest method, retrying once on 409/503.""" - try: - return func(*args, **kwargs) - except webtest.AppError as ex: - if "409 " not in ex.args[0] and "503 " not in ex.args[0]: - raise - time.sleep(0.01) - return func(*args, **kwargs) - - -def retry_post_json(app, *args, **kwargs): - """POST JSON with retry on transient errors.""" - return _retry_send(app.post_json, *args, **kwargs) - - -def retry_put_json(app, *args, **kwargs): - """PUT JSON with retry on transient errors.""" - return _retry_send(app.put_json, *args, **kwargs) - - -def retry_delete(app, *args, **kwargs): - """DELETE with retry on transient errors.""" - return _retry_send(app.delete, *args, **kwargs) - - -def _make_auth_state(config, host_url): - """Generate hawk credentials for a new random user.""" - global_secret = os.environ.get("SYNC_MASTER_SECRET") - policy = config.registry.getUtility(IAuthenticationPolicy) - if global_secret is not None: - policy.secrets._secrets = [global_secret] - user_id = random.randint(1, 100000) - fxa_uid = "DECAFBAD" + str(uuid.uuid4().hex)[8:] - hashed_fxa_uid = str(uuid.uuid4().hex) - fxa_kid = "0000000000000-DECAFBAD" + str(uuid.uuid4().hex)[8:] - req = Request.blank(host_url) - creds = policy.encode_hawk_id( - req, - user_id, - extra={ - "hashed_fxa_uid": hashed_fxa_uid, - "fxa_uid": fxa_uid, - "fxa_kid": fxa_kid, - }, - ) - auth_token, auth_secret = creds - return { - "user_id": user_id, - "fxa_uid": fxa_uid, - "hashed_fxa_uid": hashed_fxa_uid, - "fxa_kid": fxa_kid, - "auth_token": auth_token, - "auth_secret": auth_secret, - } - @pytest.fixture(scope="function") def st_ctx(): @@ -113,49 +43,26 @@ def st_ctx(): ondisk = "sqlite:////tmp/tests-sync-%s.db" % os.environ["MOZSVC_UUID"] os.environ["MOZSVC_ONDISK_SQLURI"] = ondisk - # Locate tests.ini relative to test_storage.py + # Locate tests.ini relative to this file this_dir = os.path.dirname(os.path.abspath(__file__)) config = get_test_configurator(this_dir, ini_file) config.commit() config.make_wsgi_app() host_url = os.environ.get("SYNC_SERVER_URL", "http://localhost:8000") - import urllib.parse as urlparse - host_parts = urlparse.urlparse(host_url) - app = TestApp( - host_url, - extra_environ={ - "HTTP_HOST": host_parts.netloc, - "wsgi.url_scheme": host_parts.scheme or "http", - "SERVER_NAME": host_parts.hostname, - "REMOTE_ADDR": "127.0.0.1", - "SCRIPT_NAME": host_parts.path, - }, - ) - - # Mutable auth state — shared with the do_request closure so that - # switch_user() and the expired-token test can swap credentials at runtime. - auth = _make_auth_state(config, host_url) + auth = make_auth_state(config, host_url) auth_state = { "auth_token": auth["auth_token"], "auth_secret": auth["auth_secret"], } - orig_do_request = app.do_request - - def new_do_request(req, *args, **kwds): - hawkauthlib.sign_request( - req, auth_state["auth_token"], auth_state["auth_secret"] - ) - return orig_do_request(req, *args, **kwds) - - app.do_request = new_do_request + app = make_test_app(host_url, auth_state) root = "/1.5/%d" % auth["user_id"] retry_delete(app, root) - ctx = { + yield { "app": app, "root": root, "user_id": auth["user_id"], @@ -167,56 +74,5 @@ def st_ctx(): "host_url": host_url, } - yield ctx - config.end() del os.environ["MOZSVC_UUID"] - - -@contextlib.contextmanager -def switch_user(st_ctx): - """Context manager: temporarily switch to a fresh random user. - - Updates both st_ctx and the auth_state dict (shared with the - do_request closure) for the duration of the block, then restores - the original user on exit. - """ - orig_root = st_ctx["root"] - orig_user_id = st_ctx["user_id"] - orig_fxa_uid = st_ctx["fxa_uid"] - orig_hashed_fxa_uid = st_ctx["hashed_fxa_uid"] - orig_fxa_kid = st_ctx["fxa_kid"] - orig_auth_token = st_ctx["auth_state"]["auth_token"] - orig_auth_secret = st_ctx["auth_state"]["auth_secret"] - - config = st_ctx["config"] - host_url = st_ctx["host_url"] - app = st_ctx["app"] - - for _ in range(10): - new_auth = _make_auth_state(config, host_url) - if new_auth["user_id"] != orig_user_id: - break - else: - raise RuntimeError("Failed to switch to new user id") - - st_ctx["auth_state"]["auth_token"] = new_auth["auth_token"] - st_ctx["auth_state"]["auth_secret"] = new_auth["auth_secret"] - st_ctx["user_id"] = new_auth["user_id"] - st_ctx["fxa_uid"] = new_auth["fxa_uid"] - st_ctx["hashed_fxa_uid"] = new_auth["hashed_fxa_uid"] - st_ctx["fxa_kid"] = new_auth["fxa_kid"] - new_root = "/1.5/%d" % new_auth["user_id"] - st_ctx["root"] = new_root - retry_delete(app, new_root) - - try: - yield - finally: - st_ctx["auth_state"]["auth_token"] = orig_auth_token - st_ctx["auth_state"]["auth_secret"] = orig_auth_secret - st_ctx["user_id"] = orig_user_id - st_ctx["fxa_uid"] = orig_fxa_uid - st_ctx["hashed_fxa_uid"] = orig_hashed_fxa_uid - st_ctx["fxa_kid"] = orig_fxa_kid - st_ctx["root"] = orig_root diff --git a/tools/integration_tests/helpers.py b/tools/integration_tests/helpers.py new file mode 100644 index 00000000..c777e23c --- /dev/null +++ b/tools/integration_tests/helpers.py @@ -0,0 +1,181 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +"""Helper functions and constants for storage integration tests. + +These are plain module-level utilities — no pytest fixtures here. +Test files import directly from this module; conftest.py imports +whatever it needs to build fixtures. +""" + +import contextlib +import logging +import os +import random +import time +import uuid + +import hawkauthlib +import webtest +from pyramid.interfaces import IAuthenticationPolicy +from pyramid.request import Request +from webtest import TestApp + +# max number of attempts to check server heartbeat +SYNC_SERVER_STARTUP_MAX_ATTEMPTS = 35 +SYNC_SERVER_URL = os.environ.get("SYNC_SERVER_URL", "http://localhost:8000") + +logger = logging.getLogger("tools.integration-tests") + +if os.environ.get("SYNC_TEST_LOG_HTTP"): + _orig_do_request = webtest.TestApp.do_request + + def _logged_do_request(self, req, *args, **kwargs): + """Wrap request and response logging around original do_request.""" + logger.info(">> %s %s", req.method, req.url) + if req.body: + logger.info(">> BODY: %s", req.body) + resp = _orig_do_request(self, req, *args, **kwargs) + logger.info("<< %s", resp.status) + logger.info("<< BODY: %s", resp.body) + return resp + + webtest.TestApp.do_request = _logged_do_request + + +def _retry_send(func, *args, **kwargs): + """Call a webtest method, retrying once on 409/503.""" + try: + return func(*args, **kwargs) + except webtest.AppError as ex: + if "409 " not in ex.args[0] and "503 " not in ex.args[0]: + raise + time.sleep(0.01) + return func(*args, **kwargs) + + +def retry_post_json(app, *args, **kwargs): + """POST JSON with retry on transient errors.""" + return _retry_send(app.post_json, *args, **kwargs) + + +def retry_put_json(app, *args, **kwargs): + """PUT JSON with retry on transient errors.""" + return _retry_send(app.put_json, *args, **kwargs) + + +def retry_delete(app, *args, **kwargs): + """DELETE with retry on transient errors.""" + return _retry_send(app.delete, *args, **kwargs) + + +def make_auth_state(config, host_url): + """Generate hawk credentials for a new random user.""" + global_secret = os.environ.get("SYNC_MASTER_SECRET") + policy = config.registry.getUtility(IAuthenticationPolicy) + if global_secret is not None: + policy.secrets._secrets = [global_secret] + user_id = random.randint(1, 100000) + fxa_uid = "DECAFBAD" + str(uuid.uuid4().hex)[8:] + hashed_fxa_uid = str(uuid.uuid4().hex) + fxa_kid = "0000000000000-DECAFBAD" + str(uuid.uuid4().hex)[8:] + req = Request.blank(host_url) + creds = policy.encode_hawk_id( + req, + user_id, + extra={ + "hashed_fxa_uid": hashed_fxa_uid, + "fxa_uid": fxa_uid, + "fxa_kid": fxa_kid, + }, + ) + auth_token, auth_secret = creds + return { + "user_id": user_id, + "fxa_uid": fxa_uid, + "hashed_fxa_uid": hashed_fxa_uid, + "fxa_kid": fxa_kid, + "auth_token": auth_token, + "auth_secret": auth_secret, + } + + +def make_test_app(host_url, auth_state): + """Build a hawk-signed WebTest TestApp for the given host URL. + + Returns ``(app, root)`` where *root* is the ``/1.5/`` prefix + for the authenticated user embedded in *auth_state*. + """ + import urllib.parse as urlparse + + host_parts = urlparse.urlparse(host_url) + app = TestApp( + host_url, + extra_environ={ + "HTTP_HOST": host_parts.netloc, + "wsgi.url_scheme": host_parts.scheme or "http", + "SERVER_NAME": host_parts.hostname, + "REMOTE_ADDR": "127.0.0.1", + "SCRIPT_NAME": host_parts.path, + }, + ) + + orig_do_request = app.do_request + + def new_do_request(req, *args, **kwds): + hawkauthlib.sign_request( + req, auth_state["auth_token"], auth_state["auth_secret"] + ) + return orig_do_request(req, *args, **kwds) + + app.do_request = new_do_request + return app + + +@contextlib.contextmanager +def switch_user(st_ctx): + """Context manager: temporarily switch to a fresh random user. + + Updates both st_ctx and the auth_state dict (shared with the + do_request closure) for the duration of the block, then restores + the original user on exit. + """ + orig_root = st_ctx["root"] + orig_user_id = st_ctx["user_id"] + orig_fxa_uid = st_ctx["fxa_uid"] + orig_hashed_fxa_uid = st_ctx["hashed_fxa_uid"] + orig_fxa_kid = st_ctx["fxa_kid"] + orig_auth_token = st_ctx["auth_state"]["auth_token"] + orig_auth_secret = st_ctx["auth_state"]["auth_secret"] + + config = st_ctx["config"] + host_url = st_ctx["host_url"] + app = st_ctx["app"] + + for _ in range(10): + new_auth = make_auth_state(config, host_url) + if new_auth["user_id"] != orig_user_id: + break + else: + raise RuntimeError("Failed to switch to new user id") + + st_ctx["auth_state"]["auth_token"] = new_auth["auth_token"] + st_ctx["auth_state"]["auth_secret"] = new_auth["auth_secret"] + st_ctx["user_id"] = new_auth["user_id"] + st_ctx["fxa_uid"] = new_auth["fxa_uid"] + st_ctx["hashed_fxa_uid"] = new_auth["hashed_fxa_uid"] + st_ctx["fxa_kid"] = new_auth["fxa_kid"] + new_root = "/1.5/%d" % new_auth["user_id"] + st_ctx["root"] = new_root + retry_delete(app, new_root) + + try: + yield + finally: + st_ctx["auth_state"]["auth_token"] = orig_auth_token + st_ctx["auth_state"]["auth_secret"] = orig_auth_secret + st_ctx["user_id"] = orig_user_id + st_ctx["fxa_uid"] = orig_fxa_uid + st_ctx["hashed_fxa_uid"] = orig_hashed_fxa_uid + st_ctx["fxa_kid"] = orig_fxa_kid + st_ctx["root"] = orig_root diff --git a/tools/integration_tests/test_storage.py b/tools/integration_tests/test_storage.py index cf6f2233..0fb1a4ed 100644 --- a/tools/integration_tests/test_storage.py +++ b/tools/integration_tests/test_storage.py @@ -28,7 +28,7 @@ from webtest.app import AppError import tokenlib -from tools.integration_tests.conftest import ( +from tools.integration_tests.helpers import ( switch_user, retry_post_json, retry_put_json, diff --git a/tools/integration_tests/tokenserver/conftest.py b/tools/integration_tests/tokenserver/conftest.py index cc23dd94..fae747c6 100644 --- a/tools/integration_tests/tokenserver/conftest.py +++ b/tools/integration_tests/tokenserver/conftest.py @@ -1,359 +1,44 @@ # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this file, # You can obtain one at http://mozilla.org/MPL/2.0/. -"""Fixtures and helpers for tokenserver integration tests. - -All helper functions are module-level so that test functions can -import and call them directly without going through a class instance. +"""Pytest fixtures for tokenserver integration tests. Fixture hierarchy ───────────────── -ts_db_conn — function-scoped SQLAlchemy connection -ts_app — function-scoped WebTest TestApp +ts_db_conn — function-scoped SQLAlchemy connection +ts_app — function-scoped WebTest TestApp ts_service_id — function-scoped service ID (sync-1.5) -ts_ctx — function-scoped composite: clears DB, seeds service + node, - yields a plain dict consumed by test functions -fxa_auth — session-scoped FxA OAuth token for test_e2e.py - (session scope justified: FxA account creation is a slow - network call to an external staging service; one account - suffices for the whole test session) +ts_ctx — function-scoped composite: clears DB, seeds service + node, + yields a plain dict consumed by test functions +fxa_auth — session-scoped FxA OAuth token for test_e2e.py + (session scope justified: FxA account creation is a slow + network call to an external staging service; one account + suffices for the whole test session) + +Helper functions and constants live in helpers.py. """ -import binascii -import json -import math import os import random import string import time -import urllib.parse as urlparse -from base64 import urlsafe_b64encode as b64encode import pytest from sqlalchemy import create_engine -from sqlalchemy.sql import text as sqltext -from tokenlib.utils import decode_token_bytes -from webtest import TestApp -DEFAULT_OAUTH_SCOPE = "https://identity.mozilla.com/apps/oldsync" - -NODE_ID = 800 -NODE_URL = "https://example.com" -FXA_EMAIL_DOMAIN = "api-accounts.stage.mozaws.net" -TOKEN_SIGNING_SECRET = os.environ.get("SYNC_MASTER_SECRET", "secret0") -FXA_METRICS_HASH_SECRET = os.environ.get("SYNC_MASTER_SECRET", "secret0") +from integration_tests.tokenserver.helpers import ( + NODE_ID, + NODE_URL, + add_node, + clear_db, + get_db_mode, + get_expected_node_type, + get_or_add_service, + make_app, +) -# ── DB-mode helper ─────────────────────────────────────────────────────────── - - -def get_db_mode() -> str: - """Derive db_mode from the SYNC_TOKENSERVER__DATABASE_URL env var.""" - return os.environ["SYNC_TOKENSERVER__DATABASE_URL"].split(":")[0] - - -def get_expected_node_type() -> str: - """Derive expected node_type from the SYNC_SYNCSTORAGE__DATABASE_URL env var.""" - syncstorage_url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL", "spanner://") - node_type = syncstorage_url.split(":")[0] - if node_type == "postgresql": - return "postgres" - if node_type.startswith("mysql"): - return "mysql" - return node_type - - -# ── SQL helpers ────────────────────────────────────────────────────────────── - - -def execute_sql(conn, query, params=None): - """Execute a SQL statement and return the cursor.""" - return conn.execute(query, params or {}) - - -def clear_db(conn) -> None: - """Delete all users and nodes. - - Services are intentionally not cleared: tokenserver may have cached - its service_id and a DELETE would invalidate that cache mid-run. - """ - execute_sql(conn, sqltext("DELETE FROM users"), {}).close() - execute_sql(conn, sqltext("DELETE FROM nodes"), {}).close() - - -def get_service_id(conn, service: str): - """Return the ID for the given service name, or None if not found.""" - cursor = execute_sql( - conn, - sqltext("select id from services where service = :service"), - {"service": service}, - ) - row = cursor.fetchone() - cursor.close() - return None if row is None else row[0] - - -def add_service(conn, service: str, pattern: str) -> int: - """Insert a services row and return its ID.""" - db_mode = get_db_mode() - if db_mode == "postgres": - sql = sqltext( - "insert into services (service, pattern) values (:service, :pattern) RETURNING id" - ) - cursor = execute_sql(conn, sql, {"service": service, "pattern": pattern}) - result: int = cursor.fetchone()[0] - else: - sql = sqltext( - "insert into services (service, pattern) values (:service, :pattern)" - ) - cursor = execute_sql(conn, sql, {"service": service, "pattern": pattern}) - result = cursor.lastrowid - cursor.close() - return result - - -def get_or_add_service(conn, service: str, pattern: str) -> int: - """Return existing service ID, inserting a new row if it does not exist.""" - service_id = get_service_id(conn, service) - if service_id is not None: - return int(service_id) - return add_service(conn, service, pattern) - - -def add_node( - conn, - service_id: int, - capacity: int = 100, - available: int = 100, - node: str = NODE_URL, - id: int | None = None, - current_load: int = 0, - backoff: int = 0, - downed: int = 0, -) -> int: - """Insert a nodes row and return its ID.""" - db_mode = get_db_mode() - params = { - "service": service_id, - "node": node, - "available": available, - "capacity": capacity, - "current_load": current_load, - "backoff": backoff, - "downed": downed, - } - if id is not None: - params["id"] = id - cols = "service, node, available, capacity, current_load, backoff, downed, id" - vals = ":service, :node, :available, :capacity, :current_load, :backoff, :downed, :id" - else: - cols = "service, node, available, capacity, current_load, backoff, downed" - vals = ( - ":service, :node, :available, :capacity, :current_load, :backoff, :downed" - ) - - result: int - if db_mode == "postgres": - sql = sqltext(f"insert into nodes ({cols}) values ({vals}) RETURNING id") # nosec B608 - cols/vals are hardcoded literals, not user input - cursor = execute_sql(conn, sql, params) - result = cursor.fetchone()[0] - else: - sql = sqltext(f"insert into nodes ({cols}) values ({vals})") # nosec B608 - cursor = execute_sql(conn, sql, params) - result = cursor.lastrowid - cursor.close() - return result - - -def get_node(conn, node_id: int) -> dict: - """Return a node dict by ID.""" - cursor = execute_sql( - conn, sqltext("select * from nodes where id = :id"), {"id": node_id} - ) - (id_, service, node, available, current_load, capacity, downed, backoff) = ( - cursor.fetchone() - ) - cursor.close() - return { - "id": id_, - "service": service, - "node": node, - "available": available, - "current_load": current_load, - "capacity": capacity, - "downed": downed, - "backoff": backoff, - } - - -def add_user( - conn, - service_id: int, - email: str | None = None, - nodeid: int = NODE_ID, - generation: int = 1234, - keys_changed_at: int | None = 1234, - client_state: str = "aaaa", - created_at: int | None = None, - replaced_at: int | None = None, -) -> int: - """Insert a users row and return its uid.""" - db_mode = get_db_mode() - created_at = created_at or math.trunc(time.time() * 1000) - params = { - "service": service_id, - "email": email or f"test@{FXA_EMAIL_DOMAIN}", - "nodeid": nodeid, - "generation": generation, - "keys_changed_at": keys_changed_at, - "client_state": client_state, - "created_at": created_at, - "replaced_at": replaced_at, - } - result: int - if db_mode == "postgres": - sql = sqltext("""\ - insert into users - (service, email, nodeid, generation, keys_changed_at, - client_state, created_at, replaced_at) - values - (:service, :email, :nodeid, :generation, :keys_changed_at, - :client_state, :created_at, :replaced_at) - RETURNING uid - """) - cursor = execute_sql(conn, sql, params) - result = cursor.fetchone()[0] - else: - sql = sqltext("""\ - insert into users - (service, email, nodeid, generation, keys_changed_at, - client_state, created_at, replaced_at) - values - (:service, :email, :nodeid, :generation, :keys_changed_at, - :client_state, :created_at, :replaced_at) - """) - cursor = execute_sql(conn, sql, params) - result = cursor.lastrowid - cursor.close() - return result - - -def get_user(conn, uid: int) -> dict: - """Return a user dict by uid.""" - cursor = execute_sql( - conn, sqltext("select * from users where uid = :uid"), {"uid": uid} - ) - ( - uid, - service, - email, - generation, - client_state, - created_at, - replaced_at, - nodeid, - keys_changed_at, - ) = cursor.fetchone() - cursor.close() - return { - "uid": uid, - "service": service, - "email": email, - "generation": generation, - "client_state": client_state, - "created_at": created_at, - "replaced_at": replaced_at, - "nodeid": nodeid, - "keys_changed_at": keys_changed_at, - } - - -def get_replaced_users(conn, service_id: int, email: str) -> list: - """Return a list of user dicts for records with a non-null replaced_at.""" - cursor = execute_sql( - conn, - sqltext("""\ - select * from users - where service = :service - and email = :email - and replaced_at is not null - """), - {"service": service_id, "email": email}, - ) - users = [] - for row in cursor.fetchall(): - ( - uid, - service, - email, - generation, - client_state, - created_at, - replaced_at, - nodeid, - keys_changed_at, - ) = row - users.append( - { - "uid": uid, - "service": service, - "email": email, - "generation": generation, - "client_state": client_state, - "created_at": created_at, - "replaced_at": replaced_at, - "nodeid": nodeid, - "keys_changed_at": keys_changed_at, - } - ) - cursor.close() - return users - - -def count_users(conn) -> int: - """Return the count of distinct user UIDs.""" - cursor = execute_sql(conn, sqltext("select COUNT(DISTINCT(uid)) from users"), {}) - (count,) = cursor.fetchone() - cursor.close() - return int(count) - - -# ── Auth helpers ───────────────────────────────────────────────────────────── - - -def build_oauth_headers( - generation: int | None = None, - user: str = "test", - keys_changed_at: int | None = None, - client_state: str | None = None, - status: int = 200, - **additional_headers: str, -) -> dict: - """Build OAuth Bearer + X-KeyID headers for a test request.""" - claims = { - "user": user, - "generation": generation, - "client_id": "fake client id", - "scope": [DEFAULT_OAUTH_SCOPE], - } - if generation is not None: - claims["generation"] = generation - body = {"body": claims, "status": status} - headers = {} - headers["Authorization"] = f"Bearer {json.dumps(body)}" - client_state_bytes = binascii.unhexlify(client_state or "") - client_state_b64 = b64encode(client_state_bytes).strip(b"=").decode("utf-8") - headers["X-KeyID"] = f"{keys_changed_at}-{client_state_b64}" - headers.update(additional_headers) - return headers - - -def unsafe_parse_token(token: str) -> dict: - """Parse a tokenlib token without verifying its HMAC signature.""" - return json.loads(decode_token_bytes(token)[:-32].decode("utf8")) # type: ignore[no-any-return] - - -# ── Fixtures ───────────────────────────────────────────────────────────────── +# ── Fixtures ────────────────────────────────────────────────────────────────── @pytest.fixture(scope="function") @@ -372,18 +57,7 @@ def ts_db_conn(): @pytest.fixture(scope="function") def ts_app(): """Function-scoped WebTest TestApp pointing at the tokenserver host.""" - host = os.environ["TOKENSERVER_HOST"] - host_url = urlparse.urlparse(host) - return TestApp( - host, - extra_environ={ - "HTTP_HOST": host_url.netloc, - "wsgi.url_scheme": host_url.scheme or "http", - "SERVER_NAME": host_url.hostname, - "REMOTE_ADDR": "127.0.0.1", - "SCRIPT_NAME": host_url.path, - }, - ) + return make_app(os.environ["TOKENSERVER_HOST"]) @pytest.fixture(scope="function") @@ -397,7 +71,7 @@ def ts_ctx(ts_db_conn, ts_app, ts_service_id): """Full per-test tokenserver context. Clears the database, seeds the default service and node, then yields - a dict that test functions can destructure (pytest): + a dict that test functions can destructure: def test_foo(ts_ctx): app = ts_ctx["app"] diff --git a/tools/integration_tests/tokenserver/helpers.py b/tools/integration_tests/tokenserver/helpers.py new file mode 100644 index 00000000..71d6e5d1 --- /dev/null +++ b/tools/integration_tests/tokenserver/helpers.py @@ -0,0 +1,353 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this file, +# You can obtain one at http://mozilla.org/MPL/2.0/. +"""Helper functions and constants for tokenserver integration tests. + +These are plain module-level utilities — no pytest fixtures here. +Test files import directly from this module; conftest.py imports +whatever it needs to seed fixtures. +""" + +import binascii +import json +import math +import os +import time +import urllib.parse as urlparse +from base64 import urlsafe_b64encode as b64encode + +from sqlalchemy.sql import text as sqltext +from tokenlib.utils import decode_token_bytes +from webtest import TestApp + +DEFAULT_OAUTH_SCOPE = "https://identity.mozilla.com/apps/oldsync" + +NODE_ID = 800 +NODE_URL = "https://example.com" +FXA_EMAIL_DOMAIN = "api-accounts.stage.mozaws.net" +TOKEN_SIGNING_SECRET = os.environ.get("SYNC_MASTER_SECRET", "secret0") +FXA_METRICS_HASH_SECRET = os.environ.get("SYNC_MASTER_SECRET", "secret0") + + +# ── DB-mode helpers ─────────────────────────────────────────────────────────── + + +def get_db_mode() -> str: + """Derive db_mode from the SYNC_TOKENSERVER__DATABASE_URL env var.""" + return os.environ["SYNC_TOKENSERVER__DATABASE_URL"].split(":")[0] + + +def get_expected_node_type() -> str: + """Derive expected node_type from the SYNC_SYNCSTORAGE__DATABASE_URL env var.""" + syncstorage_url = os.environ.get("SYNC_SYNCSTORAGE__DATABASE_URL", "spanner://") + node_type = syncstorage_url.split(":")[0] + if node_type == "postgresql": + return "postgres" + if node_type.startswith("mysql"): + return "mysql" + return node_type + + +# ── SQL helpers ─────────────────────────────────────────────────────────────── + + +def execute_sql(conn, query, params=None): + """Execute a SQL statement and return the cursor.""" + return conn.execute(query, params or {}) + + +def clear_db(conn) -> None: + """Delete all users and nodes. + + Services are intentionally not cleared: tokenserver may have cached + its service_id and a DELETE would invalidate that cache mid-run. + """ + execute_sql(conn, sqltext("DELETE FROM users"), {}).close() + execute_sql(conn, sqltext("DELETE FROM nodes"), {}).close() + + +def get_service_id(conn, service: str): + """Return the ID for the given service name, or None if not found.""" + cursor = execute_sql( + conn, + sqltext("select id from services where service = :service"), + {"service": service}, + ) + row = cursor.fetchone() + cursor.close() + return None if row is None else row[0] + + +def add_service(conn, service: str, pattern: str) -> int: + """Insert a services row and return its ID.""" + db_mode = get_db_mode() + if db_mode == "postgres": + sql = sqltext( + "insert into services (service, pattern) values (:service, :pattern) RETURNING id" + ) + cursor = execute_sql(conn, sql, {"service": service, "pattern": pattern}) + result: int = cursor.fetchone()[0] + else: + sql = sqltext( + "insert into services (service, pattern) values (:service, :pattern)" + ) + cursor = execute_sql(conn, sql, {"service": service, "pattern": pattern}) + result = cursor.lastrowid + cursor.close() + return result + + +def get_or_add_service(conn, service: str, pattern: str) -> int: + """Return existing service ID, inserting a new row if it does not exist.""" + service_id = get_service_id(conn, service) + if service_id is not None: + return int(service_id) + return add_service(conn, service, pattern) + + +def add_node( + conn, + service_id: int, + capacity: int = 100, + available: int = 100, + node: str = NODE_URL, + id: int | None = None, + current_load: int = 0, + backoff: int = 0, + downed: int = 0, +) -> int: + """Insert a nodes row and return its ID.""" + db_mode = get_db_mode() + params = { + "service": service_id, + "node": node, + "available": available, + "capacity": capacity, + "current_load": current_load, + "backoff": backoff, + "downed": downed, + } + if id is not None: + params["id"] = id + cols = "service, node, available, capacity, current_load, backoff, downed, id" + vals = ":service, :node, :available, :capacity, :current_load, :backoff, :downed, :id" + else: + cols = "service, node, available, capacity, current_load, backoff, downed" + vals = ( + ":service, :node, :available, :capacity, :current_load, :backoff, :downed" + ) + + result: int + if db_mode == "postgres": + sql = sqltext(f"insert into nodes ({cols}) values ({vals}) RETURNING id") # nosec B608 - cols/vals are hardcoded literals, not user input + cursor = execute_sql(conn, sql, params) + result = cursor.fetchone()[0] + else: + sql = sqltext(f"insert into nodes ({cols}) values ({vals})") # nosec B608 + cursor = execute_sql(conn, sql, params) + result = cursor.lastrowid + cursor.close() + return result + + +def get_node(conn, node_id: int) -> dict: + """Return a node dict by ID.""" + cursor = execute_sql( + conn, sqltext("select * from nodes where id = :id"), {"id": node_id} + ) + (id_, service, node, available, current_load, capacity, downed, backoff) = ( + cursor.fetchone() + ) + cursor.close() + return { + "id": id_, + "service": service, + "node": node, + "available": available, + "current_load": current_load, + "capacity": capacity, + "downed": downed, + "backoff": backoff, + } + + +def add_user( + conn, + service_id: int, + email: str | None = None, + nodeid: int = NODE_ID, + generation: int = 1234, + keys_changed_at: int | None = 1234, + client_state: str = "aaaa", + created_at: int | None = None, + replaced_at: int | None = None, +) -> int: + """Insert a users row and return its uid.""" + db_mode = get_db_mode() + created_at = created_at or math.trunc(time.time() * 1000) + params = { + "service": service_id, + "email": email or f"test@{FXA_EMAIL_DOMAIN}", + "nodeid": nodeid, + "generation": generation, + "keys_changed_at": keys_changed_at, + "client_state": client_state, + "created_at": created_at, + "replaced_at": replaced_at, + } + result: int + if db_mode == "postgres": + sql = sqltext("""\ + insert into users + (service, email, nodeid, generation, keys_changed_at, + client_state, created_at, replaced_at) + values + (:service, :email, :nodeid, :generation, :keys_changed_at, + :client_state, :created_at, :replaced_at) + RETURNING uid + """) + cursor = execute_sql(conn, sql, params) + result = cursor.fetchone()[0] + else: + sql = sqltext("""\ + insert into users + (service, email, nodeid, generation, keys_changed_at, + client_state, created_at, replaced_at) + values + (:service, :email, :nodeid, :generation, :keys_changed_at, + :client_state, :created_at, :replaced_at) + """) + cursor = execute_sql(conn, sql, params) + result = cursor.lastrowid + cursor.close() + return result + + +def get_user(conn, uid: int) -> dict: + """Return a user dict by uid.""" + cursor = execute_sql( + conn, sqltext("select * from users where uid = :uid"), {"uid": uid} + ) + ( + uid, + service, + email, + generation, + client_state, + created_at, + replaced_at, + nodeid, + keys_changed_at, + ) = cursor.fetchone() + cursor.close() + return { + "uid": uid, + "service": service, + "email": email, + "generation": generation, + "client_state": client_state, + "created_at": created_at, + "replaced_at": replaced_at, + "nodeid": nodeid, + "keys_changed_at": keys_changed_at, + } + + +def get_replaced_users(conn, service_id: int, email: str) -> list: + """Return a list of user dicts for records with a non-null replaced_at.""" + cursor = execute_sql( + conn, + sqltext("""\ + select * from users + where service = :service + and email = :email + and replaced_at is not null + """), + {"service": service_id, "email": email}, + ) + users = [] + for row in cursor.fetchall(): + ( + uid, + service, + email, + generation, + client_state, + created_at, + replaced_at, + nodeid, + keys_changed_at, + ) = row + users.append( + { + "uid": uid, + "service": service, + "email": email, + "generation": generation, + "client_state": client_state, + "created_at": created_at, + "replaced_at": replaced_at, + "nodeid": nodeid, + "keys_changed_at": keys_changed_at, + } + ) + cursor.close() + return users + + +def count_users(conn) -> int: + """Return the count of distinct user UIDs.""" + cursor = execute_sql(conn, sqltext("select COUNT(DISTINCT(uid)) from users"), {}) + (count,) = cursor.fetchone() + cursor.close() + return int(count) + + +# ── Auth helpers ────────────────────────────────────────────────────────────── + + +def build_oauth_headers( + generation: int | None = None, + user: str = "test", + keys_changed_at: int | None = None, + client_state: str | None = None, + status: int = 200, + **additional_headers: str, +) -> dict: + """Build OAuth Bearer + X-KeyID headers for a test request.""" + claims = { + "user": user, + "generation": generation, + "client_id": "fake client id", + "scope": [DEFAULT_OAUTH_SCOPE], + } + if generation is not None: + claims["generation"] = generation + body = {"body": claims, "status": status} + headers = {} + headers["Authorization"] = f"Bearer {json.dumps(body)}" + client_state_bytes = binascii.unhexlify(client_state or "") + client_state_b64 = b64encode(client_state_bytes).strip(b"=").decode("utf-8") + headers["X-KeyID"] = f"{keys_changed_at}-{client_state_b64}" + headers.update(additional_headers) + return headers + + +def make_app(host: str) -> TestApp: + """Build a WebTest TestApp pointing at the given host URL.""" + host_url = urlparse.urlparse(host) + return TestApp( + host, + extra_environ={ + "HTTP_HOST": host_url.netloc, + "wsgi.url_scheme": host_url.scheme or "http", + "SERVER_NAME": host_url.hostname, + "REMOTE_ADDR": "127.0.0.1", + "SCRIPT_NAME": host_url.path, + }, + ) + + +def unsafe_parse_token(token: str) -> dict: + """Parse a tokenlib token without verifying its HMAC signature.""" + return json.loads(decode_token_bytes(token)[:-32].decode("utf8")) # type: ignore[no-any-return] diff --git a/tools/integration_tests/tokenserver/test_authorization.py b/tools/integration_tests/tokenserver/test_authorization.py index 33fc3a69..b0d99f59 100644 --- a/tools/integration_tests/tokenserver/test_authorization.py +++ b/tools/integration_tests/tokenserver/test_authorization.py @@ -3,7 +3,7 @@ # You can obtain one at http://mozilla.org/MPL/2.0/. """Authorization integration tests for the tokenserver.""" -from integration_tests.tokenserver.conftest import ( +from integration_tests.tokenserver.helpers import ( add_user, build_oauth_headers, get_user, diff --git a/tools/integration_tests/tokenserver/test_e2e.py b/tools/integration_tests/tokenserver/test_e2e.py index 64361bc2..705b3184 100644 --- a/tools/integration_tests/tokenserver/test_e2e.py +++ b/tools/integration_tests/tokenserver/test_e2e.py @@ -15,7 +15,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -from integration_tests.tokenserver.conftest import ( +from integration_tests.tokenserver.helpers import ( FXA_METRICS_HASH_SECRET, TOKEN_SIGNING_SECRET, unsafe_parse_token, diff --git a/tools/integration_tests/tokenserver/test_misc.py b/tools/integration_tests/tokenserver/test_misc.py index 0a3e9c9f..7b4c49c6 100644 --- a/tools/integration_tests/tokenserver/test_misc.py +++ b/tools/integration_tests/tokenserver/test_misc.py @@ -3,7 +3,7 @@ # You can obtain one at http://mozilla.org/MPL/2.0/. """Miscellaneous integration tests for the tokenserver.""" -from integration_tests.tokenserver.conftest import ( +from integration_tests.tokenserver.helpers import ( FXA_EMAIL_DOMAIN, NODE_ID, add_user, diff --git a/tools/integration_tests/tokenserver/test_node_assignment.py b/tools/integration_tests/tokenserver/test_node_assignment.py index f4fa7fb9..55de8d75 100644 --- a/tools/integration_tests/tokenserver/test_node_assignment.py +++ b/tools/integration_tests/tokenserver/test_node_assignment.py @@ -3,7 +3,7 @@ # You can obtain one at http://mozilla.org/MPL/2.0/. """Node assignment integration tests for the tokenserver.""" -from integration_tests.tokenserver.conftest import ( +from integration_tests.tokenserver.helpers import ( NODE_ID, add_node, build_oauth_headers,