i
This commit is contained in:
		
							parent
							
								
									33133615fd
								
							
						
					
					
						commit
						5bf4a08b5e
					
				
							
								
								
									
										83
									
								
								entry.py
									
									
									
									
									
								
							
							
						
						
									
										83
									
								
								entry.py
									
									
									
									
									
								
							| @ -1,4 +1,6 @@ | ||||
| import os | ||||
| import math | ||||
| import numpy as np | ||||
| import torch | ||||
| import safetensors.torch | ||||
| 
 | ||||
| @ -7,6 +9,68 @@ from sgm.util import instantiate_from_config | ||||
| 
 | ||||
| from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler | ||||
| 
 | ||||
| 
 | ||||
| def get_unique_embedder_keys_from_conditioner(conditioner): | ||||
|     return list(set([x.input_key for x in conditioner.embedders])) | ||||
| 
 | ||||
| 
 | ||||
| def get_batch(keys, value_dict, N, device="cuda"): | ||||
|     # Hardcoded demo setups; might undergo some changes in the future | ||||
| 
 | ||||
|     batch = {} | ||||
|     batch_uc = {} | ||||
| 
 | ||||
|     for key in keys: | ||||
|         if key == "txt": | ||||
|             batch["txt"] = ( | ||||
|                 np.repeat([value_dict["prompt"]], repeats=math.prod(N)) | ||||
|                 .reshape(N) | ||||
|                 .tolist() | ||||
|             ) | ||||
|             batch_uc["txt"] = ( | ||||
|                 np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) | ||||
|                 .reshape(N) | ||||
|                 .tolist() | ||||
|             ) | ||||
|         elif key == "original_size_as_tuple": | ||||
|             batch["original_size_as_tuple"] = ( | ||||
|                 torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) | ||||
|                 .to(device) | ||||
|                 .repeat(*N, 1) | ||||
|             ) | ||||
|         elif key == "crop_coords_top_left": | ||||
|             batch["crop_coords_top_left"] = ( | ||||
|                 torch.tensor( | ||||
|                     [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] | ||||
|                 ) | ||||
|                 .to(device) | ||||
|                 .repeat(*N, 1) | ||||
|             ) | ||||
|         elif key == "aesthetic_score": | ||||
|             batch["aesthetic_score"] = ( | ||||
|                 torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) | ||||
|             ) | ||||
|             batch_uc["aesthetic_score"] = ( | ||||
|                 torch.tensor([value_dict["negative_aesthetic_score"]]) | ||||
|                 .to(device) | ||||
|                 .repeat(*N, 1) | ||||
|             ) | ||||
| 
 | ||||
|         elif key == "target_size_as_tuple": | ||||
|             batch["target_size_as_tuple"] = ( | ||||
|                 torch.tensor([value_dict["target_height"], value_dict["target_width"]]) | ||||
|                 .to(device) | ||||
|                 .repeat(*N, 1) | ||||
|             ) | ||||
|         else: | ||||
|             batch[key] = value_dict[key] | ||||
| 
 | ||||
|     for key in batch.keys(): | ||||
|         if key not in batch_uc and isinstance(batch[key], torch.Tensor): | ||||
|             batch_uc[key] = torch.clone(batch[key]) | ||||
|     return batch, batch_uc | ||||
| 
 | ||||
| 
 | ||||
| sampler = EulerAncestralSampler( | ||||
|     num_steps=40, | ||||
|     discretization_config={ | ||||
| @ -31,4 +95,23 @@ model.eval() | ||||
| sd = safetensors.torch.load_file('./sd_xl_base_1.0.safetensors') | ||||
| model.load_state_dict(sd, strict=False) | ||||
| 
 | ||||
| model.conditioner.cuda() | ||||
| 
 | ||||
| value_dict = { | ||||
|     "prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024, | ||||
|     "crop_coords_top": 0, "crop_coords_left": 0, "target_height": 1024, "target_width": 1024, "aesthetic_score": 7.5, | ||||
|     "negative_aesthetic_score": 2.0, | ||||
| } | ||||
| 
 | ||||
| batch, batch_uc = get_batch( | ||||
|     get_unique_embedder_keys_from_conditioner(model.conditioner), | ||||
|     value_dict, | ||||
|     1, | ||||
| ) | ||||
| 
 | ||||
| c, uc = model.conditioner.get_unconditional_conditioning( | ||||
|     batch, | ||||
|     batch_uc=batch_uc) | ||||
| model.conditioner.cpu() | ||||
| 
 | ||||
| a = 0 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user