From 26ea5085887c90a5a5b1766cec1672f90e86fa2f Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 15 Dec 2023 11:37:45 -0800 Subject: [PATCH] backend maintain (#1429) --- ldm_patched/contrib/external_sag.py | 6 +----- ldm_patched/ldm/modules/attention.py | 4 +--- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/ldm_patched/contrib/external_sag.py b/ldm_patched/contrib/external_sag.py index 59d1890..3505b44 100644 --- a/ldm_patched/contrib/external_sag.py +++ b/ldm_patched/contrib/external_sag.py @@ -29,9 +29,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale @@ -113,7 +111,6 @@ class SelfAttentionGuidance: m = model.clone() attn_scores = None - mid_block_shape = None # TODO: make this work properly with chunked batches # currently, we can only save the attn from one UNet call @@ -136,7 +133,6 @@ class SelfAttentionGuidance: def post_cfg_function(args): nonlocal attn_scores - nonlocal mid_block_shape uncond_attn = attn_scores sag_scale = scale diff --git a/ldm_patched/ldm/modules/attention.py b/ldm_patched/ldm/modules/attention.py index f4579ba..49e502e 100644 --- a/ldm_patched/ldm/modules/attention.py +++ b/ldm_patched/ldm/modules/attention.py @@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale