backend maintain (#1429)

This commit is contained in:
lllyasviel 2023-12-15 11:37:45 -08:00 committed by GitHub
parent 059037eeb2
commit 26ea508588
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 8 deletions

View File

@ -29,9 +29,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
@ -113,7 +111,6 @@ class SelfAttentionGuidance:
m = model.clone()
attn_scores = None
mid_block_shape = None
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
@ -136,7 +133,6 @@ class SelfAttentionGuidance:
def post_cfg_function(args):
nonlocal attn_scores
nonlocal mid_block_shape
uncond_attn = attn_scores
sag_scale = scale

View File

@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale