starflow / misc /discriminator.py
leoeric's picture
Initial commit for HF Space - code files only
0b4562b
raw
history blame
14.5 kB
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import functools
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# Conv2D with same padding
class Conv2dSame(nn.Conv2d):
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
ih, iw = x.size()[-2:]
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return super().forward(x)
class BlurBlock(torch.nn.Module):
def __init__(self,
kernel: Tuple[int] = (1, 3, 3, 1)
):
super().__init__()
kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
kernel = kernel[None, :] * kernel[:, None]
kernel /= kernel.sum()
kernel = kernel.unsqueeze(0).unsqueeze(0)
self.register_buffer("kernel", kernel)
def calc_same_pad(self, i: int, k: int, s: int) -> int:
return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
ic, ih, iw = x.size()[-3:]
pad_h = self.calc_same_pad(i=ih, k=4, s=2)
pad_w = self.calc_same_pad(i=iw, k=4, s=2)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
weight = self.kernel.expand(ic, -1, -1, -1)
out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1])
return out
class SinusoidalTimeEmbedding(torch.nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
assert embedding_dim % 2 == 0, "embedding_dim must be even"
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
half_dim = self.embedding_dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=timesteps.device) * -embeddings)
embeddings = timesteps[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class ModulatedConv2dSame(Conv2dSame):
def __init__(self, in_channels, out_channels, kernel_size, cond_channels=None):
super().__init__(in_channels, out_channels, kernel_size)
# FiLM modulation projections
if cond_channels is not None:
self.film_proj = torch.nn.Linear(cond_channels, 2 * out_channels)
# Initialize scale to 0 and bias to 0
torch.nn.init.zeros_(self.film_proj.weight)
torch.nn.init.zeros_(self.film_proj.bias)
def forward(self, x, temb=None):
x = super().forward(x)
if temb is not None:
scale, bias = self.film_proj(temb)[:, :, None, None].chunk(2, dim=1)
x = x * (scale + 1) + bias
return x
class NLayerDiscriminator(torch.nn.Module):
def __init__(
self,
num_channels: int = 3,
hidden_channels: int = 128,
num_stages: int = 3,
blur_resample: bool = True,
blur_kernel_size: int = 4,
with_condition: bool = False,
):
""" Initializes the NLayerDiscriminator.
Args:
num_channels -> int: The number of input channels.
hidden_channels -> int: The number of hidden channels.
num_stages -> int: The number of stages.
blur_resample -> bool: Whether to use blur resampling.
blur_kernel_size -> int: The blur kernel size.
"""
super().__init__()
assert num_stages > 0, "Discriminator cannot have 0 stages"
assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]"
in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
init_kernel_size = 5
activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
self.with_condition = with_condition
if with_condition:
cond_channels = 768
self.time_emb = SinusoidalTimeEmbedding(128)
self.time_proj = torch.nn.Sequential(
torch.nn.Linear(128, cond_channels),
torch.nn.SiLU(),
torch.nn.Linear(cond_channels, cond_channels),
)
else:
cond_channels = None
self.block_in = torch.nn.Sequential(
Conv2dSame(
num_channels,
hidden_channels,
kernel_size=init_kernel_size
),
activation(),
)
BLUR_KERNEL_MAP = {
3: (1,2,1),
4: (1,3,3,1),
5: (1,4,6,4,1),
}
discriminator_blocks = []
for i_level in range(num_stages):
in_channels = hidden_channels * in_channel_mult[i_level]
out_channels = hidden_channels * in_channel_mult[i_level + 1]
conv_block = ModulatedConv2dSame(
in_channels,
out_channels,
kernel_size=3,
cond_channels=cond_channels
)
discriminator_blocks.append(conv_block)
down_block = torch.nn.Sequential(
torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]),
torch.nn.GroupNorm(32, out_channels),
activation(),
)
discriminator_blocks.append(down_block)
self.blocks = torch.nn.ModuleList(discriminator_blocks)
self.pool = torch.nn.AdaptiveMaxPool2d((16, 16))
self.to_logits = torch.nn.Sequential(
Conv2dSame(out_channels, out_channels, 1),
activation(),
Conv2dSame(out_channels, 1, kernel_size=5)
)
def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
""" Forward pass.
Args:
x -> torch.Tensor: The input tensor.
Returns:
output -> torch.Tensor: The output tensor.
"""
if x.dim() == 5:
x = rearrange(x, 'b t c h w -> (b t) c h w')
hidden_states = self.block_in(x)
if condition is not None and self.with_condition:
temb = self.time_proj(self.time_emb(condition * 1000.0))
else:
temb = None
for i, block in enumerate(self.blocks):
if i % 2 == 0:
hidden_states = block(hidden_states, temb) # conv_block
else:
hidden_states = block(hidden_states) # down_block
hidden_states = self.pool(hidden_states)
return self.to_logits(hidden_states)
# 3D discriminator
class Conv3dSame(nn.Conv3d):
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
it, ih, iw = x.size()[-3:] # frame, height, width
pad_t = self.calc_same_pad(i=it, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[2], s=self.stride[2], d=self.dilation[2])
if pad_t > 0 or pad_h > 0 or pad_w > 0:
x = F.pad(
x,
[pad_w // 2, pad_w - pad_w // 2,
pad_h // 2, pad_h - pad_h // 2,
pad_t // 2, pad_t - pad_t // 2],
)
return super().forward(x)
class ModulatedConv3dSame(Conv3dSame):
def __init__(self, in_channels, out_channels, kernel_size, cond_channels=None):
super().__init__(in_channels, out_channels, kernel_size)
# FiLM modulation
if cond_channels is not None:
self.film_proj = torch.nn.Linear(cond_channels, 2 * out_channels)
# Initialize FiLM params (scale to 0, bias to 0)
torch.nn.init.zeros_(self.film_proj.weight)
torch.nn.init.zeros_(self.film_proj.bias)
def forward(self, x, temb=None):
x = super().forward(x) # (B, C, T, H, W)
if temb is not None:
scale, bias = self.film_proj(temb)[:, :, None, None, None].chunk(2, dim=1)
x = x * (scale + 1) + bias
return x
class BlurBlock3D(nn.Module):
def __init__(self, kernel=(1, 3, 3, 1), stride=(1, 2, 2)):
"""
3D BlurPool block.
Applies blur to spatial dimensions only by default.
"""
super().__init__()
self.stride = stride
kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
kernel = kernel[None, :] * kernel[:, None]
kernel /= kernel.sum()
kernel = kernel.unsqueeze(0).unsqueeze(0).unsqueeze(0) # shape: (1, 1, 1, H, W)
self.register_buffer("kernel", kernel)
def calc_same_pad(self, i: int, k: int, s: int) -> int:
return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, c, t, h, w = x.shape
kd, kh, kw = self.kernel.shape[-3:]
sd, sh, sw = self.stride
# Only apply padding to H and W
pad_h = self.calc_same_pad(h, kh, sh)
pad_w = self.calc_same_pad(w, kw, sw)
pad_d = 0 if sd == 1 else self.calc_same_pad(t, kd, sd)
if pad_h > 0 or pad_w > 0 or pad_d > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2,
pad_h // 2, pad_h - pad_h // 2,
pad_d // 2, pad_d - pad_d // 2])
weight = self.kernel.expand(c, 1, -1, -1, -1)
return F.conv3d(x, weight, stride=self.stride, groups=c)
class NLayer3DDiscriminator(torch.nn.Module):
def __init__(
self,
num_channels: int = 3,
hidden_channels: int = 128,
num_stages: int = 3,
blur_resample: bool = True,
blur_kernel_size: int = 4,
with_condition: bool = False,
):
""" Initializes the NLayer3DDiscriminator.
Args:
num_channels -> int: The number of input channels.
hidden_channels -> int: The number of hidden channels.
num_stages -> int: The number of stages.
blur_resample -> bool: Whether to use blur resampling.
blur_kernel_size -> int: The blur kernel size.
"""
super().__init__()
assert num_stages > 0, "Discriminator cannot have 0 stages"
assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]"
in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
init_kernel_size = 5
activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
self.with_condition = with_condition
if with_condition:
cond_channels = 768
self.time_emb = SinusoidalTimeEmbedding(128)
self.time_proj = torch.nn.Sequential(
torch.nn.Linear(128, cond_channels),
torch.nn.SiLU(),
torch.nn.Linear(cond_channels, cond_channels),
)
else:
cond_channels = None
self.block_in = torch.nn.Sequential(
Conv3dSame(
num_channels,
hidden_channels,
kernel_size=init_kernel_size
),
activation(),
)
BLUR_KERNEL_MAP = {
3: (1,2,1),
4: (1,3,3,1),
5: (1,4,6,4,1),
}
num_downsample_temp_stage = int(num_stages * 1/3)
downsample_temp = [False] * num_downsample_temp_stage + [True] * (num_stages - num_downsample_temp_stage)
discriminator_blocks = []
for i_level in range(num_stages):
in_channels = hidden_channels * in_channel_mult[i_level]
out_channels = hidden_channels * in_channel_mult[i_level + 1]
conv_block = ModulatedConv3dSame(
in_channels,
out_channels,
kernel_size=3,
cond_channels=cond_channels
)
discriminator_blocks.append(conv_block)
down_block = torch.nn.Sequential(
torch.nn.AvgPool3d(kernel_size=2, stride=(2, 2, 2) if downsample_temp[i_level] else (1, 2, 2)) if not blur_resample else BlurBlock3D(BLUR_KERNEL_MAP[blur_kernel_size], stride=(2, 2, 2) if downsample_temp[i_level] else (1, 2, 2)),
torch.nn.GroupNorm(32, out_channels),
activation(),
)
discriminator_blocks.append(down_block)
self.blocks = torch.nn.ModuleList(discriminator_blocks)
self.pool = torch.nn.AdaptiveMaxPool3d((2, 16, 16))
self.to_logits = torch.nn.Sequential(
Conv3dSame(out_channels, out_channels, 1),
activation(),
Conv3dSame(out_channels, 1, kernel_size=5)
)
def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
""" Forward pass.
Args:
x -> torch.Tensor: The input tensor of shape [b t c h w].
Returns:
output -> torch.Tensor: The output tensor.
"""
x = rearrange(x, 'b t c h w -> b c t h w')
hidden_states = self.block_in(x)
if condition is not None and self.with_condition:
temb = self.time_proj(self.time_emb(condition * 1000.0))
else:
temb = None
for i, block in enumerate(self.blocks):
if i % 2 == 0:
hidden_states = block(hidden_states, temb) # conv_block
else:
hidden_states = block(hidden_states) # down_block
hidden_states = self.pool(hidden_states)
return self.to_logits(hidden_states)