diff --git a/entry.py b/entry.py index 7e230e5..758fef7 100644 --- a/entry.py +++ b/entry.py @@ -1,4 +1,6 @@ import os +import math +import numpy as np import torch import safetensors.torch @@ -7,6 +9,68 @@ from sgm.util import instantiate_from_config from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + +def get_batch(keys, value_dict, N, device="cuda"): + # Hardcoded demo setups; might undergo some changes in the future + + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = ( + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + batch_uc["txt"] = ( + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) + .to(device) + .repeat(*N, 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) + .to(device) + .repeat(*N, 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = ( + torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + ) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]) + .to(device) + .repeat(*N, 1) + ) + else: + batch[key] = value_dict[key] + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + sampler = EulerAncestralSampler( num_steps=40, discretization_config={ @@ -31,4 +95,23 @@ model.eval() sd = safetensors.torch.load_file('./sd_xl_base_1.0.safetensors') model.load_state_dict(sd, strict=False) +model.conditioner.cuda() + +value_dict = { + "prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024, + "crop_coords_top": 0, "crop_coords_left": 0, "target_height": 1024, "target_width": 1024, "aesthetic_score": 7.5, + "negative_aesthetic_score": 2.0, +} + +batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + 1, +) + +c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc) +model.conditioner.cpu() + a = 0