Spaces:
Sleeping
Sleeping
| # | |
| # For licensing see accompanying LICENSE file. | |
| # Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| # | |
| import functools | |
| import math | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| # Conv2D with same padding | |
| class Conv2dSame(nn.Conv2d): | |
| def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: | |
| return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| ih, iw = x.size()[-2:] | |
| pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) | |
| pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) | |
| if pad_h > 0 or pad_w > 0: | |
| x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) | |
| return super().forward(x) | |
| class BlurBlock(torch.nn.Module): | |
| def __init__(self, | |
| kernel: Tuple[int] = (1, 3, 3, 1) | |
| ): | |
| super().__init__() | |
| kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False) | |
| kernel = kernel[None, :] * kernel[:, None] | |
| kernel /= kernel.sum() | |
| kernel = kernel.unsqueeze(0).unsqueeze(0) | |
| self.register_buffer("kernel", kernel) | |
| def calc_same_pad(self, i: int, k: int, s: int) -> int: | |
| return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| ic, ih, iw = x.size()[-3:] | |
| pad_h = self.calc_same_pad(i=ih, k=4, s=2) | |
| pad_w = self.calc_same_pad(i=iw, k=4, s=2) | |
| if pad_h > 0 or pad_w > 0: | |
| x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) | |
| weight = self.kernel.expand(ic, -1, -1, -1) | |
| out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1]) | |
| return out | |
| class SinusoidalTimeEmbedding(torch.nn.Module): | |
| def __init__(self, embedding_dim: int): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| assert embedding_dim % 2 == 0, "embedding_dim must be even" | |
| def forward(self, timesteps: torch.Tensor) -> torch.Tensor: | |
| half_dim = self.embedding_dim // 2 | |
| embeddings = math.log(10000) / (half_dim - 1) | |
| embeddings = torch.exp(torch.arange(half_dim, device=timesteps.device) * -embeddings) | |
| embeddings = timesteps[:, None] * embeddings[None, :] | |
| embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
| return embeddings | |
| class ModulatedConv2dSame(Conv2dSame): | |
| def __init__(self, in_channels, out_channels, kernel_size, cond_channels=None): | |
| super().__init__(in_channels, out_channels, kernel_size) | |
| # FiLM modulation projections | |
| if cond_channels is not None: | |
| self.film_proj = torch.nn.Linear(cond_channels, 2 * out_channels) | |
| # Initialize scale to 0 and bias to 0 | |
| torch.nn.init.zeros_(self.film_proj.weight) | |
| torch.nn.init.zeros_(self.film_proj.bias) | |
| def forward(self, x, temb=None): | |
| x = super().forward(x) | |
| if temb is not None: | |
| scale, bias = self.film_proj(temb)[:, :, None, None].chunk(2, dim=1) | |
| x = x * (scale + 1) + bias | |
| return x | |
| class NLayerDiscriminator(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_channels: int = 3, | |
| hidden_channels: int = 128, | |
| num_stages: int = 3, | |
| blur_resample: bool = True, | |
| blur_kernel_size: int = 4, | |
| with_condition: bool = False, | |
| ): | |
| """ Initializes the NLayerDiscriminator. | |
| Args: | |
| num_channels -> int: The number of input channels. | |
| hidden_channels -> int: The number of hidden channels. | |
| num_stages -> int: The number of stages. | |
| blur_resample -> bool: Whether to use blur resampling. | |
| blur_kernel_size -> int: The blur kernel size. | |
| """ | |
| super().__init__() | |
| assert num_stages > 0, "Discriminator cannot have 0 stages" | |
| assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]" | |
| in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages))) | |
| init_kernel_size = 5 | |
| activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1) | |
| self.with_condition = with_condition | |
| if with_condition: | |
| cond_channels = 768 | |
| self.time_emb = SinusoidalTimeEmbedding(128) | |
| self.time_proj = torch.nn.Sequential( | |
| torch.nn.Linear(128, cond_channels), | |
| torch.nn.SiLU(), | |
| torch.nn.Linear(cond_channels, cond_channels), | |
| ) | |
| else: | |
| cond_channels = None | |
| self.block_in = torch.nn.Sequential( | |
| Conv2dSame( | |
| num_channels, | |
| hidden_channels, | |
| kernel_size=init_kernel_size | |
| ), | |
| activation(), | |
| ) | |
| BLUR_KERNEL_MAP = { | |
| 3: (1,2,1), | |
| 4: (1,3,3,1), | |
| 5: (1,4,6,4,1), | |
| } | |
| discriminator_blocks = [] | |
| for i_level in range(num_stages): | |
| in_channels = hidden_channels * in_channel_mult[i_level] | |
| out_channels = hidden_channels * in_channel_mult[i_level + 1] | |
| conv_block = ModulatedConv2dSame( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| cond_channels=cond_channels | |
| ) | |
| discriminator_blocks.append(conv_block) | |
| down_block = torch.nn.Sequential( | |
| torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]), | |
| torch.nn.GroupNorm(32, out_channels), | |
| activation(), | |
| ) | |
| discriminator_blocks.append(down_block) | |
| self.blocks = torch.nn.ModuleList(discriminator_blocks) | |
| self.pool = torch.nn.AdaptiveMaxPool2d((16, 16)) | |
| self.to_logits = torch.nn.Sequential( | |
| Conv2dSame(out_channels, out_channels, 1), | |
| activation(), | |
| Conv2dSame(out_channels, 1, kernel_size=5) | |
| ) | |
| def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor: | |
| """ Forward pass. | |
| Args: | |
| x -> torch.Tensor: The input tensor. | |
| Returns: | |
| output -> torch.Tensor: The output tensor. | |
| """ | |
| if x.dim() == 5: | |
| x = rearrange(x, 'b t c h w -> (b t) c h w') | |
| hidden_states = self.block_in(x) | |
| if condition is not None and self.with_condition: | |
| temb = self.time_proj(self.time_emb(condition * 1000.0)) | |
| else: | |
| temb = None | |
| for i, block in enumerate(self.blocks): | |
| if i % 2 == 0: | |
| hidden_states = block(hidden_states, temb) # conv_block | |
| else: | |
| hidden_states = block(hidden_states) # down_block | |
| hidden_states = self.pool(hidden_states) | |
| return self.to_logits(hidden_states) | |
| # 3D discriminator | |
| class Conv3dSame(nn.Conv3d): | |
| def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: | |
| return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| it, ih, iw = x.size()[-3:] # frame, height, width | |
| pad_t = self.calc_same_pad(i=it, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) | |
| pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) | |
| pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[2], s=self.stride[2], d=self.dilation[2]) | |
| if pad_t > 0 or pad_h > 0 or pad_w > 0: | |
| x = F.pad( | |
| x, | |
| [pad_w // 2, pad_w - pad_w // 2, | |
| pad_h // 2, pad_h - pad_h // 2, | |
| pad_t // 2, pad_t - pad_t // 2], | |
| ) | |
| return super().forward(x) | |
| class ModulatedConv3dSame(Conv3dSame): | |
| def __init__(self, in_channels, out_channels, kernel_size, cond_channels=None): | |
| super().__init__(in_channels, out_channels, kernel_size) | |
| # FiLM modulation | |
| if cond_channels is not None: | |
| self.film_proj = torch.nn.Linear(cond_channels, 2 * out_channels) | |
| # Initialize FiLM params (scale to 0, bias to 0) | |
| torch.nn.init.zeros_(self.film_proj.weight) | |
| torch.nn.init.zeros_(self.film_proj.bias) | |
| def forward(self, x, temb=None): | |
| x = super().forward(x) # (B, C, T, H, W) | |
| if temb is not None: | |
| scale, bias = self.film_proj(temb)[:, :, None, None, None].chunk(2, dim=1) | |
| x = x * (scale + 1) + bias | |
| return x | |
| class BlurBlock3D(nn.Module): | |
| def __init__(self, kernel=(1, 3, 3, 1), stride=(1, 2, 2)): | |
| """ | |
| 3D BlurPool block. | |
| Applies blur to spatial dimensions only by default. | |
| """ | |
| super().__init__() | |
| self.stride = stride | |
| kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False) | |
| kernel = kernel[None, :] * kernel[:, None] | |
| kernel /= kernel.sum() | |
| kernel = kernel.unsqueeze(0).unsqueeze(0).unsqueeze(0) # shape: (1, 1, 1, H, W) | |
| self.register_buffer("kernel", kernel) | |
| def calc_same_pad(self, i: int, k: int, s: int) -> int: | |
| return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| _, c, t, h, w = x.shape | |
| kd, kh, kw = self.kernel.shape[-3:] | |
| sd, sh, sw = self.stride | |
| # Only apply padding to H and W | |
| pad_h = self.calc_same_pad(h, kh, sh) | |
| pad_w = self.calc_same_pad(w, kw, sw) | |
| pad_d = 0 if sd == 1 else self.calc_same_pad(t, kd, sd) | |
| if pad_h > 0 or pad_w > 0 or pad_d > 0: | |
| x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, | |
| pad_h // 2, pad_h - pad_h // 2, | |
| pad_d // 2, pad_d - pad_d // 2]) | |
| weight = self.kernel.expand(c, 1, -1, -1, -1) | |
| return F.conv3d(x, weight, stride=self.stride, groups=c) | |
| class NLayer3DDiscriminator(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_channels: int = 3, | |
| hidden_channels: int = 128, | |
| num_stages: int = 3, | |
| blur_resample: bool = True, | |
| blur_kernel_size: int = 4, | |
| with_condition: bool = False, | |
| ): | |
| """ Initializes the NLayer3DDiscriminator. | |
| Args: | |
| num_channels -> int: The number of input channels. | |
| hidden_channels -> int: The number of hidden channels. | |
| num_stages -> int: The number of stages. | |
| blur_resample -> bool: Whether to use blur resampling. | |
| blur_kernel_size -> int: The blur kernel size. | |
| """ | |
| super().__init__() | |
| assert num_stages > 0, "Discriminator cannot have 0 stages" | |
| assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]" | |
| in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages))) | |
| init_kernel_size = 5 | |
| activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1) | |
| self.with_condition = with_condition | |
| if with_condition: | |
| cond_channels = 768 | |
| self.time_emb = SinusoidalTimeEmbedding(128) | |
| self.time_proj = torch.nn.Sequential( | |
| torch.nn.Linear(128, cond_channels), | |
| torch.nn.SiLU(), | |
| torch.nn.Linear(cond_channels, cond_channels), | |
| ) | |
| else: | |
| cond_channels = None | |
| self.block_in = torch.nn.Sequential( | |
| Conv3dSame( | |
| num_channels, | |
| hidden_channels, | |
| kernel_size=init_kernel_size | |
| ), | |
| activation(), | |
| ) | |
| BLUR_KERNEL_MAP = { | |
| 3: (1,2,1), | |
| 4: (1,3,3,1), | |
| 5: (1,4,6,4,1), | |
| } | |
| num_downsample_temp_stage = int(num_stages * 1/3) | |
| downsample_temp = [False] * num_downsample_temp_stage + [True] * (num_stages - num_downsample_temp_stage) | |
| discriminator_blocks = [] | |
| for i_level in range(num_stages): | |
| in_channels = hidden_channels * in_channel_mult[i_level] | |
| out_channels = hidden_channels * in_channel_mult[i_level + 1] | |
| conv_block = ModulatedConv3dSame( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| cond_channels=cond_channels | |
| ) | |
| discriminator_blocks.append(conv_block) | |
| down_block = torch.nn.Sequential( | |
| torch.nn.AvgPool3d(kernel_size=2, stride=(2, 2, 2) if downsample_temp[i_level] else (1, 2, 2)) if not blur_resample else BlurBlock3D(BLUR_KERNEL_MAP[blur_kernel_size], stride=(2, 2, 2) if downsample_temp[i_level] else (1, 2, 2)), | |
| torch.nn.GroupNorm(32, out_channels), | |
| activation(), | |
| ) | |
| discriminator_blocks.append(down_block) | |
| self.blocks = torch.nn.ModuleList(discriminator_blocks) | |
| self.pool = torch.nn.AdaptiveMaxPool3d((2, 16, 16)) | |
| self.to_logits = torch.nn.Sequential( | |
| Conv3dSame(out_channels, out_channels, 1), | |
| activation(), | |
| Conv3dSame(out_channels, 1, kernel_size=5) | |
| ) | |
| def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor: | |
| """ Forward pass. | |
| Args: | |
| x -> torch.Tensor: The input tensor of shape [b t c h w]. | |
| Returns: | |
| output -> torch.Tensor: The output tensor. | |
| """ | |
| x = rearrange(x, 'b t c h w -> b c t h w') | |
| hidden_states = self.block_in(x) | |
| if condition is not None and self.with_condition: | |
| temb = self.time_proj(self.time_emb(condition * 1000.0)) | |
| else: | |
| temb = None | |
| for i, block in enumerate(self.blocks): | |
| if i % 2 == 0: | |
| hidden_states = block(hidden_states, temb) # conv_block | |
| else: | |
| hidden_states = block(hidden_states) # down_block | |
| hidden_states = self.pool(hidden_states) | |
| return self.to_logits(hidden_states) | |