refactor(assets): database layer — split queries into modules and merge migrations

- Split monolithic queries.py into modular query modules (asset, asset_reference, common, tags)
- Absorb bulk_ops.py and tags.py into query modules
- Merge migrations 0002-0005 into single migration (0002_merge_to_asset_references)
- Update models.py (merge AssetInfo/AssetCacheState into AssetReference)
- Enable SQLite foreign key enforcement
- Add comprehensive query-layer tests

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019c917d-82b5-7448-a04f-9cd59c69d0a2
This commit is contained in:
Luke Mino-Altherr 2026-02-24 11:58:54 -08:00
parent 3ebe1ac22e
commit 3965aca3e6
19 changed files with 3792 additions and 1336 deletions

View File

@ -0,0 +1,264 @@
"""
Merge AssetInfo and AssetCacheState into unified asset_references table.
This migration drops old tables and creates the new unified schema.
All existing data is discarded.
Revision ID: 0002_merge_to_asset_references
Revises: 0001_assets
Create Date: 2025-02-11
"""
from alembic import op
import sqlalchemy as sa
revision = "0002_merge_to_asset_references"
down_revision = "0001_assets"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop old tables (order matters due to FK constraints)
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
# Truncate assets table (cascades handled by dropping dependent tables first)
op.execute("DELETE FROM assets")
# Create asset_references table
op.create_table(
"asset_references",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column(
"asset_id",
sa.String(length=36),
sa.ForeignKey("assets.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column(
"needs_verify",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
sa.Column(
"is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column(
"preview_id",
sa.String(length=36),
sa.ForeignKey("assets.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
sa.CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
)
op.create_index(
"uq_asset_references_file_path", "asset_references", ["file_path"], unique=True
)
op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"])
op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"])
op.create_index("ix_asset_references_name", "asset_references", ["name"])
op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"])
op.create_index(
"ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"]
)
op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"])
op.create_index(
"ix_asset_references_last_access_time", "asset_references", ["last_access_time"]
)
op.create_index(
"ix_asset_references_owner_name", "asset_references", ["owner_id", "name"]
)
# Create asset_reference_tags table
op.create_table(
"asset_reference_tags",
sa.Column(
"asset_reference_id",
sa.String(length=36),
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"tag_name",
sa.String(length=512),
sa.ForeignKey("tags.name", ondelete="RESTRICT"),
nullable=False,
),
sa.Column(
"origin", sa.String(length=32), nullable=False, server_default="manual"
),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint(
"asset_reference_id", "tag_name", name="pk_asset_reference_tags"
),
)
op.create_index(
"ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"]
)
op.create_index(
"ix_asset_reference_tags_asset_reference_id",
"asset_reference_tags",
["asset_reference_id"],
)
# Create asset_reference_meta table
op.create_table(
"asset_reference_meta",
sa.Column(
"asset_reference_id",
sa.String(length=36),
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint(
"asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta"
),
)
op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"])
op.create_index(
"ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"]
)
op.create_index(
"ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"]
)
op.create_index(
"ix_asset_reference_meta_key_val_bool",
"asset_reference_meta",
["key", "val_bool"],
)
def downgrade() -> None:
"""Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema.
NOTE: Data is not recoverable. The upgrade discards all rows from the old
tables and truncates assets. After downgrade the old schema will be empty.
A filesystem rescan will repopulate data once the older code is running.
"""
# Drop new tables (order matters due to FK constraints)
op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta")
op.drop_table("asset_reference_meta")
op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags")
op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags")
op.drop_table("asset_reference_tags")
op.drop_index("ix_asset_references_owner_name", table_name="asset_references")
op.drop_index("ix_asset_references_last_access_time", table_name="asset_references")
op.drop_index("ix_asset_references_created_at", table_name="asset_references")
op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references")
op.drop_index("ix_asset_references_is_missing", table_name="asset_references")
op.drop_index("ix_asset_references_name", table_name="asset_references")
op.drop_index("ix_asset_references_owner_id", table_name="asset_references")
op.drop_index("ix_asset_references_asset_id", table_name="asset_references")
op.drop_index("uq_asset_references_file_path", table_name="asset_references")
op.drop_table("asset_references")
# Truncate assets (upgrade deleted all rows; downgrade starts fresh too)
op.execute("DELETE FROM assets")
# Recreate old tables from 0001_assets schema
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])

View File

@ -1,204 +0,0 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@ -2,8 +2,8 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
@ -16,48 +16,43 @@ from sqlalchemy import (
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
from app.assets.helpers import get_utc_now
from app.database.models import Base, to_dict
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
DateTime(timezone=False), nullable=False, default=get_utc_now
)
infos: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
references: Mapped[list[AssetReference]] = relationship(
"AssetReference",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id),
foreign_keys=lambda: [AssetReference.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
preview_of: Mapped[list[AssetReference]] = relationship(
"AssetReference",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
foreign_keys=lambda: [AssetReference.preview_id],
viewonly=True,
)
cache_states: Mapped[list[AssetCacheState]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
@ -71,47 +66,52 @@ class Asset(Base):
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
class AssetReference(Base):
"""Unified model combining file cache state and user-facing metadata.
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
Each row represents either:
- A filesystem reference (file_path is set) with cache state
- An API-created reference (file_path is NULL) without cache state
"""
asset: Mapped[Asset] = relationship(back_populates="cache_states")
__tablename__ = "asset_references"
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
# Cache state fields (from former AssetCacheState)
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
# Info fields (from former AssetInfo)
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL")
)
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
back_populates="references",
foreign_keys=[asset_id],
lazy="selectin",
)
@ -121,35 +121,44 @@ class AssetInfo(Base):
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
back_populates="asset_info",
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
back_populates="asset_reference",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="asset_info",
tag_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="asset_reference",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
overlaps="tags,asset_references",
)
tags: Mapped[list[Tag]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
secondary="asset_reference_tags",
back_populates="asset_references",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
overlaps="tag_links,asset_reference_links,asset_references,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
Index("uq_asset_references_file_path", "file_path", unique=True),
Index("ix_asset_references_asset_id", "asset_id"),
Index("ix_asset_references_owner_id", "owner_id"),
Index("ix_asset_references_name", "name"),
Index("ix_asset_references_is_missing", "is_missing"),
Index("ix_asset_references_enrichment_level", "enrichment_level"),
Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
@ -158,14 +167,17 @@ class AssetInfo(Base):
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
path_part = f" path={self.file_path!r}" if self.file_path else ""
return f"<AssetReference id={self.id} name={self.name!r}{path_part}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
class AssetReferenceMeta(Base):
__tablename__ = "asset_reference_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
asset_reference_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
@ -175,36 +187,40 @@ class AssetInfoMeta(Base):
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
asset_reference: Mapped[AssetReference] = relationship(
back_populates="metadata_entries"
)
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
Index("ix_asset_reference_meta_key", "key"),
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
class AssetReferenceTag(Base):
__tablename__ = "asset_reference_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
asset_reference_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_reference_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
Index("ix_asset_reference_tags_tag_name", "tag_name"),
Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"),
)
@ -214,20 +230,18 @@ class Tag(Base):
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
overlaps="asset_references,tags",
)
asset_infos: Mapped[list[AssetInfo]] = relationship(
secondary="asset_info_tags",
asset_references: Mapped[list[AssetReference]] = relationship(
secondary="asset_reference_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
overlaps="asset_reference_links,tag_links,tags,asset_reference",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@ -1,976 +0,0 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@ -0,0 +1,117 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
get_asset_by_hash,
get_existing_asset_ids,
reassign_asset_references,
update_asset_hash_and_mime,
upsert_asset,
)
from app.assets.database.queries.asset_reference import (
CacheStateRow,
UnenrichedReferenceRow,
bulk_insert_references_ignore_conflicts,
bulk_update_enrichment_level,
bulk_update_is_missing,
bulk_update_needs_verify,
convert_metadata_to_rows,
delete_assets_by_ids,
delete_orphaned_seed_asset,
delete_reference_by_id,
delete_references_by_ids,
fetch_reference_and_asset,
fetch_reference_asset_and_tags,
get_or_create_reference,
get_reference_by_file_path,
get_reference_by_id,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_references_for_prefixes,
get_unenriched_references,
get_unreferenced_unhashed_asset_ids,
insert_reference,
list_references_by_asset_id,
list_references_page,
mark_references_missing_outside_prefixes,
reference_exists_for_asset_id,
restore_references_by_paths,
set_reference_metadata,
set_reference_preview,
update_enrichment_level,
update_reference_access_time,
update_reference_name,
update_reference_timestamps,
update_reference_updated_at,
upsert_reference,
)
from app.assets.database.queries.tags import (
AddTagsDict,
RemoveTagsDict,
SetTagsDict,
add_missing_tag_for_asset_id,
add_tags_to_reference,
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_reference_tags,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
set_reference_tags,
)
__all__ = [
"AddTagsDict",
"CacheStateRow",
"RemoveTagsDict",
"SetTagsDict",
"UnenrichedReferenceRow",
"add_missing_tag_for_asset_id",
"add_tags_to_reference",
"asset_exists_by_hash",
"bulk_insert_assets",
"bulk_insert_references_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_update_enrichment_level",
"bulk_update_is_missing",
"bulk_update_needs_verify",
"convert_metadata_to_rows",
"delete_assets_by_ids",
"delete_orphaned_seed_asset",
"delete_reference_by_id",
"delete_references_by_ids",
"ensure_tags_exist",
"fetch_reference_and_asset",
"fetch_reference_asset_and_tags",
"get_asset_by_hash",
"get_existing_asset_ids",
"get_or_create_reference",
"get_reference_by_file_path",
"get_reference_by_id",
"get_reference_ids_by_ids",
"get_reference_tags",
"get_references_by_paths_and_asset_ids",
"get_references_for_prefixes",
"get_unenriched_references",
"get_unreferenced_unhashed_asset_ids",
"insert_reference",
"list_references_by_asset_id",
"list_references_page",
"list_tags_with_usage",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",
"reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id",
"remove_tags_from_reference",
"restore_references_by_paths",
"set_reference_metadata",
"set_reference_preview",
"set_reference_tags",
"update_asset_hash_and_mime",
"update_enrichment_level",
"update_reference_access_time",
"update_reference_name",
"update_reference_timestamps",
"update_reference_updated_at",
"upsert_asset",
"upsert_reference",
]

View File

@ -0,0 +1,140 @@
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
def asset_exists_by_hash(
session: Session,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True))
.select_from(Asset)
.where(Asset.hash == asset_hash)
.limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
asset_hash: str,
) -> Asset | None:
return (
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
.scalars()
.first()
)
def upsert_asset(
session: Session,
asset_hash: str,
size_bytes: int,
mime_type: str | None = None,
) -> tuple[Asset, bool, bool]:
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
if mime_type:
vals["mime_type"] = mime_type
ins = (
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
.scalars()
.first()
)
if not asset:
raise RuntimeError("Asset row not found after upsert.")
updated = False
if not created:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
updated = True
return asset, created, updated
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
if not rows:
return
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
session.execute(ins, chunk)
def get_existing_asset_ids(
session: Session,
asset_ids: list[str],
) -> set[str]:
"""Return the subset of asset_ids that exist in the database."""
if not asset_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
rows = session.execute(
select(Asset.id).where(Asset.id.in_(chunk))
).fetchall()
found.update(row[0] for row in rows)
return found
def update_asset_hash_and_mime(
session: Session,
asset_id: str,
asset_hash: str | None = None,
mime_type: str | None = None,
) -> bool:
"""Update asset hash and/or mime_type. Returns True if asset was found."""
asset = session.get(Asset, asset_id)
if not asset:
return False
if asset_hash is not None:
asset.hash = asset_hash
if mime_type is not None:
asset.mime_type = mime_type
return True
def reassign_asset_references(
session: Session,
from_asset_id: str,
to_asset_id: str,
reference_id: str,
) -> None:
"""Reassign a reference from one asset to another.
Used when merging a stub asset into an existing asset with the same hash.
"""
ref = session.get(AssetReference, reference_id)
if ref:
ref.asset_id = to_asset_id
session.flush()

View File

@ -0,0 +1,992 @@
"""Query functions for the unified AssetReference table.
This module replaces the separate asset_info.py and cache_state.py query modules,
providing a unified interface for the merged asset_references table.
"""
import os
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from typing import NamedTuple, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, exists, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, noload
from app.assets.database.models import (
Asset,
AssetReference,
AssetReferenceMeta,
AssetReferenceTag,
Tag,
)
from app.assets.database.queries.common import (
MAX_BIND_PARAMS,
build_visible_owner_clause,
calculate_rows_per_statement,
iter_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
def _check_is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row."""
if value is None:
return {
"key": key,
"ordinal": ordinal,
"val_str": None,
"val_num": None,
"val_bool": None,
"val_json": None,
}
if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return {"key": key, "ordinal": ordinal, "val_num": num}
if isinstance(value, str):
return {"key": key, "ordinal": ordinal, "val_str": value}
return {"key": key, "ordinal": ordinal, "val_json": value}
def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""Turn a metadata key/value into typed projection rows."""
if value is None:
return [_scalar_to_row(key, 0, None)]
if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)]
if isinstance(value, list):
if all(_check_is_scalar(x) for x in value):
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetReferenceMeta.val_json.is_(None),
AssetReferenceMeta.val_str.is_(None),
AssetReferenceMeta.val_num.is_(None),
AssetReferenceMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def get_reference_by_id(
session: Session,
reference_id: str,
) -> AssetReference | None:
return session.get(AssetReference, reference_id)
def get_reference_by_file_path(
session: Session,
file_path: str,
) -> AssetReference | None:
"""Get a reference by its file path."""
return (
session.execute(
select(AssetReference).where(AssetReference.file_path == file_path).limit(1)
)
.scalars()
.first()
)
def reference_exists_for_asset_id(
session: Session,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetReference)
.where(AssetReference.asset_id == asset_id)
.limit(1)
)
return session.execute(q).first() is not None
def insert_reference(
session: Session,
asset_id: str,
name: str,
owner_id: str = "",
file_path: str | None = None,
mtime_ns: int | None = None,
preview_id: str | None = None,
) -> AssetReference | None:
"""Insert a new AssetReference. Returns None if unique constraint violated."""
now = get_utc_now()
try:
with session.begin_nested():
ref = AssetReference(
asset_id=asset_id,
name=name,
owner_id=owner_id,
file_path=file_path,
mtime_ns=mtime_ns,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
except IntegrityError:
return None
def get_or_create_reference(
session: Session,
asset_id: str,
name: str,
owner_id: str = "",
file_path: str | None = None,
mtime_ns: int | None = None,
preview_id: str | None = None,
) -> tuple[AssetReference, bool]:
"""Get existing or create new AssetReference.
For filesystem references (file_path is set), uniqueness is by file_path.
For API references (file_path is None), we look for matching
asset_id + owner_id + name.
Returns (reference, created).
"""
ref = insert_reference(
session,
asset_id=asset_id,
name=name,
owner_id=owner_id,
file_path=file_path,
mtime_ns=mtime_ns,
preview_id=preview_id,
)
if ref:
return ref, True
# Find existing - priority to file_path match, then name match
if file_path:
existing = get_reference_by_file_path(session, file_path)
else:
existing = (
session.execute(
select(AssetReference)
.where(
AssetReference.asset_id == asset_id,
AssetReference.name == name,
AssetReference.owner_id == owner_id,
AssetReference.file_path.is_(None),
)
.limit(1)
)
.unique()
.scalar_one_or_none()
)
if not existing:
raise RuntimeError("Failed to find AssetReference after insert conflict.")
return existing, False
def update_reference_timestamps(
session: Session,
reference: AssetReference,
preview_id: str | None = None,
) -> None:
"""Update timestamps and optionally preview_id on existing AssetReference."""
now = get_utc_now()
if preview_id and reference.preview_id != preview_id:
reference.preview_id = preview_id
reference.updated_at = now
def list_references_page(
session: Session,
owner_id: str = "",
limit: int = 100,
offset: int = 0,
name_contains: str | None = None,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
metadata_filter: dict | None = None,
sort: str | None = None,
order: str | None = None,
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
"""List references with pagination, filtering, and sorting.
Returns (references, tag_map, total_count).
"""
base = (
select(AssetReference)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.options(noload(AssetReference.tags))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetReference.name,
"created_at": AssetReference.created_at,
"updated_at": AssetReference.updated_at,
"last_access_time": AssetReference.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetReference.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetReference)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
)
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = int(session.execute(count_stmt).scalar_one() or 0)
refs = session.execute(base).unique().scalars().all()
id_list: list[str] = [r.id for r in refs]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
.order_by(AssetReferenceTag.added_at)
)
for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name)
return list(refs), tag_map, total
def fetch_reference_asset_and_tags(
session: Session,
reference_id: str,
owner_id: str = "",
) -> tuple[AssetReference, Asset, list[str]] | None:
stmt = (
select(AssetReference, Asset, Tag.name)
.join(Asset, Asset.id == AssetReference.asset_id)
.join(
AssetReferenceTag,
AssetReferenceTag.asset_reference_id == AssetReference.id,
isouter=True,
)
.join(Tag, Tag.name == AssetReferenceTag.tag_name, isouter=True)
.where(
AssetReference.id == reference_id,
build_visible_owner_clause(owner_id),
)
.options(noload(AssetReference.tags))
.order_by(Tag.name.asc())
)
rows = session.execute(stmt).all()
if not rows:
return None
first_ref, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _ref, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_ref, first_asset, tags
def fetch_reference_and_asset(
session: Session,
reference_id: str,
owner_id: str = "",
) -> tuple[AssetReference, Asset] | None:
stmt = (
select(AssetReference, Asset)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(
AssetReference.id == reference_id,
build_visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetReference.tags))
)
pair = session.execute(stmt).first()
if not pair:
return None
return pair[0], pair[1]
def update_reference_access_time(
session: Session,
reference_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or get_utc_now()
stmt = sa.update(AssetReference).where(AssetReference.id == reference_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetReference.last_access_time.is_(None),
AssetReference.last_access_time < ts,
)
)
session.execute(stmt.values(last_access_time=ts))
def update_reference_name(
session: Session,
reference_id: str,
name: str,
) -> None:
"""Update the name of an AssetReference."""
now = get_utc_now()
session.execute(
sa.update(AssetReference)
.where(AssetReference.id == reference_id)
.values(name=name, updated_at=now)
)
def update_reference_updated_at(
session: Session,
reference_id: str,
ts: datetime | None = None,
) -> None:
"""Update the updated_at timestamp of an AssetReference."""
ts = ts or get_utc_now()
session.execute(
sa.update(AssetReference)
.where(AssetReference.id == reference_id)
.values(updated_at=ts)
)
def set_reference_metadata(
session: Session,
reference_id: str,
user_metadata: dict | None = None,
) -> None:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
ref.user_metadata = user_metadata or {}
ref.updated_at = get_utc_now()
session.flush()
session.execute(
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == reference_id
)
)
session.flush()
if not user_metadata:
return
rows: list[AssetReferenceMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=reference_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_reference_by_id(
session: Session,
reference_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetReference).where(
AssetReference.id == reference_id,
build_visible_owner_clause(owner_id),
)
return int(session.execute(stmt).rowcount or 0) > 0
def set_reference_preview(
session: Session,
reference_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
if preview_asset_id is None:
ref.preview_id = None
else:
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
ref.preview_id = preview_asset_id
ref.updated_at = get_utc_now()
session.flush()
class CacheStateRow(NamedTuple):
"""Row from reference query with cache state data."""
reference_id: str
file_path: str
mtime_ns: int | None
needs_verify: bool
asset_id: str
asset_hash: str | None
size_bytes: int
def list_references_by_asset_id(
session: Session,
asset_id: str,
) -> Sequence[AssetReference]:
return (
session.execute(
select(AssetReference)
.where(AssetReference.asset_id == asset_id)
.order_by(AssetReference.id.asc())
)
.scalars()
.all()
)
def upsert_reference(
session: Session,
asset_id: str,
file_path: str,
name: str,
mtime_ns: int,
owner_id: str = "",
) -> tuple[bool, bool]:
"""Upsert a reference by file_path. Returns (created, updated).
Also restores references that were previously marked as missing.
"""
now = get_utc_now()
vals = {
"asset_id": asset_id,
"file_path": file_path,
"name": name,
"owner_id": owner_id,
"mtime_ns": int(mtime_ns),
"is_missing": False,
"created_at": now,
"updated_at": now,
"last_access_time": now,
}
ins = (
sqlite.insert(AssetReference)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetReference.file_path])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
if created:
return True, False
upd = (
sa.update(AssetReference)
.where(AssetReference.file_path == file_path)
.where(
sa.or_(
AssetReference.asset_id != asset_id,
AssetReference.mtime_ns.is_(None),
AssetReference.mtime_ns != int(mtime_ns),
AssetReference.is_missing == True, # noqa: E712
)
)
.values(
asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, updated_at=now
)
)
res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0
return False, updated
def mark_references_missing_outside_prefixes(
session: Session,
valid_prefixes: list[str],
) -> int:
"""Mark references as missing when file_path doesn't match any valid prefix.
Returns number of references marked as missing.
"""
if not valid_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_sql_like_string(base)
return AssetReference.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.file_path.isnot(None))
.where(~matches_valid_prefix)
.where(AssetReference.is_missing == False) # noqa: E712
.values(is_missing=True)
)
return result.rowcount
def restore_references_by_paths(session: Session, file_paths: list[str]) -> int:
"""Restore references that were previously marked as missing.
Returns number of references restored.
"""
if not file_paths:
return 0
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.file_path.in_(file_paths))
.where(AssetReference.is_missing == True) # noqa: E712
.values(is_missing=False)
)
return result.rowcount
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
"""Get IDs of unhashed assets (hash=None) with no active references.
An asset is considered unreferenced if it has no references,
or all its references are marked as missing.
Returns list of asset IDs that are unreferenced.
"""
active_ref_exists = (
sa.select(sa.literal(1))
.where(AssetReference.asset_id == Asset.id)
.where(AssetReference.is_missing == False) # noqa: E712
.correlate(Asset)
.exists()
)
unreferenced_subq = sa.select(Asset.id).where(
Asset.hash.is_(None), ~active_ref_exists
)
return [row[0] for row in session.execute(unreferenced_subq).all()]
def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""Delete assets and their references by ID.
Returns number of assets deleted.
"""
if not asset_ids:
return 0
session.execute(
sa.delete(AssetReference).where(AssetReference.asset_id.in_(asset_ids))
)
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
def get_references_for_prefixes(
session: Session,
prefixes: list[str],
*,
include_missing: bool = False,
) -> list[CacheStateRow]:
"""Get all references with file paths matching any of the given prefixes.
Args:
session: Database session
prefixes: List of absolute directory prefixes to match
include_missing: If False (default), exclude references marked as missing
Returns:
List of cache state rows with joined asset data
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
query = (
sa.select(
AssetReference.id,
AssetReference.file_path,
AssetReference.mtime_ns,
AssetReference.needs_verify,
AssetReference.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(AssetReference.file_path.isnot(None))
.where(sa.or_(*conds))
)
if not include_missing:
query = query.where(AssetReference.is_missing == False) # noqa: E712
rows = session.execute(
query.order_by(AssetReference.asset_id.asc(), AssetReference.id.asc())
).all()
return [
CacheStateRow(
reference_id=row[0],
file_path=row[1],
mtime_ns=row[2],
needs_verify=row[3],
asset_id=row[4],
asset_hash=row[5],
size_bytes=int(row[6] or 0),
)
for row in rows
]
def bulk_update_needs_verify(
session: Session, reference_ids: list[str], value: bool
) -> int:
"""Set needs_verify flag for multiple references.
Returns: Number of rows updated
"""
if not reference_ids:
return 0
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.id.in_(reference_ids))
.values(needs_verify=value)
)
return result.rowcount
def bulk_update_is_missing(
session: Session, reference_ids: list[str], value: bool
) -> int:
"""Set is_missing flag for multiple references.
Returns: Number of rows updated
"""
if not reference_ids:
return 0
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.id.in_(reference_ids))
.values(is_missing=value)
)
return result.rowcount
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
"""Delete references by their IDs.
Returns: Number of rows deleted
"""
if not reference_ids:
return 0
result = session.execute(
sa.delete(AssetReference).where(AssetReference.id.in_(reference_ids))
)
return result.rowcount
def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
"""Delete a seed asset (hash is None) and its references.
Returns: True if asset was deleted, False if not found
"""
session.execute(
sa.delete(AssetReference).where(AssetReference.asset_id == asset_id)
)
asset = session.get(Asset, asset_id)
if asset:
session.delete(asset)
return True
return False
class UnenrichedReferenceRow(NamedTuple):
"""Row for references needing enrichment."""
reference_id: str
asset_id: str
file_path: str
enrichment_level: int
def get_unenriched_references(
session: Session,
prefixes: list[str],
max_level: int = 0,
limit: int = 1000,
) -> list[UnenrichedReferenceRow]:
"""Get references that need enrichment (enrichment_level <= max_level).
Args:
session: Database session
prefixes: List of absolute directory prefixes to scan
max_level: Maximum enrichment level to include (0=stubs, 1=metadata done)
limit: Maximum number of rows to return
Returns:
List of unenriched reference rows with file paths
"""
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
query = (
sa.select(
AssetReference.id,
AssetReference.asset_id,
AssetReference.file_path,
AssetReference.enrichment_level,
)
.where(AssetReference.file_path.isnot(None))
.where(sa.or_(*conds))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.enrichment_level <= max_level)
.order_by(AssetReference.id.asc())
.limit(limit)
)
rows = session.execute(query).all()
return [
UnenrichedReferenceRow(
reference_id=row[0],
asset_id=row[1],
file_path=row[2],
enrichment_level=row[3],
)
for row in rows
]
def update_enrichment_level(
session: Session,
reference_id: str,
level: int,
) -> None:
"""Update the enrichment level for a reference."""
session.execute(
sa.update(AssetReference)
.where(AssetReference.id == reference_id)
.values(enrichment_level=level)
)
def bulk_update_enrichment_level(
session: Session,
reference_ids: list[str],
level: int,
) -> int:
"""Update enrichment level for multiple references.
Returns: Number of rows updated
"""
if not reference_ids:
return 0
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.id.in_(reference_ids))
.values(enrichment_level=level)
)
return result.rowcount
def bulk_insert_references_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert reference rows with ON CONFLICT DO NOTHING on file_path.
Each dict should have: id, asset_id, file_path, name, owner_id, mtime_ns, etc.
The is_missing field is automatically set to False for new inserts.
"""
if not rows:
return
enriched_rows = [{**row, "is_missing": False} for row in rows]
ins = sqlite.insert(AssetReference).on_conflict_do_nothing(
index_elements=[AssetReference.file_path]
)
for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(14)):
session.execute(ins, chunk)
def get_references_by_paths_and_asset_ids(
session: Session,
path_to_asset: dict[str, str],
) -> set[str]:
"""Query references to find paths where our asset_id won the insert.
Args:
path_to_asset: Mapping of file_path -> asset_id we tried to insert
Returns:
Set of file_paths where our asset_id is present
"""
if not path_to_asset:
return set()
pairs = list(path_to_asset.items())
winners: set[str] = set()
# Each pair uses 2 bind params, so chunk at MAX_BIND_PARAMS // 2
for chunk in iter_chunks(pairs, MAX_BIND_PARAMS // 2):
pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_(
chunk
)
result = session.execute(
select(AssetReference.file_path).where(pairwise)
)
winners.update(result.scalars().all())
return winners
def get_reference_ids_by_ids(
session: Session,
reference_ids: list[str],
) -> set[str]:
"""Query to find which reference IDs exist in the database."""
if not reference_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS):
result = session.execute(
select(AssetReference.id).where(AssetReference.id.in_(chunk))
)
found.update(result.scalars().all())
return found

View File

@ -0,0 +1,40 @@
"""Shared utilities for database query modules."""
from typing import Iterable
import sqlalchemy as sa
from app.assets.database.models import AssetReference
MAX_BIND_PARAMS = 800
def calculate_rows_per_statement(cols: int) -> int:
"""Calculate how many rows can fit in one statement given column count."""
return max(1, MAX_BIND_PARAMS // max(1, cols))
def iter_chunks(seq, n: int):
"""Yield successive n-sized chunks from seq."""
for i in range(0, len(seq), n):
yield seq[i : i + n]
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
"""Yield chunks of rows sized to fit within bind param limits."""
if not rows:
return
rows_per_stmt = calculate_rows_per_statement(cols_per_row)
for i in range(0, len(rows), rows_per_stmt):
yield rows[i : i + rows_per_stmt]
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads.
Owner-less rows are visible to everyone.
"""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetReference.owner_id == ""
return AssetReference.owner_id.in_(["", owner_id])

View File

@ -0,0 +1,366 @@
from typing import Iterable, Sequence, TypedDict
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import (
AssetReference,
AssetReferenceMeta,
AssetReferenceTag,
Tag,
)
from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
class AddTagsDict(TypedDict):
added: list[str]
already_present: list[str]
total_tags: list[str]
class RemoveTagsDict(TypedDict):
removed: list[str]
not_present: list[str]
total_tags: list[str]
class SetTagsDict(TypedDict):
added: list[str]
removed: list[str]
total: list[str]
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_reference_tags(session: Session, reference_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
session.execute(
select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
]
def set_reference_tags(
session: Session,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsDict:
desired = normalize_tags(tags)
current = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
if to_remove:
session.execute(
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def add_tags_to_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
reference_row: AssetReference | None = None,
) -> AddTagsDict:
if not reference_row:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_reference_tags(session, reference_id=reference_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
) -> RemoveTagsDict:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_reference_tags(session, reference_id=reference_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def add_missing_tag_for_asset_id(
session: Session,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetReference.id.label("asset_reference_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(get_utc_now()).label("added_at"),
)
.where(AssetReference.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == "missing")
)
)
)
)
session.execute(
sqlite.insert(AssetReferenceTag)
.from_select(
["asset_reference_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
)
def remove_missing_tag_for_asset_id(
session: Session,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id.in_(
sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id)
),
AssetReferenceTag.tag_name == "missing",
)
)
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetReferenceTag.tag_name.label("tag_name"),
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.group_by(AssetReferenceTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
visible_tags_sq = (
select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.group_by(AssetReferenceTag.tag_name)
)
total_q = total_q.where(Tag.name.in_(visible_tags_sq))
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_reference_tags and asset_reference_meta.
Uses ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at
meta_rows: Dicts with: asset_reference_id, key, ordinal, val_*
"""
if tag_rows:
ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing(
index_elements=[
AssetReferenceMeta.asset_reference_id,
AssetReferenceMeta.key,
AssetReferenceMeta.ordinal,
]
)
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
session.execute(ins_meta, chunk)

View File

@ -1,62 +0,0 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@ -228,7 +228,7 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
- Stripping whitespace and converting to lowercase.
- Removing duplicates.
"""
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
def collect_models_files() -> list[str]:
out: list[str] = []

View File

@ -14,7 +14,7 @@ try:
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
@ -75,6 +75,13 @@ def init_db():
# Check if we need to upgrade
engine = create_engine(db_url)
# Enable foreign key enforcement for SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
conn = engine.connect()
context = MigrationContext.configure(conn)

View File

@ -0,0 +1,28 @@
"""Helper functions for assets integration tests."""
import time
import requests
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a synchronous sync/seed pass by calling the seed endpoint with wait=true.
Retries on 409 (already running) until the previous scan finishes.
"""
deadline = time.monotonic() + 60
while True:
r = session.post(
base_url + "/api/assets/seed?wait=true",
json={"roots": ["models", "input", "output"]},
timeout=60,
)
if r.status_code != 409:
assert r.status_code == 200, f"seed endpoint returned {r.status_code}: {r.text}"
return
if time.monotonic() > deadline:
raise TimeoutError("seed endpoint stuck in 409 (already running)")
time.sleep(0.25)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@ -0,0 +1,20 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def session():
"""In-memory SQLite session for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
with Session(engine) as sess:
yield sess
@pytest.fixture(autouse=True)
def autoclean_unit_test_assets():
"""Override parent autouse fixture - query tests don't need server cleanup."""
yield

View File

@ -0,0 +1,144 @@
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.helpers import get_utc_now
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
get_asset_by_hash,
upsert_asset,
bulk_insert_assets,
)
class TestAssetExistsByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,expected",
[
(None, "nonexistent", False), # No asset exists
("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash
(None, "", False), # Null hash in DB doesn't match empty string
],
ids=["nonexistent", "existing", "null_hash_no_match"],
)
def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected):
if setup_hash is not None or query_hash == "":
asset = Asset(hash=setup_hash, size_bytes=100)
session.add(asset)
session.commit()
assert asset_exists_by_hash(session, asset_hash=query_hash) is expected
class TestGetAssetByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,should_find",
[
(None, "nonexistent", False),
("blake3:def456", "blake3:def456", True),
],
ids=["nonexistent", "existing"],
)
def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find):
if setup_hash is not None:
asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png")
session.add(asset)
session.commit()
result = get_asset_by_hash(session, asset_hash=query_hash)
if should_find:
assert result is not None
assert result.size_bytes == 200
assert result.mime_type == "image/png"
else:
assert result is None
class TestUpsertAsset:
@pytest.mark.parametrize(
"first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime",
[
# New asset creation
(None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"),
# Existing asset, same values - no update
(500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"),
# Existing asset with size 0, update with new values
(0, None, 2048, "image/png", False, True, 2048, "image/png"),
# Existing asset, second call with size 0 - no update
(1000, None, 0, None, False, False, 1000, None),
],
ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"],
)
def test_upsert_scenarios(
self,
session: Session,
first_size,
first_mime,
second_size,
second_mime,
expect_created,
expect_updated,
final_size,
final_mime,
):
asset_hash = f"blake3:test_{first_size}_{second_size}"
# First upsert (if first_size is not None, we're testing the second call)
if first_size is not None:
upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=first_size,
mime_type=first_mime,
)
session.commit()
# The upsert call we're testing
asset, created, updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=second_size,
mime_type=second_mime,
)
session.commit()
assert created is expect_created
assert updated is expect_updated
assert asset.size_bytes == final_size
assert asset.mime_type == final_mime
class TestBulkInsertAssets:
def test_inserts_multiple_assets(self, session: Session):
now = get_utc_now()
rows = [
{"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": now},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": now},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": now},
]
bulk_insert_assets(session, rows)
session.commit()
assets = session.query(Asset).all()
assert len(assets) == 3
hashes = {a.hash for a in assets}
assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"}
def test_empty_list_is_noop(self, session: Session):
bulk_insert_assets(session, [])
session.commit()
assert session.query(Asset).count() == 0
def test_handles_large_batch(self, session: Session):
"""Test chunking logic with more rows than MAX_BIND_PARAMS allows."""
now = get_utc_now()
rows = [
{"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": now}
for i in range(200)
]
bulk_insert_assets(session, rows)
session.commit()
assert session.query(Asset).count() == 200

View File

@ -0,0 +1,517 @@
import time
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta
from app.assets.database.queries import (
reference_exists_for_asset_id,
get_reference_by_id,
insert_reference,
get_or_create_reference,
update_reference_timestamps,
list_references_page,
fetch_reference_asset_and_tags,
fetch_reference_and_asset,
update_reference_access_time,
set_reference_metadata,
delete_reference_by_id,
set_reference_preview,
bulk_insert_references_ignore_conflicts,
get_reference_ids_by_ids,
ensure_tags_exist,
add_tags_to_reference,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestReferenceExistsForAssetId:
def test_returns_false_when_no_reference(self, session: Session):
asset = _make_asset(session, "hash1")
assert reference_exists_for_asset_id(session, asset_id=asset.id) is False
def test_returns_true_when_reference_exists(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset)
assert reference_exists_for_asset_id(session, asset_id=asset.id) is True
class TestGetReferenceById:
def test_returns_none_for_nonexistent(self, session: Session):
assert get_reference_by_id(session, reference_id="nonexistent") is None
def test_returns_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="myfile.txt")
result = get_reference_by_id(session, reference_id=ref.id)
assert result is not None
assert result.name == "myfile.txt"
class TestListReferencesPage:
def test_empty_db(self, session: Session):
refs, tag_map, total = list_references_page(session)
assert refs == []
assert tag_map == {}
assert total == 0
def test_returns_references_with_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"])
session.commit()
refs, tag_map, total = list_references_page(session)
assert len(refs) == 1
assert refs[0].id == ref.id
assert set(tag_map[ref.id]) == {"alpha", "beta"}
assert total == 1
def test_name_contains_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="model_v1.safetensors")
_make_reference(session, asset, name="config.json")
session.commit()
refs, _, total = list_references_page(session, name_contains="model")
assert total == 1
assert refs[0].name == "model_v1.safetensors"
def test_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="public", owner_id="")
_make_reference(session, asset, name="private", owner_id="user1")
session.commit()
# Empty owner sees only public
refs, _, total = list_references_page(session, owner_id="")
assert total == 1
assert refs[0].name == "public"
# Owner sees both
refs, _, total = list_references_page(session, owner_id="user1")
assert total == 2
def test_include_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="tagged")
_make_reference(session, asset, name="untagged")
ensure_tags_exist(session, ["wanted"])
add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"])
session.commit()
refs, _, total = list_references_page(session, include_tags=["wanted"])
assert total == 1
assert refs[0].name == "tagged"
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="keep")
ref_exclude = _make_reference(session, asset, name="exclude")
ensure_tags_exist(session, ["bad"])
add_tags_to_reference(session, reference_id=ref_exclude.id, tags=["bad"])
session.commit()
refs, _, total = list_references_page(session, exclude_tags=["bad"])
assert total == 1
assert refs[0].name == "keep"
def test_sorting(self, session: Session):
asset = _make_asset(session, "hash1", size=100)
asset2 = _make_asset(session, "hash2", size=500)
_make_reference(session, asset, name="small")
_make_reference(session, asset2, name="large")
session.commit()
refs, _, _ = list_references_page(session, sort="size", order="desc")
assert refs[0].name == "large"
refs, _, _ = list_references_page(session, sort="name", order="asc")
assert refs[0].name == "large"
class TestFetchReferenceAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_reference_asset_and_tags(session, "nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="test.bin")
ensure_tags_exist(session, ["tag1"])
add_tags_to_reference(session, reference_id=ref.id, tags=["tag1"])
session.commit()
result = fetch_reference_asset_and_tags(session, ref.id)
assert result is not None
ret_ref, ret_asset, ret_tags = result
assert ret_ref.id == ref.id
assert ret_asset.id == asset.id
assert ret_tags == ["tag1"]
class TestFetchReferenceAndAsset:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_reference_and_asset(session, reference_id="nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
result = fetch_reference_and_asset(session, reference_id=ref.id)
assert result is not None
ret_ref, ret_asset = result
assert ret_ref.id == ref.id
assert ret_asset.id == asset.id
class TestUpdateReferenceAccessTime:
def test_updates_last_access_time(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
original_time = ref.last_access_time
session.commit()
import time
time.sleep(0.01)
update_reference_access_time(session, reference_id=ref.id)
session.commit()
session.refresh(ref)
assert ref.last_access_time > original_time
class TestDeleteReferenceById:
def test_deletes_existing(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
result = delete_reference_by_id(session, reference_id=ref.id, owner_id="")
assert result is True
assert get_reference_by_id(session, reference_id=ref.id) is None
def test_returns_false_for_nonexistent(self, session: Session):
result = delete_reference_by_id(session, reference_id="nonexistent", owner_id="")
assert result is False
def test_respects_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, owner_id="user1")
session.commit()
result = delete_reference_by_id(session, reference_id=ref.id, owner_id="user2")
assert result is False
assert get_reference_by_id(session, reference_id=ref.id) is not None
class TestSetReferencePreview:
def test_sets_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id)
session.commit()
session.refresh(ref)
assert ref.preview_id == preview_asset.id
def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
ref.preview_id = preview_asset.id
session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None)
session.commit()
session.refresh(ref)
assert ref.preview_id is None
def test_raises_for_nonexistent_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None)
def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
with pytest.raises(ValueError, match="Preview Asset"):
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent")
class TestInsertReference:
def test_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref = insert_reference(
session, asset_id=asset.id, owner_id="user1", name="test.bin"
)
session.commit()
assert ref is not None
assert ref.name == "test.bin"
assert ref.owner_id == "user1"
def test_allows_duplicate_names(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = insert_reference(session, asset_id=asset.id, owner_id="user1", name="dup.bin")
session.commit()
# Duplicate names are now allowed
ref2 = insert_reference(
session, asset_id=asset.id, owner_id="user1", name="dup.bin"
)
session.commit()
assert ref1 is not None
assert ref2 is not None
assert ref1.id != ref2.id
class TestGetOrCreateReference:
def test_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref, created = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="new.bin"
)
session.commit()
assert created is True
assert ref.name == "new.bin"
def test_always_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref1, created1 = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
# Duplicate names are allowed, so always creates new
ref2, created2 = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
assert created1 is True
assert created2 is True
assert ref1.id != ref2.id
class TestUpdateReferenceTimestamps:
def test_updates_timestamps(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
original_updated_at = ref.updated_at
session.commit()
time.sleep(0.01)
update_reference_timestamps(session, ref)
session.commit()
session.refresh(ref)
assert ref.updated_at > original_updated_at
def test_updates_preview_id(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
session.commit()
update_reference_timestamps(session, ref, preview_id=preview_asset.id)
session.commit()
session.refresh(ref)
assert ref.preview_id == preview_asset.id
class TestSetReferenceMetadata:
def test_sets_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"key": "value"}
)
session.commit()
session.refresh(ref)
assert ref.user_metadata == {"key": "value"}
# Check metadata table
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 1
assert meta[0].key == "key"
assert meta[0].val_str == "value"
def test_replaces_existing_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"old": "data"}
)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"new": "data"}
)
session.commit()
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 1
assert meta[0].key == "new"
def test_clears_metadata_with_empty_dict(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"key": "value"}
)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={}
)
session.commit()
session.refresh(ref)
assert ref.user_metadata == {}
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 0
def test_raises_for_nonexistent(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_reference_metadata(
session, reference_id="nonexistent", user_metadata={"key": "value"}
)
class TestBulkInsertReferencesIgnoreConflicts:
def test_inserts_multiple_references(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk1.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk2.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
refs = session.query(AssetReference).all()
assert len(refs) == 2
def test_allows_duplicate_names(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="existing.bin", owner_id="")
session.commit()
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "existing.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "new.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
# Duplicate names allowed, so all 3 rows exist
refs = session.query(AssetReference).all()
assert len(refs) == 3
def test_empty_list_is_noop(self, session: Session):
bulk_insert_references_ignore_conflicts(session, [])
assert session.query(AssetReference).count() == 0
class TestGetReferenceIdsByIds:
def test_returns_existing_ids(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="a.bin")
ref2 = _make_reference(session, asset, name="b.bin")
session.commit()
found = get_reference_ids_by_ids(session, [ref1.id, ref2.id, "nonexistent"])
assert found == {ref1.id, ref2.id}
def test_empty_list_returns_empty(self, session: Session):
found = get_reference_ids_by_ids(session, [])
assert found == set()

View File

@ -0,0 +1,499 @@
"""Tests for cache_state (AssetReference file path) query functions."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import (
list_references_by_asset_id,
upsert_reference,
get_unreferenced_unhashed_asset_ids,
delete_assets_by_ids,
get_references_for_prefixes,
bulk_update_needs_verify,
delete_references_by_ids,
delete_orphaned_seed_asset,
bulk_insert_references_ignore_conflicts,
get_references_by_paths_and_asset_ids,
mark_references_missing_outside_prefixes,
restore_references_by_paths,
)
from app.assets.helpers import select_best_live_path, get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size)
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
file_path: str,
name: str = "test",
mtime_ns: int | None = None,
needs_verify: bool = False,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
asset_id=asset.id,
file_path=file_path,
name=name,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestListReferencesByAssetId:
def test_returns_empty_for_no_references(self, session: Session):
asset = _make_asset(session, "hash1")
refs = list_references_by_asset_id(session, asset_id=asset.id)
assert list(refs) == []
def test_returns_references_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/path/a.bin", name="a")
_make_reference(session, asset, "/path/b.bin", name="b")
session.commit()
refs = list_references_by_asset_id(session, asset_id=asset.id)
paths = [r.file_path for r in refs]
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
def test_does_not_return_other_assets_references(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path/asset1.bin", name="a1")
_make_reference(session, asset2, "/path/asset2.bin", name="a2")
session.commit()
refs = list_references_by_asset_id(session, asset_id=asset1.id)
paths = [r.file_path for r in refs]
assert paths == ["/path/asset1.bin"]
class TestSelectBestLivePath:
def test_returns_empty_for_empty_list(self):
result = select_best_live_path([])
assert result == ""
def test_returns_empty_when_no_files_exist(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, "/nonexistent/path.bin")
session.commit()
result = select_best_live_path([ref])
assert result == ""
def test_prefers_verified_path(self, session: Session, tmp_path):
"""needs_verify=False should be preferred."""
asset = _make_asset(session, "hash1")
verified_file = tmp_path / "verified.bin"
verified_file.write_bytes(b"data")
unverified_file = tmp_path / "unverified.bin"
unverified_file.write_bytes(b"data")
ref_verified = _make_reference(
session, asset, str(verified_file), name="verified", needs_verify=False
)
ref_unverified = _make_reference(
session, asset, str(unverified_file), name="unverified", needs_verify=True
)
session.commit()
refs = [ref_unverified, ref_verified]
result = select_best_live_path(refs)
assert result == str(verified_file)
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
"""If all references need verification, return first existing path."""
asset = _make_asset(session, "hash1")
existing_file = tmp_path / "exists.bin"
existing_file.write_bytes(b"data")
ref = _make_reference(session, asset, str(existing_file), needs_verify=True)
session.commit()
result = select_best_live_path([ref])
assert result == str(existing_file)
class TestSelectBestLivePathWithMocking:
def test_handles_missing_file_path_attr(self):
"""Gracefully handle references with None file_path."""
class MockRef:
file_path = None
needs_verify = False
result = select_best_live_path([MockRef()])
assert result == ""
class TestUpsertReference:
@pytest.mark.parametrize(
"initial_mtime,second_mtime,expect_created,expect_updated,final_mtime",
[
# New reference creation
(None, 12345, True, False, 12345),
# Existing reference, same mtime - no update
(100, 100, False, False, 100),
# Existing reference, different mtime - update
(100, 200, False, True, 200),
],
ids=["new_reference", "existing_no_change", "existing_update_mtime"],
)
def test_upsert_scenarios(
self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime
):
asset = _make_asset(session, "hash1")
file_path = f"/path_{initial_mtime}_{second_mtime}.bin"
name = f"file_{initial_mtime}_{second_mtime}"
# Create initial reference if needed
if initial_mtime is not None:
upsert_reference(session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=initial_mtime)
session.commit()
# The upsert call we're testing
created, updated = upsert_reference(
session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=second_mtime
)
session.commit()
assert created is expect_created
assert updated is expect_updated
ref = session.query(AssetReference).filter_by(file_path=file_path).one()
assert ref.mtime_ns == final_mtime
def test_upsert_restores_missing_reference(self, session: Session):
"""Upserting a reference that was marked missing should restore it."""
asset = _make_asset(session, "hash1")
file_path = "/restored/file.bin"
ref = _make_reference(session, asset, file_path, mtime_ns=100)
ref.is_missing = True
session.commit()
created, updated = upsert_reference(
session, asset_id=asset.id, file_path=file_path, name="restored", mtime_ns=100
)
session.commit()
assert created is False
assert updated is True
restored_ref = session.query(AssetReference).filter_by(file_path=file_path).one()
assert restored_ref.is_missing is False
class TestRestoreReferencesByPaths:
def test_restores_missing_references(self, session: Session):
asset = _make_asset(session, "hash1")
missing_path = "/missing/file.bin"
active_path = "/active/file.bin"
missing_ref = _make_reference(session, asset, missing_path, name="missing")
missing_ref.is_missing = True
_make_reference(session, asset, active_path, name="active")
session.commit()
restored = restore_references_by_paths(session, [missing_path])
session.commit()
assert restored == 1
ref = session.query(AssetReference).filter_by(file_path=missing_path).one()
assert ref.is_missing is False
def test_empty_list_restores_nothing(self, session: Session):
restored = restore_references_by_paths(session, [])
assert restored == 0
class TestMarkReferencesMissingOutsidePrefixes:
def test_marks_references_missing_outside_prefixes(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
valid_dir = tmp_path / "valid"
valid_dir.mkdir()
invalid_dir = tmp_path / "invalid"
invalid_dir.mkdir()
valid_path = str(valid_dir / "file.bin")
invalid_path = str(invalid_dir / "file.bin")
_make_reference(session, asset, valid_path, name="valid")
_make_reference(session, asset, invalid_path, name="invalid")
session.commit()
marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)])
session.commit()
assert marked == 1
all_refs = session.query(AssetReference).all()
assert len(all_refs) == 2
valid_ref = next(r for r in all_refs if r.file_path == valid_path)
invalid_ref = next(r for r in all_refs if r.file_path == invalid_path)
assert valid_ref.is_missing is False
assert invalid_ref.is_missing is True
def test_empty_prefixes_marks_nothing(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/some/path.bin")
session.commit()
marked = mark_references_missing_outside_prefixes(session, [])
assert marked == 0
class TestGetUnreferencedUnhashedAssetIds:
def test_returns_unreferenced_unhashed_assets(self, session: Session):
# Unhashed asset (hash=None) with no references (no file_path)
no_refs = _make_asset(session, hash_val=None)
# Unhashed asset with active reference (not unreferenced)
with_active_ref = _make_asset(session, hash_val=None)
_make_reference(session, with_active_ref, "/has/ref.bin", name="has_ref")
# Unhashed asset with only missing reference (should be unreferenced)
with_missing_ref = _make_asset(session, hash_val=None)
missing_ref = _make_reference(session, with_missing_ref, "/missing/ref.bin", name="missing_ref")
missing_ref.is_missing = True
# Regular asset (hash not None) - should not be returned
_make_asset(session, hash_val="blake3:regular")
session.commit()
unreferenced = get_unreferenced_unhashed_asset_ids(session)
assert no_refs.id in unreferenced
assert with_missing_ref.id in unreferenced
assert with_active_ref.id not in unreferenced
class TestDeleteAssetsByIds:
def test_deletes_assets_and_references(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/test/path.bin", name="test")
session.commit()
deleted = delete_assets_by_ids(session, [asset.id])
session.commit()
assert deleted == 1
assert session.query(Asset).count() == 0
assert session.query(AssetReference).count() == 0
def test_empty_list_deletes_nothing(self, session: Session):
_make_asset(session, "hash1")
session.commit()
deleted = delete_assets_by_ids(session, [])
assert deleted == 0
assert session.query(Asset).count() == 1
class TestGetReferencesForPrefixes:
def test_returns_references_matching_prefix(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
dir1 = tmp_path / "dir1"
dir1.mkdir()
dir2 = tmp_path / "dir2"
dir2.mkdir()
path1 = str(dir1 / "file.bin")
path2 = str(dir2 / "file.bin")
_make_reference(session, asset, path1, name="file1", mtime_ns=100)
_make_reference(session, asset, path2, name="file2", mtime_ns=200)
session.commit()
rows = get_references_for_prefixes(session, [str(dir1)])
assert len(rows) == 1
assert rows[0].file_path == path1
def test_empty_prefixes_returns_empty(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/some/path.bin")
session.commit()
rows = get_references_for_prefixes(session, [])
assert rows == []
class TestBulkSetNeedsVerify:
def test_sets_needs_verify_flag(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, "/path1.bin", needs_verify=False)
ref2 = _make_reference(session, asset, "/path2.bin", needs_verify=False)
session.commit()
updated = bulk_update_needs_verify(session, [ref1.id, ref2.id], True)
session.commit()
assert updated == 2
session.refresh(ref1)
session.refresh(ref2)
assert ref1.needs_verify is True
assert ref2.needs_verify is True
def test_empty_list_updates_nothing(self, session: Session):
updated = bulk_update_needs_verify(session, [], True)
assert updated == 0
class TestDeleteReferencesByIds:
def test_deletes_references_by_id(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, "/path1.bin")
_make_reference(session, asset, "/path2.bin")
session.commit()
deleted = delete_references_by_ids(session, [ref1.id])
session.commit()
assert deleted == 1
assert session.query(AssetReference).count() == 1
def test_empty_list_deletes_nothing(self, session: Session):
deleted = delete_references_by_ids(session, [])
assert deleted == 0
class TestDeleteOrphanedSeedAsset:
@pytest.mark.parametrize(
"create_asset,expected_deleted,expected_count",
[
(True, True, 0), # Existing asset gets deleted
(False, False, 0), # Nonexistent returns False
],
ids=["deletes_existing", "nonexistent_returns_false"],
)
def test_delete_orphaned_seed_asset(
self, session: Session, create_asset, expected_deleted, expected_count
):
asset_id = "nonexistent-id"
if create_asset:
asset = _make_asset(session, hash_val=None)
asset_id = asset.id
_make_reference(session, asset, "/test/path.bin", name="test")
session.commit()
deleted = delete_orphaned_seed_asset(session, asset_id)
if create_asset:
session.commit()
assert deleted is expected_deleted
assert session.query(Asset).count() == expected_count
class TestBulkInsertReferencesIgnoreConflicts:
def test_inserts_multiple_references(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"asset_id": asset.id,
"file_path": "/bulk1.bin",
"name": "bulk1",
"mtime_ns": 100,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"asset_id": asset.id,
"file_path": "/bulk2.bin",
"name": "bulk2",
"mtime_ns": 200,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetReference).count() == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/existing.bin", mtime_ns=100)
session.commit()
now = get_utc_now()
rows = [
{
"asset_id": asset.id,
"file_path": "/existing.bin",
"name": "existing",
"mtime_ns": 999,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"asset_id": asset.id,
"file_path": "/new.bin",
"name": "new",
"mtime_ns": 200,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetReference).count() == 2
existing = session.query(AssetReference).filter_by(file_path="/existing.bin").one()
assert existing.mtime_ns == 100 # Original value preserved
def test_empty_list_is_noop(self, session: Session):
bulk_insert_references_ignore_conflicts(session, [])
assert session.query(AssetReference).count() == 0
class TestGetReferencesByPathsAndAssetIds:
def test_returns_matching_paths(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path1.bin")
_make_reference(session, asset2, "/path2.bin")
session.commit()
path_to_asset = {
"/path1.bin": asset1.id,
"/path2.bin": asset2.id,
}
winners = get_references_by_paths_and_asset_ids(session, path_to_asset)
assert winners == {"/path1.bin", "/path2.bin"}
def test_excludes_non_matching_asset_ids(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path1.bin")
session.commit()
# Path exists but with different asset_id
path_to_asset = {"/path1.bin": asset2.id}
winners = get_references_by_paths_and_asset_ids(session, path_to_asset)
assert winners == set()
def test_empty_dict_returns_empty(self, session: Session):
winners = get_references_by_paths_and_asset_ids(session, {})
assert winners == set()

View File

@ -0,0 +1,184 @@
"""Tests for metadata filtering logic in asset_reference queries."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta
from app.assets.database.queries import list_references_page
from app.assets.database.queries.asset_reference import convert_metadata_to_rows
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
name: str,
metadata: dict | None = None,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id="",
name=name,
asset_id=asset.id,
user_metadata=metadata,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
if metadata:
for key, val in metadata.items():
for row in convert_metadata_to_rows(key, val):
meta_row = AssetReferenceMeta(
asset_reference_id=ref.id,
key=row["key"],
ordinal=row.get("ordinal", 0),
val_str=row.get("val_str"),
val_num=row.get("val_num"),
val_bool=row.get("val_bool"),
val_json=row.get("val_json"),
)
session.add(meta_row)
session.flush()
return ref
class TestMetadataFilterByType:
"""Table-driven tests for metadata filtering by different value types."""
@pytest.mark.parametrize(
"match_meta,nomatch_meta,filter_key,filter_val",
[
# String matching
({"category": "models"}, {"category": "images"}, "category", "models"),
# Integer matching
({"epoch": 5}, {"epoch": 10}, "epoch", 5),
# Float matching
({"score": 0.95}, {"score": 0.5}, "score", 0.95),
# Boolean True matching
({"enabled": True}, {"enabled": False}, "enabled", True),
# Boolean False matching
({"enabled": False}, {"enabled": True}, "enabled", False),
],
ids=["string", "int", "float", "bool_true", "bool_false"],
)
def test_filter_matches_correct_value(
self, session: Session, match_meta, nomatch_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "match", match_meta)
_make_reference(session, asset, "nomatch", nomatch_meta)
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 1
assert refs[0].name == "match"
@pytest.mark.parametrize(
"stored_meta,filter_key,filter_val",
[
# String no match
({"category": "models"}, "category", "other"),
# Int no match
({"epoch": 5}, "epoch", 99),
# Float no match
({"score": 0.5}, "score", 0.99),
],
ids=["string_no_match", "int_no_match", "float_no_match"],
)
def test_filter_returns_empty_when_no_match(
self, session: Session, stored_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "item", stored_meta)
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 0
class TestMetadataFilterNull:
"""Tests for null/missing key filtering."""
@pytest.mark.parametrize(
"match_name,match_meta,nomatch_name,nomatch_meta,filter_key",
[
# Null matches missing key
("missing_key", {}, "has_key", {"optional": "value"}, "optional"),
# Null matches explicit null
("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"),
],
ids=["missing_key", "explicit_null"],
)
def test_null_filter_matches(
self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key
):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, match_name, match_meta)
_make_reference(session, asset, nomatch_name, nomatch_meta)
session.commit()
refs, _, total = list_references_page(session, metadata_filter={filter_key: None})
assert total == 1
assert refs[0].name == match_name
class TestMetadataFilterList:
"""Tests for list-based (OR) filtering."""
def test_filter_by_list_matches_any(self, session: Session):
"""List values should match ANY of the values (OR)."""
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "cat_a", {"category": "a"})
_make_reference(session, asset, "cat_b", {"category": "b"})
_make_reference(session, asset, "cat_c", {"category": "c"})
session.commit()
refs, _, total = list_references_page(session, metadata_filter={"category": ["a", "b"]})
assert total == 2
names = {r.name for r in refs}
assert names == {"cat_a", "cat_b"}
class TestMetadataFilterMultipleKeys:
"""Tests for multiple filter keys (AND semantics)."""
def test_multiple_keys_must_all_match(self, session: Session):
"""Multiple keys should ALL match (AND)."""
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "match", {"type": "model", "version": 2})
_make_reference(session, asset, "wrong_type", {"type": "config", "version": 2})
_make_reference(session, asset, "wrong_version", {"type": "model", "version": 1})
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={"type": "model", "version": 2}
)
assert total == 1
assert refs[0].name == "match"
class TestMetadataFilterEmptyDict:
"""Tests for empty filter behavior."""
def test_empty_filter_returns_all(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "a", {"key": "val"})
_make_reference(session, asset, "b", {})
session.commit()
refs, _, total = list_references_page(session, metadata_filter={})
assert total == 2

View File

@ -0,0 +1,366 @@
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, AssetReferenceMeta, Tag
from app.assets.database.queries import (
ensure_tags_exist,
get_reference_tags,
set_reference_tags,
add_tags_to_reference,
remove_tags_from_reference,
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
list_tags_with_usage,
bulk_insert_tags_and_meta,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestEnsureTagsExist:
def test_creates_new_tags(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_is_idempotent(self, session: Session):
ensure_tags_exist(session, ["alpha"], tag_type="user")
ensure_tags_exist(session, ["alpha"], tag_type="user")
session.commit()
assert session.query(Tag).count() == 1
def test_normalizes_tags(self, session: Session):
ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"])
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, [])
session.commit()
assert session.query(Tag).count() == 0
def test_tag_type_is_set(self, session: Session):
ensure_tags_exist(session, ["system-tag"], tag_type="system")
session.commit()
tag = session.query(Tag).filter_by(name="system-tag").one()
assert tag.tag_type == "system"
class TestGetReferenceTags:
def test_returns_empty_for_no_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
tags = get_reference_tags(session, reference_id=ref.id)
assert tags == []
def test_returns_tags_for_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["tag1", "tag2"])
session.add_all([
AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag1", origin="manual", added_at=get_utc_now()),
AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag2", origin="manual", added_at=get_utc_now()),
])
session.flush()
tags = get_reference_tags(session, reference_id=ref.id)
assert set(tags) == {"tag1", "tag2"}
class TestSetReferenceTags:
def test_adds_new_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"])
session.commit()
assert set(result["added"]) == {"a", "b"}
assert result["removed"] == []
assert set(result["total"]) == {"a", "b"}
def test_removes_old_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
set_reference_tags(session, reference_id=ref.id, tags=["a", "b", "c"])
result = set_reference_tags(session, reference_id=ref.id, tags=["a"])
session.commit()
assert result["added"] == []
assert set(result["removed"]) == {"b", "c"}
assert result["total"] == ["a"]
def test_replaces_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
set_reference_tags(session, reference_id=ref.id, tags=["a", "b"])
result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"])
session.commit()
assert result["added"] == ["c"]
assert result["removed"] == ["a"]
assert set(result["total"]) == {"b", "c"}
class TestAddTagsToReference:
def test_adds_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
session.commit()
assert set(result["added"]) == {"x", "y"}
assert result["already_present"] == []
def test_reports_already_present(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["x"])
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
session.commit()
assert result["added"] == ["y"]
assert result["already_present"] == ["x"]
def test_raises_for_missing_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"):
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
class TestRemoveTagsFromReference:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"])
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"])
session.commit()
assert set(result["removed"]) == {"a", "b"}
assert result["not_present"] == []
assert result["total_tags"] == ["c"]
def test_reports_not_present(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["a"])
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"])
session.commit()
assert result["removed"] == ["a"]
assert result["not_present"] == ["x"]
def test_raises_for_missing_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"):
remove_tags_from_reference(session, reference_id="nonexistent", tags=["x"])
class TestMissingTagFunctions:
def test_add_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_reference_tags(session, reference_id=ref.id)
assert "missing" in tags
def test_add_missing_tag_is_idempotent(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="missing").all()
assert len(links) == 1
def test_remove_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_reference_tags(session, reference_id=ref.id)
assert "missing" not in tags
class TestListTagsWithUsage:
def test_returns_tags_with_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_exclude_zero_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session, include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags_with_usage(session, order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_owner_visibility(self, session: Session):
ensure_tags_exist(session, ["shared-tag", "owner-tag"])
asset = _make_asset(session, "hash1")
shared_ref = _make_reference(session, asset, name="shared", owner_id="")
owner_ref = _make_reference(session, asset, name="owned", owner_id="user1")
add_tags_to_reference(session, reference_id=shared_ref.id, tags=["shared-tag"])
add_tags_to_reference(session, reference_id=owner_ref.id, tags=["owner-tag"])
session.commit()
# Empty owner sees only shared
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 0
# User1 sees both
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 1
class TestBulkInsertTagsAndMeta:
def test_inserts_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_reference_id": ref.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now},
{"asset_reference_id": ref.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
tags = get_reference_tags(session, reference_id=ref.id)
assert set(tags) == {"bulk-tag1", "bulk-tag2"}
def test_inserts_meta(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
meta_rows = [
{
"asset_reference_id": ref.id,
"key": "meta-key",
"ordinal": 0,
"val_str": "meta-value",
"val_num": None,
"val_bool": None,
"val_json": None,
},
]
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows)
session.commit()
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 1
assert meta[0].key == "meta-key"
assert meta[0].val_str == "meta-value"
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["existing-tag"])
add_tags_to_reference(session, reference_id=ref.id, tags=["existing-tag"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_reference_id": ref.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
# Should still have only one tag link
links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="existing-tag").all()
assert len(links) == 1
# Origin should be original, not overwritten
assert links[0].origin == "manual"
def test_empty_lists_is_noop(self, session: Session):
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[])
assert session.query(AssetReferenceTag).count() == 0
assert session.query(AssetReferenceMeta).count() == 0