diff --git a/backend/headless/fcbh/ldm/modules/attention.py b/backend/headless/fcbh/ldm/modules/attention.py index f3e1b6e..c038355 100644 --- a/backend/headless/fcbh/ldm/modules/attention.py +++ b/backend/headless/fcbh/ldm/modules/attention.py @@ -160,32 +160,19 @@ def attention_sub_quad(query, key, value, heads, mask=None): mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) - chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD - kv_chunk_size_min = None + kv_chunk_size = None + query_chunk_size = None - #not sure at all about the math here - #TODO: tweak this - if mem_free_total > 8192 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 4 - elif mem_free_total > 4096 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 2 - else: - query_chunk_size_x = 1024 - kv_chunk_size_min_x = None - kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 - if kv_chunk_size_x < 1024: - kv_chunk_size_x = None + for x in [4096, 2048, 1024, 512, 256]: + count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0) + if count >= k_tokens: + kv_chunk_size = k_tokens + query_chunk_size = x + break - if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: - # the big matmul fits into our memory limit; do everything in 1 chunk, - # i.e. send it down the unchunked fast-path - query_chunk_size = q_tokens - kv_chunk_size = k_tokens - else: - query_chunk_size = query_chunk_size_x - kv_chunk_size = kv_chunk_size_x - kv_chunk_size_min = kv_chunk_size_min_x + if query_chunk_size is None: + query_chunk_size = 512 hidden_states = efficient_dot_product_attention( query, @@ -229,7 +216,7 @@ def attention_split(q, k, v, heads, mask=None): gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size - modifier = 3 if element_size == 2 else 2.5 + modifier = 3 mem_required = tensor_size * modifier steps = 1 @@ -257,10 +244,10 @@ def attention_split(q, k, v, heads, mask=None): s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale - first_op_done = True s2 = s1.softmax(dim=-1).to(v.dtype) del s1 + first_op_done = True r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 diff --git a/fooocus_version.py b/fooocus_version.py index e09972d..9360f81 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.766' +version = '2.1.767'