"""Tests for external cache provider API.""" import importlib.util import pytest from typing import Optional def _torch_available() -> bool: """Check if PyTorch is available.""" return importlib.util.find_spec("torch") is not None from comfy_execution.cache_provider import ( CacheProvider, CacheContext, CacheValue, register_cache_provider, unregister_cache_provider, _get_cache_providers, _has_cache_providers, _clear_cache_providers, _serialize_cache_key, _contains_self_unequal, _estimate_value_size, _canonicalize, ) class TestCanonicalize: """Test _canonicalize function for deterministic ordering.""" def test_frozenset_ordering_is_deterministic(self): """Frozensets should produce consistent canonical form regardless of iteration order.""" # Create two frozensets with same content fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) result1 = _canonicalize(fs1) result2 = _canonicalize(fs2) assert result1 == result2 def test_nested_frozenset_ordering(self): """Nested frozensets should also be deterministically ordered.""" inner1 = frozenset([1, 2, 3]) inner2 = frozenset([3, 2, 1]) fs1 = frozenset([("key", inner1)]) fs2 = frozenset([("key", inner2)]) result1 = _canonicalize(fs1) result2 = _canonicalize(fs2) assert result1 == result2 def test_dict_ordering(self): """Dicts should be sorted by key.""" d1 = {"z": 1, "a": 2, "m": 3} d2 = {"a": 2, "m": 3, "z": 1} result1 = _canonicalize(d1) result2 = _canonicalize(d2) assert result1 == result2 def test_tuple_preserved(self): """Tuples should be marked and preserved.""" t = (1, 2, 3) result = _canonicalize(t) assert result[0] == "__tuple__" assert result[1] == [1, 2, 3] def test_list_preserved(self): """Lists should be recursively canonicalized.""" lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] result = _canonicalize(lst) # First element should be dict with sorted keys assert result[0] == {"a": 1, "b": 2} # Second element should be canonicalized frozenset assert result[1][0] == "__frozenset__" def test_primitives_unchanged(self): """Primitive types should pass through unchanged.""" assert _canonicalize(42) == 42 assert _canonicalize(3.14) == 3.14 assert _canonicalize("hello") == "hello" assert _canonicalize(True) is True assert _canonicalize(None) is None def test_bytes_converted(self): """Bytes should be converted to hex string.""" b = b"\x00\xff" result = _canonicalize(b) assert result[0] == "__bytes__" assert result[1] == "00ff" def test_set_ordering(self): """Sets should be sorted like frozensets.""" s1 = {3, 1, 2} s2 = {1, 2, 3} result1 = _canonicalize(s1) result2 = _canonicalize(s2) assert result1 == result2 assert result1[0] == "__set__" def test_unknown_type_raises(self): """Unknown types should raise ValueError (fail-closed).""" class CustomObj: pass with pytest.raises(ValueError): _canonicalize(CustomObj()) def test_object_with_value_attr_raises(self): """Objects with .value attribute (Unhashable-like) should raise ValueError.""" class FakeUnhashable: def __init__(self): self.value = float('nan') with pytest.raises(ValueError): _canonicalize(FakeUnhashable()) class TestSerializeCacheKey: """Test _serialize_cache_key for deterministic hashing.""" def test_same_content_same_hash(self): """Same content should produce same hash.""" key1 = frozenset([("node_1", frozenset([("input", "value")]))]) key2 = frozenset([("node_1", frozenset([("input", "value")]))]) hash1 = _serialize_cache_key(key1) hash2 = _serialize_cache_key(key2) assert hash1 == hash2 def test_different_content_different_hash(self): """Different content should produce different hash.""" key1 = frozenset([("node_1", "value_a")]) key2 = frozenset([("node_1", "value_b")]) hash1 = _serialize_cache_key(key1) hash2 = _serialize_cache_key(key2) assert hash1 != hash2 def test_returns_hex_string(self): """Should return hex string (SHA256 hex digest).""" key = frozenset([("test", 123)]) result = _serialize_cache_key(key) assert isinstance(result, str) assert len(result) == 64 # SHA256 hex digest is 64 chars def test_complex_nested_structure(self): """Complex nested structures should hash deterministically.""" # Note: frozensets can only contain hashable types, so we use # nested frozensets of tuples to represent dict-like structures key = frozenset([ ("node_1", frozenset([ ("input_a", ("tuple", "value")), ("input_b", frozenset([("nested", "dict")])), ])), ("node_2", frozenset([ ("param", 42), ])), ]) # Hash twice to verify determinism hash1 = _serialize_cache_key(key) hash2 = _serialize_cache_key(key) assert hash1 == hash2 def test_dict_in_cache_key(self): """Dicts passed directly to _serialize_cache_key should work.""" key = {"node_1": {"input": "value"}, "node_2": 42} hash1 = _serialize_cache_key(key) hash2 = _serialize_cache_key(key) assert hash1 == hash2 assert isinstance(hash1, str) assert len(hash1) == 64 def test_unknown_type_returns_none(self): """Non-cacheable types should return None (fail-closed).""" class CustomObj: pass assert _serialize_cache_key(CustomObj()) is None class TestContainsSelfUnequal: """Test _contains_self_unequal utility function.""" def test_nan_float_detected(self): """NaN floats should be detected (not equal to itself).""" assert _contains_self_unequal(float('nan')) is True def test_regular_float_not_detected(self): """Regular floats are equal to themselves.""" assert _contains_self_unequal(3.14) is False assert _contains_self_unequal(0.0) is False assert _contains_self_unequal(-1.5) is False def test_infinity_not_detected(self): """Infinity is equal to itself.""" assert _contains_self_unequal(float('inf')) is False assert _contains_self_unequal(float('-inf')) is False def test_nan_in_list(self): """NaN in list should be detected.""" assert _contains_self_unequal([1, 2, float('nan'), 4]) is True assert _contains_self_unequal([1, 2, 3, 4]) is False def test_nan_in_tuple(self): """NaN in tuple should be detected.""" assert _contains_self_unequal((1, float('nan'))) is True assert _contains_self_unequal((1, 2, 3)) is False def test_nan_in_frozenset(self): """NaN in frozenset should be detected.""" assert _contains_self_unequal(frozenset([1, float('nan')])) is True assert _contains_self_unequal(frozenset([1, 2, 3])) is False def test_nan_in_dict_value(self): """NaN in dict value should be detected.""" assert _contains_self_unequal({"key": float('nan')}) is True assert _contains_self_unequal({"key": 42}) is False def test_nan_in_nested_structure(self): """NaN in deeply nested structure should be detected.""" nested = {"level1": [{"level2": (1, 2, float('nan'))}]} assert _contains_self_unequal(nested) is True def test_non_numeric_types(self): """Non-numeric types should not be self-unequal.""" assert _contains_self_unequal("string") is False assert _contains_self_unequal(None) is False assert _contains_self_unequal(True) is False def test_object_with_nan_value_attribute(self): """Objects wrapping NaN in .value should be detected.""" class NanWrapper: def __init__(self): self.value = float('nan') assert _contains_self_unequal(NanWrapper()) is True def test_custom_self_unequal_object(self): """Custom objects where not (x == x) should be detected.""" class NeverEqual: def __eq__(self, other): return False assert _contains_self_unequal(NeverEqual()) is True class TestEstimateValueSize: """Test _estimate_value_size utility function.""" def test_empty_outputs(self): """Empty outputs should have zero size.""" value = CacheValue(outputs=[]) assert _estimate_value_size(value) == 0 @pytest.mark.skipif( not _torch_available(), reason="PyTorch not available" ) def test_tensor_size_estimation(self): """Tensor size should be estimated correctly.""" import torch # 1000 float32 elements = 4000 bytes tensor = torch.zeros(1000, dtype=torch.float32) value = CacheValue(outputs=[[tensor]]) size = _estimate_value_size(value) assert size == 4000 @pytest.mark.skipif( not _torch_available(), reason="PyTorch not available" ) def test_nested_tensor_in_dict(self): """Tensors nested in dicts should be counted.""" import torch tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes value = CacheValue(outputs=[[{"samples": tensor}]]) size = _estimate_value_size(value) assert size == 400 class TestProviderRegistry: """Test cache provider registration and retrieval.""" def setup_method(self): """Clear providers before each test.""" _clear_cache_providers() def teardown_method(self): """Clear providers after each test.""" _clear_cache_providers() def test_register_provider(self): """Provider should be registered successfully.""" provider = MockCacheProvider() register_cache_provider(provider) assert _has_cache_providers() is True providers = _get_cache_providers() assert len(providers) == 1 assert providers[0] is provider def test_unregister_provider(self): """Provider should be unregistered successfully.""" provider = MockCacheProvider() register_cache_provider(provider) unregister_cache_provider(provider) assert _has_cache_providers() is False def test_multiple_providers(self): """Multiple providers can be registered.""" provider1 = MockCacheProvider() provider2 = MockCacheProvider() register_cache_provider(provider1) register_cache_provider(provider2) providers = _get_cache_providers() assert len(providers) == 2 def test_duplicate_registration_ignored(self): """Registering same provider twice should be ignored.""" provider = MockCacheProvider() register_cache_provider(provider) register_cache_provider(provider) # Should be ignored providers = _get_cache_providers() assert len(providers) == 1 def test_clear_providers(self): """_clear_cache_providers should remove all providers.""" provider1 = MockCacheProvider() provider2 = MockCacheProvider() register_cache_provider(provider1) register_cache_provider(provider2) _clear_cache_providers() assert _has_cache_providers() is False assert len(_get_cache_providers()) == 0 class TestCacheContext: """Test CacheContext dataclass.""" def test_context_creation(self): """CacheContext should be created with all fields.""" context = CacheContext( prompt_id="prompt-123", node_id="node-456", class_type="KSampler", cache_key_hash="a" * 64, ) assert context.prompt_id == "prompt-123" assert context.node_id == "node-456" assert context.class_type == "KSampler" assert context.cache_key_hash == "a" * 64 class TestCacheValue: """Test CacheValue dataclass.""" def test_value_creation(self): """CacheValue should be created with outputs.""" outputs = [[{"samples": "tensor_data"}]] value = CacheValue(outputs=outputs) assert value.outputs == outputs class MockCacheProvider(CacheProvider): """Mock cache provider for testing.""" def __init__(self): self.lookups = [] self.stores = [] async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: self.lookups.append(context) return None async def on_store(self, context: CacheContext, value: CacheValue) -> None: self.stores.append((context, value))