backend maintain (#1429)
This commit is contained in:
parent
059037eeb2
commit
26ea508588
@ -29,9 +29,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
@ -113,7 +111,6 @@ class SelfAttentionGuidance:
|
||||
m = model.clone()
|
||||
|
||||
attn_scores = None
|
||||
mid_block_shape = None
|
||||
|
||||
# TODO: make this work properly with chunked batches
|
||||
# currently, we can only save the attn from one UNet call
|
||||
@ -136,7 +133,6 @@ class SelfAttentionGuidance:
|
||||
|
||||
def post_cfg_function(args):
|
||||
nonlocal attn_scores
|
||||
nonlocal mid_block_shape
|
||||
uncond_attn = attn_scores
|
||||
|
||||
sag_scale = scale
|
||||
|
@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user