i
This commit is contained in:
parent
33133615fd
commit
5bf4a08b5e
83
entry.py
83
entry.py
@ -1,4 +1,6 @@
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
@ -7,6 +9,68 @@ from sgm.util import instantiate_from_config
|
||||
|
||||
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list(set([x.input_key for x in conditioner.embedders]))
|
||||
|
||||
|
||||
def get_batch(keys, value_dict, N, device="cuda"):
|
||||
# Hardcoded demo setups; might undergo some changes in the future
|
||||
|
||||
batch = {}
|
||||
batch_uc = {}
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor(
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
for key in batch.keys():
|
||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||
batch_uc[key] = torch.clone(batch[key])
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
sampler = EulerAncestralSampler(
|
||||
num_steps=40,
|
||||
discretization_config={
|
||||
@ -31,4 +95,23 @@ model.eval()
|
||||
sd = safetensors.torch.load_file('./sd_xl_base_1.0.safetensors')
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
model.conditioner.cuda()
|
||||
|
||||
value_dict = {
|
||||
"prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024,
|
||||
"crop_coords_top": 0, "crop_coords_left": 0, "target_height": 1024, "target_width": 1024, "aesthetic_score": 7.5,
|
||||
"negative_aesthetic_score": 2.0,
|
||||
}
|
||||
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
1,
|
||||
)
|
||||
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc)
|
||||
model.conditioner.cpu()
|
||||
|
||||
a = 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user