82 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			82 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import torch.nn as nn
 | |
| import numpy as np
 | |
| from functools import partial
 | |
| 
 | |
| from .util import extract_into_tensor, make_beta_schedule
 | |
| from fcbh.ldm.util import default
 | |
| 
 | |
| 
 | |
| class AbstractLowScaleModel(nn.Module):
 | |
|     # for concatenating a downsampled image to the latent representation
 | |
|     def __init__(self, noise_schedule_config=None):
 | |
|         super(AbstractLowScaleModel, self).__init__()
 | |
|         if noise_schedule_config is not None:
 | |
|             self.register_schedule(**noise_schedule_config)
 | |
| 
 | |
|     def register_schedule(self, beta_schedule="linear", timesteps=1000,
 | |
|                           linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
 | |
|         betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
 | |
|                                    cosine_s=cosine_s)
 | |
|         alphas = 1. - betas
 | |
|         alphas_cumprod = np.cumprod(alphas, axis=0)
 | |
|         alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
 | |
| 
 | |
|         timesteps, = betas.shape
 | |
|         self.num_timesteps = int(timesteps)
 | |
|         self.linear_start = linear_start
 | |
|         self.linear_end = linear_end
 | |
|         assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
 | |
| 
 | |
|         to_torch = partial(torch.tensor, dtype=torch.float32)
 | |
| 
 | |
|         self.register_buffer('betas', to_torch(betas))
 | |
|         self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
 | |
|         self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
 | |
| 
 | |
|         # calculations for diffusion q(x_t | x_{t-1}) and others
 | |
|         self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
 | |
|         self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
 | |
|         self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
 | |
|         self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
 | |
|         self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
 | |
| 
 | |
|     def q_sample(self, x_start, t, noise=None):
 | |
|         noise = default(noise, lambda: torch.randn_like(x_start))
 | |
|         return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
 | |
|                 extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return x, None
 | |
| 
 | |
|     def decode(self, x):
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class SimpleImageConcat(AbstractLowScaleModel):
 | |
|     # no noise level conditioning
 | |
|     def __init__(self):
 | |
|         super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
 | |
|         self.max_noise_level = 0
 | |
| 
 | |
|     def forward(self, x):
 | |
|         # fix to constant noise level
 | |
|         return x, torch.zeros(x.shape[0], device=x.device).long()
 | |
| 
 | |
| 
 | |
| class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
 | |
|     def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
 | |
|         super().__init__(noise_schedule_config=noise_schedule_config)
 | |
|         self.max_noise_level = max_noise_level
 | |
| 
 | |
|     def forward(self, x, noise_level=None):
 | |
|         if noise_level is None:
 | |
|             noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
 | |
|         else:
 | |
|             assert isinstance(noise_level, torch.Tensor)
 | |
|         z = self.q_sample(x, noise_level)
 | |
|         return z, noise_level
 | |
| 
 | |
| 
 | |
| 
 |