Spaces:
Running
on
Zero
Running
on
Zero
| # | |
| # For licensing see accompanying LICENSE file. | |
| # Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| # | |
| from typing import Mapping, Text, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| from einops import rearrange | |
| from torch.cuda.amp import autocast | |
| from .lpips import LPIPS | |
| from .discriminator import NLayerDiscriminator, NLayer3DDiscriminator | |
| _IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| _IMAGENET_STD = [0.229, 0.224, 0.225] | |
| def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor: | |
| """Hinge loss for discrminator. | |
| This function is borrowed from | |
| https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20 | |
| """ | |
| loss_real = torch.mean(F.relu(1.0 - logits_real)) | |
| loss_fake = torch.mean(F.relu(1.0 + logits_fake)) | |
| d_loss = 0.5 * (loss_real + loss_fake) | |
| return d_loss | |
| def compute_lecam_loss( | |
| logits_real_mean: torch.Tensor, | |
| logits_fake_mean: torch.Tensor, | |
| ema_logits_real_mean: torch.Tensor, | |
| ema_logits_fake_mean: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Computes the LeCam loss for the given average real and fake logits. | |
| Args: | |
| logits_real_mean -> torch.Tensor: The average real logits. | |
| logits_fake_mean -> torch.Tensor: The average fake logits. | |
| ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits. | |
| ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits. | |
| Returns: | |
| lecam_loss -> torch.Tensor: The LeCam loss. | |
| """ | |
| lecam_loss = torch.mean(torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2)) | |
| lecam_loss += torch.mean(torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2)) | |
| return lecam_loss | |
| class PerceptualLoss(torch.nn.Module): | |
| def __init__(self, dist, model_name: str = "convnext_s"): | |
| """Initializes the PerceptualLoss class. | |
| Args: | |
| model_name: A string, the name of the perceptual loss model to use. | |
| Raise: | |
| ValueError: If the model_name does not contain "lpips" or "convnext_s". | |
| """ | |
| super().__init__() | |
| if ("lpips" not in model_name) and ( | |
| "convnext_s" not in model_name): | |
| raise ValueError(f"Unsupported Perceptual Loss model name {model_name}") | |
| self.dist = dist | |
| self.lpips = None | |
| self.convnext = None | |
| self.loss_weight_lpips = None | |
| self.loss_weight_convnext = None | |
| # Parsing the model name. We support name formatted in | |
| # "lpips-convnext_s-{float_number}-{float_number}", where the | |
| # {float_number} refers to the loss weight for each component. | |
| # E.g., lpips-convnext_s-1.0-2.0 refers to compute the perceptual loss | |
| # using both the convnext_s and lpips, and average the final loss with | |
| # (1.0 * loss(lpips) + 2.0 * loss(convnext_s)) / (1.0 + 2.0). | |
| if "lpips" in model_name: | |
| self.lpips = LPIPS(dist).eval() | |
| if "convnext_s" in model_name: | |
| self.convnext = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).eval() | |
| if "lpips" in model_name and "convnext_s" in model_name: | |
| loss_config = model_name.split('-')[-2:] | |
| self.loss_weight_lpips, self.loss_weight_convnext = float(loss_config[0]), float(loss_config[1]) | |
| print(f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}") | |
| self.register_buffer("imagenet_mean", torch.Tensor(_IMAGENET_MEAN)[None, :, None, None]) | |
| self.register_buffer("imagenet_std", torch.Tensor(_IMAGENET_STD)[None, :, None, None]) | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, input: torch.Tensor, target: torch.Tensor): | |
| """Computes the perceptual loss. | |
| Args: | |
| input: A tensor of shape (B, C, H, W), the input image. Normalized to [0, 1]. | |
| target: A tensor of shape (B, C, H, W), the target image. Normalized to [0, 1]. | |
| Returns: | |
| A scalar tensor, the perceptual loss. | |
| """ | |
| if input.dim() == 5: | |
| # If the input is 5D, we assume it is a batch of videos. | |
| # We will average the loss over the temporal dimension. | |
| input = rearrange(input, "b t c h w -> (b t) c h w") | |
| target = rearrange(target, "b t c h w -> (b t) c h w") | |
| # Always in eval mode. | |
| self.eval() | |
| loss = 0. | |
| num_losses = 0. | |
| lpips_loss = 0. | |
| convnext_loss = 0. | |
| # Computes LPIPS loss, if available. | |
| if self.lpips is not None: | |
| lpips_loss = self.lpips(input, target) | |
| if self.loss_weight_lpips is None: | |
| loss += lpips_loss | |
| num_losses += 1 | |
| else: | |
| num_losses += self.loss_weight_lpips | |
| loss += self.loss_weight_lpips * lpips_loss | |
| if self.convnext is not None: | |
| # Computes ConvNeXt-s loss, if available. | |
| input = torch.nn.functional.interpolate(input, size=224, mode="bilinear", align_corners=False, antialias=True) | |
| target = torch.nn.functional.interpolate(target, size=224, mode="bilinear", align_corners=False, antialias=True) | |
| pred_input = self.convnext((input - self.imagenet_mean) / self.imagenet_std) | |
| pred_target = self.convnext((target - self.imagenet_mean) / self.imagenet_std) | |
| convnext_loss = torch.nn.functional.mse_loss( | |
| pred_input, | |
| pred_target, | |
| reduction="mean") | |
| if self.loss_weight_convnext is None: | |
| num_losses += 1 | |
| loss += convnext_loss | |
| else: | |
| num_losses += self.loss_weight_convnext | |
| loss += self.loss_weight_convnext * convnext_loss | |
| # weighted avg. | |
| loss = loss / num_losses | |
| return loss | |
| class WaveletLoss3D(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, inputs, targets): | |
| from torch_dwt.functional import dwt3 | |
| inputs, targets = inputs.float(), targets.float() | |
| l1_loss = torch.abs( | |
| dwt3(inputs.contiguous(), "haar") - dwt3(targets.contiguous(), "haar") | |
| ) | |
| # Average over the number of wavelet filters, reducing the dimensions | |
| l1_loss = torch.mean(l1_loss, dim=1) | |
| # Average over all of the filter banks, keeping dimensions | |
| l1_loss = torch.mean(l1_loss, dim=-1, keepdim=True) | |
| l1_loss = torch.mean(l1_loss, dim=-2, keepdim=True) | |
| l1_loss = torch.mean(l1_loss, dim=-3, keepdim=True) | |
| return l1_loss | |
| class ReconstructionLoss_Single_Stage(torch.nn.Module): | |
| def __init__(self, dist, args): | |
| """Initializes the losses module. | |
| Args: | |
| config: A dictionary, the configuration for the model and everything else. | |
| """ | |
| super().__init__() | |
| self.dist = dist | |
| self.with_condition = False | |
| self.quantize_mode = 'vae' | |
| self.discriminator = NLayerDiscriminator(with_condition=False).eval() if not args.use_3d_disc else NLayer3DDiscriminator(with_condition=False).eval() | |
| self.reconstruction_loss = "l2" | |
| self.reconstruction_weight = 1.0 | |
| self.quantizer_weight = 1.0 | |
| self.perceptual_loss = PerceptualLoss(dist, "lpips-convnext_s-1.0-0.1").eval() | |
| self.perceptual_weight = 1.1 | |
| self.discriminator_iter_start = 0 | |
| self.discriminator_factor = 1.0 | |
| self.discriminator_weight = 0.1 | |
| self.lecam_regularization_weight = 0.001 | |
| self.lecam_ema_decay = 0.999 | |
| self.kl_weight = 1e-6 | |
| self.wavelet_loss_weight = 0.5 | |
| self.wavelet_loss = WaveletLoss3D() | |
| self.logvar = nn.Parameter(torch.ones(size=()) * 0.0, requires_grad=False) | |
| if self.lecam_regularization_weight > 0.0: | |
| self.register_buffer("ema_real_logits_mean", torch.zeros((1))) | |
| self.register_buffer("ema_fake_logits_mean", torch.zeros((1))) | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| reconstructions: torch.Tensor, | |
| extra_result_dict: Mapping[Text, torch.Tensor], | |
| global_step: int, | |
| mode: str = "generator", | |
| ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: | |
| # Both inputs and reconstructions are in range [0, 1]. | |
| inputs = inputs.float() | |
| reconstructions = reconstructions.float() | |
| if mode == "generator": | |
| return self._forward_generator(inputs, reconstructions, extra_result_dict, global_step) | |
| elif mode == "discriminator": | |
| return self._forward_discriminator(inputs, reconstructions, extra_result_dict, global_step) | |
| else: | |
| raise ValueError(f"Unsupported mode {mode}") | |
| def should_discriminator_be_trained(self, global_step : int): | |
| return global_step >= self.discriminator_iter_start | |
| def _forward_discriminator(self, | |
| inputs: torch.Tensor, | |
| reconstructions: torch.Tensor, | |
| extra_result_dict: Mapping[Text, torch.Tensor], | |
| global_step: int, | |
| ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: | |
| """Discrminator training step.""" | |
| discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0 | |
| loss_dict = {} | |
| # Turn the gradients on. | |
| for param in self.discriminator.parameters(): | |
| param.requires_grad = True | |
| condition = extra_result_dict.get("condition", None) if self.with_condition else None | |
| real_images = inputs.detach().requires_grad_(True) | |
| logits_real = self.discriminator(real_images, condition) | |
| logits_fake = self.discriminator(reconstructions.detach(), condition) | |
| discriminator_loss = discriminator_factor * hinge_d_loss(logits_real=logits_real, logits_fake=logits_fake) | |
| # optional lecam regularization | |
| lecam_loss = torch.zeros((), device=inputs.device) | |
| if self.lecam_regularization_weight > 0.0: | |
| lecam_loss = compute_lecam_loss( | |
| torch.mean(logits_real), | |
| torch.mean(logits_fake), | |
| self.ema_real_logits_mean, | |
| self.ema_fake_logits_mean | |
| ) * self.lecam_regularization_weight | |
| self.ema_real_logits_mean = self.ema_real_logits_mean * self.lecam_ema_decay + torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay) | |
| self.ema_fake_logits_mean = self.ema_fake_logits_mean * self.lecam_ema_decay + torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay) | |
| discriminator_loss += lecam_loss | |
| loss_dict = dict( | |
| discriminator_loss=discriminator_loss.detach(), | |
| logits_real=logits_real.detach().mean(), | |
| logits_fake=logits_fake.detach().mean(), | |
| lecam_loss=lecam_loss.detach(), | |
| ) | |
| return discriminator_loss, loss_dict | |
| def _forward_generator(self, | |
| inputs: torch.Tensor, | |
| reconstructions: torch.Tensor, | |
| extra_result_dict: Mapping[Text, torch.Tensor], | |
| global_step: int | |
| ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: | |
| """Generator training step.""" | |
| inputs = inputs.contiguous() | |
| reconstructions = reconstructions.contiguous() | |
| if self.reconstruction_loss == "l1": | |
| reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") | |
| elif self.reconstruction_loss == "l2": | |
| reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") | |
| else: | |
| raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}") | |
| reconstruction_loss *= self.reconstruction_weight | |
| # Compute wavelet loss. | |
| if inputs.dim() == 5: | |
| wavelet_loss = self.wavelet_loss( | |
| inputs.permute(0,2,1,3,4), reconstructions.permute(0,2,1,3,4)).mean() | |
| else: | |
| wavelet_loss = 0 | |
| # Compute perceptual loss. | |
| perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() | |
| # Compute discriminator loss. | |
| generator_loss = torch.zeros((), device=inputs.device) | |
| discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0 | |
| d_weight = 1.0 | |
| if discriminator_factor > 0.0 and self.discriminator_weight > 0.0: | |
| # Disable discriminator gradients. | |
| for param in self.discriminator.parameters(): | |
| param.requires_grad = False | |
| logits_fake = self.discriminator(reconstructions) | |
| generator_loss = -torch.mean(logits_fake) | |
| d_weight *= self.discriminator_weight | |
| assert self.quantize_mode == "vae", "Only vae mode is supported for now" | |
| # Compute kl loss. | |
| reconstruction_loss = reconstruction_loss / torch.exp(self.logvar) | |
| total_loss = ( | |
| reconstruction_loss | |
| + self.perceptual_weight * perceptual_loss | |
| + d_weight * discriminator_factor * generator_loss | |
| + self.wavelet_loss_weight * wavelet_loss | |
| ) | |
| loss_dict = dict( | |
| total_loss=total_loss.clone().detach(), | |
| reconstruction_loss=reconstruction_loss.detach(), | |
| perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), | |
| weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(), | |
| discriminator_factor=torch.tensor(discriminator_factor), | |
| d_weight=d_weight, | |
| gan_loss=generator_loss.detach(), | |
| wavelet_loss=(self.wavelet_loss_weight * wavelet_loss).detach(), | |
| ) | |
| return total_loss, loss_dict | |