228 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			228 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| # import pytorch_lightning as pl
 | |
| import torch.nn.functional as F
 | |
| from contextlib import contextmanager
 | |
| from typing import Any, Dict, List, Optional, Tuple, Union
 | |
| 
 | |
| from fcbh.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
 | |
| 
 | |
| from fcbh.ldm.util import instantiate_from_config
 | |
| from fcbh.ldm.modules.ema import LitEma
 | |
| 
 | |
| class DiagonalGaussianRegularizer(torch.nn.Module):
 | |
|     def __init__(self, sample: bool = True):
 | |
|         super().__init__()
 | |
|         self.sample = sample
 | |
| 
 | |
|     def get_trainable_parameters(self) -> Any:
 | |
|         yield from ()
 | |
| 
 | |
|     def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
 | |
|         log = dict()
 | |
|         posterior = DiagonalGaussianDistribution(z)
 | |
|         if self.sample:
 | |
|             z = posterior.sample()
 | |
|         else:
 | |
|             z = posterior.mode()
 | |
|         kl_loss = posterior.kl()
 | |
|         kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
 | |
|         log["kl_loss"] = kl_loss
 | |
|         return z, log
 | |
| 
 | |
| 
 | |
| class AbstractAutoencoder(torch.nn.Module):
 | |
|     """
 | |
|     This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
 | |
|     unCLIP models, etc. Hence, it is fairly general, and specific features
 | |
|     (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         ema_decay: Union[None, float] = None,
 | |
|         monitor: Union[None, str] = None,
 | |
|         input_key: str = "jpg",
 | |
|         **kwargs,
 | |
|     ):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.input_key = input_key
 | |
|         self.use_ema = ema_decay is not None
 | |
|         if monitor is not None:
 | |
|             self.monitor = monitor
 | |
| 
 | |
|         if self.use_ema:
 | |
|             self.model_ema = LitEma(self, decay=ema_decay)
 | |
|             logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
 | |
| 
 | |
|     def get_input(self, batch) -> Any:
 | |
|         raise NotImplementedError()
 | |
| 
 | |
|     def on_train_batch_end(self, *args, **kwargs):
 | |
|         # for EMA computation
 | |
|         if self.use_ema:
 | |
|             self.model_ema(self)
 | |
| 
 | |
|     @contextmanager
 | |
|     def ema_scope(self, context=None):
 | |
|         if self.use_ema:
 | |
|             self.model_ema.store(self.parameters())
 | |
|             self.model_ema.copy_to(self)
 | |
|             if context is not None:
 | |
|                 logpy.info(f"{context}: Switched to EMA weights")
 | |
|         try:
 | |
|             yield None
 | |
|         finally:
 | |
|             if self.use_ema:
 | |
|                 self.model_ema.restore(self.parameters())
 | |
|                 if context is not None:
 | |
|                     logpy.info(f"{context}: Restored training weights")
 | |
| 
 | |
|     def encode(self, *args, **kwargs) -> torch.Tensor:
 | |
|         raise NotImplementedError("encode()-method of abstract base class called")
 | |
| 
 | |
|     def decode(self, *args, **kwargs) -> torch.Tensor:
 | |
|         raise NotImplementedError("decode()-method of abstract base class called")
 | |
| 
 | |
|     def instantiate_optimizer_from_config(self, params, lr, cfg):
 | |
|         logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
 | |
|         return get_obj_from_str(cfg["target"])(
 | |
|             params, lr=lr, **cfg.get("params", dict())
 | |
|         )
 | |
| 
 | |
|     def configure_optimizers(self) -> Any:
 | |
|         raise NotImplementedError()
 | |
| 
 | |
| 
 | |
| class AutoencodingEngine(AbstractAutoencoder):
 | |
|     """
 | |
|     Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
 | |
|     (we also restore them explicitly as special cases for legacy reasons).
 | |
|     Regularizations such as KL or VQ are moved to the regularizer class.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         *args,
 | |
|         encoder_config: Dict,
 | |
|         decoder_config: Dict,
 | |
|         regularizer_config: Dict,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         super().__init__(*args, **kwargs)
 | |
| 
 | |
|         self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
 | |
|         self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
 | |
|         self.regularization: AbstractRegularizer = instantiate_from_config(
 | |
|             regularizer_config
 | |
|         )
 | |
| 
 | |
|     def get_last_layer(self):
 | |
|         return self.decoder.get_last_layer()
 | |
| 
 | |
|     def encode(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         return_reg_log: bool = False,
 | |
|         unregularized: bool = False,
 | |
|     ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
 | |
|         z = self.encoder(x)
 | |
|         if unregularized:
 | |
|             return z, dict()
 | |
|         z, reg_log = self.regularization(z)
 | |
|         if return_reg_log:
 | |
|             return z, reg_log
 | |
|         return z
 | |
| 
 | |
|     def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
 | |
|         x = self.decoder(z, **kwargs)
 | |
|         return x
 | |
| 
 | |
|     def forward(
 | |
|         self, x: torch.Tensor, **additional_decode_kwargs
 | |
|     ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
 | |
|         z, reg_log = self.encode(x, return_reg_log=True)
 | |
|         dec = self.decode(z, **additional_decode_kwargs)
 | |
|         return z, dec, reg_log
 | |
| 
 | |
| 
 | |
| class AutoencodingEngineLegacy(AutoencodingEngine):
 | |
|     def __init__(self, embed_dim: int, **kwargs):
 | |
|         self.max_batch_size = kwargs.pop("max_batch_size", None)
 | |
|         ddconfig = kwargs.pop("ddconfig")
 | |
|         super().__init__(
 | |
|             encoder_config={
 | |
|                 "target": "fcbh.ldm.modules.diffusionmodules.model.Encoder",
 | |
|                 "params": ddconfig,
 | |
|             },
 | |
|             decoder_config={
 | |
|                 "target": "fcbh.ldm.modules.diffusionmodules.model.Decoder",
 | |
|                 "params": ddconfig,
 | |
|             },
 | |
|             **kwargs,
 | |
|         )
 | |
|         self.quant_conv = torch.nn.Conv2d(
 | |
|             (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
 | |
|             (1 + ddconfig["double_z"]) * embed_dim,
 | |
|             1,
 | |
|         )
 | |
|         self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
 | |
|         self.embed_dim = embed_dim
 | |
| 
 | |
|     def get_autoencoder_params(self) -> list:
 | |
|         params = super().get_autoencoder_params()
 | |
|         return params
 | |
| 
 | |
|     def encode(
 | |
|         self, x: torch.Tensor, return_reg_log: bool = False
 | |
|     ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
 | |
|         if self.max_batch_size is None:
 | |
|             z = self.encoder(x)
 | |
|             z = self.quant_conv(z)
 | |
|         else:
 | |
|             N = x.shape[0]
 | |
|             bs = self.max_batch_size
 | |
|             n_batches = int(math.ceil(N / bs))
 | |
|             z = list()
 | |
|             for i_batch in range(n_batches):
 | |
|                 z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
 | |
|                 z_batch = self.quant_conv(z_batch)
 | |
|                 z.append(z_batch)
 | |
|             z = torch.cat(z, 0)
 | |
| 
 | |
|         z, reg_log = self.regularization(z)
 | |
|         if return_reg_log:
 | |
|             return z, reg_log
 | |
|         return z
 | |
| 
 | |
|     def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
 | |
|         if self.max_batch_size is None:
 | |
|             dec = self.post_quant_conv(z)
 | |
|             dec = self.decoder(dec, **decoder_kwargs)
 | |
|         else:
 | |
|             N = z.shape[0]
 | |
|             bs = self.max_batch_size
 | |
|             n_batches = int(math.ceil(N / bs))
 | |
|             dec = list()
 | |
|             for i_batch in range(n_batches):
 | |
|                 dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
 | |
|                 dec_batch = self.decoder(dec_batch, **decoder_kwargs)
 | |
|                 dec.append(dec_batch)
 | |
|             dec = torch.cat(dec, 0)
 | |
| 
 | |
|         return dec
 | |
| 
 | |
| 
 | |
| class AutoencoderKL(AutoencodingEngineLegacy):
 | |
|     def __init__(self, **kwargs):
 | |
|         if "lossconfig" in kwargs:
 | |
|             kwargs["loss_config"] = kwargs.pop("lossconfig")
 | |
|         super().__init__(
 | |
|             regularizer_config={
 | |
|                 "target": (
 | |
|                     "fcbh.ldm.models.autoencoder.DiagonalGaussianRegularizer"
 | |
|                 )
 | |
|             },
 | |
|             **kwargs,
 | |
|         )
 |