fix many inpaint bugs (#731)

fix many inpaint bugs
This commit is contained in:
lllyasviel 2023-10-18 06:22:08 -07:00 committed by GitHub
parent d2c8f16082
commit 9660daff94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 131 additions and 96 deletions

View File

@ -1 +1 @@
version = '2.1.702'
version = '2.1.703'

View File

@ -400,43 +400,42 @@ def worker():
pipeline.final_unet.model.diffusion_model.in_inpaint = True
# print(f'Inpaint task: {str((height, width))}')
# outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()])
# return
progressbar(13, 'VAE encoding ...')
inpaint_pixels = core.numpy_to_pytorch(inpaint_worker.current_task.image_ready)
initial_latent = core.encode_vae(vae=pipeline.final_vae, pixels=inpaint_pixels)
inpaint_latent = initial_latent['samples']
B, C, H, W = inpaint_latent.shape
inpaint_mask = core.numpy_to_pytorch(inpaint_worker.current_task.mask_ready[None])
inpaint_mask = torch.nn.functional.avg_pool2d(inpaint_mask, (8, 8))
inpaint_mask = torch.nn.functional.interpolate(inpaint_mask, (H, W), mode='bilinear')
progressbar(13, 'VAE Inpaint encoding ...')
latent_after_swap = None
inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
inpaint_pixel_mask = core.numpy_to_pytorch(inpaint_worker.current_task.interested_mask)
latent_inpaint, latent_mask = core.encode_vae_inpaint(
mask=inpaint_pixel_mask,
vae=pipeline.final_vae,
pixels=inpaint_pixel_image)
latent_swap = None
if pipeline.final_refiner_vae is not None:
progressbar(13, 'VAE SD15 encoding ...')
latent_after_swap = core.encode_vae(vae=pipeline.final_refiner_vae, pixels=inpaint_pixels)['samples']
progressbar(13, 'VAE Inpaint SD15 encoding ...')
latent_swap = core.encode_vae(
vae=pipeline.final_refiner_vae,
pixels=inpaint_pixel_fill)['samples']
inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask,
latent_after_swap=latent_after_swap)
progressbar(13, 'VAE encoding ...')
latent_fill = core.encode_vae(
vae=pipeline.final_vae,
pixels=inpaint_pixel_fill)['samples']
progressbar(13, 'VAE inpaint encoding ...')
inpaint_worker.current_task.load_latent(latent_fill=latent_fill,
latent_inpaint=latent_inpaint,
latent_mask=latent_mask,
latent_swap=latent_swap,
inpaint_head_model_path=inpaint_head_model_path)
inpaint_mask = (inpaint_worker.current_task.mask_ready > 0).astype(np.float32)
inpaint_mask = torch.tensor(inpaint_mask).float()
vae_dict = core.encode_vae_inpaint(
mask=inpaint_mask, vae=pipeline.final_vae, pixels=inpaint_pixels)
inpaint_latent = vae_dict['samples']
inpaint_mask = vae_dict['noise_mask']
inpaint_worker.current_task.load_inpaint_guidance(latent=inpaint_latent, mask=inpaint_mask,
model_path=inpaint_head_model_path)
B, C, H, W = inpaint_latent.shape
final_height, final_width = inpaint_worker.current_task.image_raw.shape[:2]
B, C, H, W = latent_fill.shape
height, width = H * 8, W * 8
final_height, final_width = inpaint_worker.current_task.image.shape[:2]
initial_latent = {'samples': latent_fill}
print(f'Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.')
if 'cn' in goals:

View File

@ -18,7 +18,7 @@ import fcbh.samplers
import fcbh.latent_formats
from fcbh.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint, \
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, \
ControlNetApplyAdvanced
from fcbh_extras.nodes_freelunch import FreeU_V2
from fcbh.sample import prepare_mask
@ -32,7 +32,6 @@ opVAEDecode = VAEDecode()
opVAEEncode = VAEEncode()
opVAEDecodeTiled = VAEDecodeTiled()
opVAEEncodeTiled = VAEEncodeTiled()
opVAEEncodeForInpaint = VAEEncodeForInpaint()
opControlNetApplyAdvanced = ControlNetApplyAdvanced()
opFreeU = FreeU_V2()
@ -130,7 +129,21 @@ def encode_vae(vae, pixels, tiled=False):
@torch.no_grad()
@torch.inference_mode()
def encode_vae_inpaint(vae, pixels, mask):
return opVAEEncodeForInpaint.encode(pixels=pixels, vae=vae, mask=mask)[0]
assert mask.ndim == 3 and pixels.ndim == 4
assert mask.shape[-1] == pixels.shape[-2]
assert mask.shape[-2] == pixels.shape[-3]
w = mask.round()[..., None]
pixels = pixels * (1 - w) + 0.5 * w
latent = vae.encode(pixels)
B, C, H, W = latent.shape
latent_mask = mask[:, None, :, :]
latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round()
latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round()
return latent, latent_mask
class VAEApprox(torch.nn.Module):

View File

@ -445,6 +445,9 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
if refiner_swap_method == 'vae':
if modules.inpaint_worker.current_task is not None:
modules.inpaint_worker.current_task.unswap()
sample_hijack.history_record = []
core.ksampler(
model=final_unet,
@ -517,9 +520,6 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
noise=refiner_noise
)
if modules.inpaint_worker.current_task is not None:
modules.inpaint_worker.current_task.swap()
target_model = final_refiner_vae
if target_model is None:
target_model = final_vae

View File

@ -43,6 +43,12 @@ def morphological_open(x):
return x_int32.clip(0, 255).astype(np.uint8)
def up255(x, t=0):
y = np.zeros_like(x).astype(np.uint8)
y[x > t] = 255
return y
def imsave(x, path):
x = Image.fromarray(x)
x.save(path)
@ -75,21 +81,25 @@ def compute_initial_abcd(x):
b = np.max(indices[0]) + 65
c = np.min(indices[1]) - 64
d = np.max(indices[1]) + 65
abp = (b + a) // 2
abm = (b - a) // 2
cdp = (d + c) // 2
cdm = (d - c) // 2
l = max(abm, cdm)
a = abp - l
b = abp + l
c = cdp - l
d = cdp + l
a, b, c, d = regulate_abcd(x, a, b, c, d)
return a, b, c, d
def area_abcd(a, b, c, d):
return (b - a) * (d - c)
def solve_abcd(x, a, b, c, d, outpaint):
H, W = x.shape[:2]
if outpaint:
return 0, H, 0, W
min_area = (min(H, W) ** 2) * 0.5
while True:
if area_abcd(a, b, c, d) >= min_area:
if b - a > H * 0.618 and d - c > W * 0.618:
break
add_h = (b - a) < (d - c)
@ -119,7 +129,7 @@ def fooocus_fill(image, mask):
area = np.where(mask < 127)
store = raw_image[area]
for k, repeats in [(64, 4), (32, 4), (16, 4), (4, 4), (2, 4)]:
for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
for _ in range(repeats):
current_image = box_blur(current_image, k)
current_image[area] = store
@ -129,98 +139,107 @@ def fooocus_fill(image, mask):
class InpaintWorker:
def __init__(self, image, mask, is_outpaint):
# mask processing
self.mask_raw_soft = morphological_open(mask)
self.mask_raw_fg = (self.mask_raw_soft == 255).astype(np.uint8) * 255
self.mask_raw_bg = (self.mask_raw_soft == 0).astype(np.uint8) * 255
self.mask_raw_trim = 255 - np.maximum(self.mask_raw_fg, self.mask_raw_bg)
# image processing
self.image_raw = fooocus_fill(image, self.mask_raw_fg)
# log all images
# imsave(self.image_raw, 'image_raw.png')
# imsave(self.mask_raw_soft, 'mask_raw_soft.png')
# imsave(self.mask_raw_fg, 'mask_raw_fg.png')
# imsave(self.mask_raw_bg, 'mask_raw_bg.png')
# imsave(self.mask_raw_trim, 'mask_raw_trim.png')
# compute abcd
a, b, c, d = compute_initial_abcd(self.mask_raw_bg < 127)
a, b, c, d = solve_abcd(self.mask_raw_bg, a, b, c, d, outpaint=is_outpaint)
a, b, c, d = compute_initial_abcd(mask > 0)
a, b, c, d = solve_abcd(mask, a, b, c, d, outpaint=is_outpaint)
# interested area
self.interested_area = (a, b, c, d)
self.mask_interested_soft = self.mask_raw_soft[a:b, c:d]
self.mask_interested_fg = self.mask_raw_fg[a:b, c:d]
self.mask_interested_bg = self.mask_raw_bg[a:b, c:d]
self.mask_interested_trim = self.mask_raw_trim[a:b, c:d]
self.image_interested = self.image_raw[a:b, c:d]
self.interested_mask = mask[a:b, c:d]
self.interested_image = image[a:b, c:d]
# resize to make images ready for diffusion
H, W, C = self.image_interested.shape
k = (1024.0 ** 2.0 / float(H * W)) ** 0.5
H, W, C = self.interested_image.shape
k = ((1024.0 ** 2.0) / float(H * W)) ** 0.5
H = int(np.ceil(float(H) * k / 16.0)) * 16
W = int(np.ceil(float(W) * k / 16.0)) * 16
self.image_ready = resample_image(self.image_interested, W, H)
self.mask_ready = resample_image(self.mask_interested_soft, W, H)
self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
self.interested_image = resample_image(self.interested_image, W, H)
self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
# soft pixels
self.mask = morphological_open(mask)
self.image = image
# ending
self.latent = None
self.latent_after_swap = None
self.swapped = False
self.latent_mask = None
self.inpaint_head_feature = None
return
def load_inpaint_guidance(self, latent, mask, model_path):
def load_latent(self,
latent_fill,
latent_inpaint,
latent_mask,
latent_swap=None,
inpaint_head_model_path=None):
global inpaint_head
assert inpaint_head_model_path is not None
self.latent = latent_fill
self.latent_mask = latent_mask
self.latent_after_swap = latent_swap
if inpaint_head is None:
inpaint_head = InpaintHead()
sd = torch.load(model_path, map_location='cpu')
sd = torch.load(inpaint_head_model_path, map_location='cpu')
inpaint_head.load_state_dict(sd)
process_latent_in = pipeline.xl_base_patched.unet.model.process_latent_in
latent = process_latent_in(latent)
B, C, H, W = latent.shape
mask = torch.nn.functional.interpolate(mask, size=(H, W), mode="bilinear")
mask = mask.round()
feed = torch.cat([mask, latent], dim=1)
feed = torch.cat([
latent_mask,
pipeline.xl_base_patched.unet.model.process_latent_in(latent_inpaint)
], dim=1)
inpaint_head.to(device=feed.device, dtype=feed.dtype)
self.inpaint_head_feature = inpaint_head(feed)
return
def load_latent(self, latent, mask, latent_after_swap=None):
self.latent = latent
self.latent_mask = mask
self.latent_after_swap = latent_after_swap
def swap(self):
if self.latent_after_swap is not None:
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
if self.swapped:
return
if self.latent is None:
return
if self.latent_after_swap is None:
return
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
self.swapped = True
return
def unswap(self):
if not self.swapped:
return
if self.latent is None:
return
if self.latent_after_swap is None:
return
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
self.swapped = False
return
def color_correction(self, img):
fg = img.astype(np.float32)
bg = self.image_raw.copy().astype(np.float32)
w = self.mask_raw_soft[:, :, None].astype(np.float32) / 255.0
bg = self.image.copy().astype(np.float32)
w = self.mask[:, :, None].astype(np.float32) / 255.0
y = fg * w + bg * (1 - w)
return y.clip(0, 255).astype(np.uint8)
def post_process(self, img):
a, b, c, d = self.interested_area
content = resample_image(img, d - c, b - a)
result = self.image_raw.copy()
result = self.image.copy()
result[a:b, c:d] = content
result = self.color_correction(result)
return result
def visualize_mask_processing(self):
result = self.image_raw // 4
a, b, c, d = self.interested_area
result[a:b, c:d] += 64
result[self.mask_raw_trim > 127] += 64
result[self.mask_raw_fg > 127] += 128
return [result, self.mask_raw_soft, self.image_ready, self.mask_ready]
return [self.interested_fill, self.interested_mask, self.image, self.mask]

View File

@ -1,3 +1,7 @@
# 2.1.703
* Fixed many previous problems related to inpaint.
# 2.1.702
* Corrected reading empty negative prompt from config (it shouldn't turn into None).