backend maintain (#1429)
This commit is contained in:
parent
059037eeb2
commit
26ea508588
@ -29,9 +29,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
|
|||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
# force cast to fp32 to avoid overflowing
|
||||||
if _ATTN_PRECISION =="fp32":
|
if _ATTN_PRECISION =="fp32":
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||||
q, k = q.float(), k.float()
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
|
||||||
else:
|
else:
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
|
||||||
@ -113,7 +111,6 @@ class SelfAttentionGuidance:
|
|||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
attn_scores = None
|
attn_scores = None
|
||||||
mid_block_shape = None
|
|
||||||
|
|
||||||
# TODO: make this work properly with chunked batches
|
# TODO: make this work properly with chunked batches
|
||||||
# currently, we can only save the attn from one UNet call
|
# currently, we can only save the attn from one UNet call
|
||||||
@ -136,7 +133,6 @@ class SelfAttentionGuidance:
|
|||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
nonlocal attn_scores
|
nonlocal attn_scores
|
||||||
nonlocal mid_block_shape
|
|
||||||
uncond_attn = attn_scores
|
uncond_attn = attn_scores
|
||||||
|
|
||||||
sag_scale = scale
|
sag_scale = scale
|
||||||
|
@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
# force cast to fp32 to avoid overflowing
|
||||||
if _ATTN_PRECISION =="fp32":
|
if _ATTN_PRECISION =="fp32":
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||||
q, k = q.float(), k.float()
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
|
||||||
else:
|
else:
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user