|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .layers import ResidualBlock, AttnBlock |
|
|
from .utils import get_named_beta_schedule |
|
|
|
|
|
def sinusoidal_embedding(n, d): |
|
|
""" |
|
|
n: iteration steps, |
|
|
d: time embedding dimension |
|
|
""" |
|
|
|
|
|
embedding = torch.tensor([[i / 10000 ** (2 * j / d) for j in range(d)] for i in range(n)]) |
|
|
sin_mask = torch.arange(0, n, 2) |
|
|
|
|
|
embedding[sin_mask] = torch.sin(embedding[sin_mask]) |
|
|
embedding[1 - sin_mask] = torch.cos(embedding[sin_mask]) |
|
|
|
|
|
return embedding |
|
|
|
|
|
def _make_te(dim_in, dim_out): |
|
|
return nn.Sequential( |
|
|
nn.Linear(dim_in, dim_out), |
|
|
nn.SiLU(), |
|
|
nn.Linear(dim_out, dim_out) |
|
|
) |
|
|
|
|
|
class UNet_with_time(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
input_frame = config.input_frame |
|
|
output_frame = config.output_frame |
|
|
n_steps = config.n_steps |
|
|
time_emb_dim = config.time_emb_dim |
|
|
cond_nc = config.cond_nc |
|
|
chs_mult = config.chs_mult |
|
|
n_res_blocks = config.n_res_blocks |
|
|
base_chs = config.base_chs |
|
|
|
|
|
use_attn_list = config.use_attn_list |
|
|
|
|
|
layer_depth = len(chs_mult) |
|
|
assert len(use_attn_list) == layer_depth, "length of use_attn_list should be the same as chs_mult" |
|
|
assert input_frame >= output_frame, "input_frame should be larger than or equal to output_frame" |
|
|
|
|
|
self.filter_list = [base_chs * m for m in chs_mult] |
|
|
|
|
|
|
|
|
self.time_embed = nn.Embedding(n_steps, time_emb_dim) |
|
|
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) |
|
|
self.time_embed.requires_grad_(False) |
|
|
self.time_embed_fc = _make_te(time_emb_dim, time_emb_dim) |
|
|
|
|
|
|
|
|
|
|
|
self.input_layer = nn.PixelUnshuffle(downscale_factor=2) |
|
|
|
|
|
|
|
|
self.down_blocks = nn.ModuleList() |
|
|
in_c = input_frame * 4 |
|
|
for i in range(layer_depth): |
|
|
out_c = self.filter_list[i] |
|
|
|
|
|
for _ in range(n_res_blocks): |
|
|
self.down_blocks.append( |
|
|
ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False) |
|
|
) |
|
|
|
|
|
if use_attn_list[i]: |
|
|
self.down_blocks.append(AttnBlock(in_c, 4)) |
|
|
|
|
|
self.down_blocks.append( |
|
|
ResidualBlock(in_c, out_c, cond_nc, time_emb_dim, down_flag=True, up_flag=False) |
|
|
) |
|
|
in_c = out_c |
|
|
|
|
|
|
|
|
|
|
|
self.mid_block1 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False) |
|
|
self.mid_attn = AttnBlock(in_c, 4) |
|
|
self.mid_block2 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False) |
|
|
|
|
|
|
|
|
|
|
|
self.up_blocks = nn.ModuleList() |
|
|
self.filter_list = [input_frame * 4] + self.filter_list[:-1] |
|
|
for i in reversed(range(layer_depth)): |
|
|
out_c = self.filter_list[i] |
|
|
|
|
|
self.up_blocks.append( |
|
|
ResidualBlock(in_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=True) |
|
|
) |
|
|
if use_attn_list[i]: |
|
|
self.up_blocks.append(AttnBlock(out_c)) |
|
|
|
|
|
for _ in range(n_res_blocks): |
|
|
self.up_blocks.append( |
|
|
ResidualBlock(out_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False) |
|
|
) |
|
|
|
|
|
in_c = out_c |
|
|
|
|
|
|
|
|
self.out_up = nn.PixelShuffle(upscale_factor=2) |
|
|
self.out_conv = nn.Conv2d(input_frame, output_frame, 3, padding=1) |
|
|
|
|
|
def forward(self, x, t, cond): |
|
|
""" |
|
|
x: (b, in_c, h, w), noisy input (concatenated with some data) |
|
|
t: (b,), time step |
|
|
cond: (b, cond_nc, h, w), conditional input |
|
|
""" |
|
|
|
|
|
t_emb = self.time_embed(t) |
|
|
t_emb = self.time_embed_fc(t_emb) |
|
|
|
|
|
|
|
|
x = self.input_layer(x) |
|
|
|
|
|
|
|
|
skip_x = [] |
|
|
for ii, down_layer in enumerate(self.down_blocks): |
|
|
if isinstance(down_layer, ResidualBlock): |
|
|
x = down_layer(x, cond, t_emb) |
|
|
skip_x.append(x) |
|
|
elif isinstance(down_layer, AttnBlock): |
|
|
x = down_layer(x) |
|
|
else: |
|
|
raise ValueError("Wrong layer type in down_blocks") |
|
|
|
|
|
|
|
|
x = self.mid_block1(x, cond, t_emb) |
|
|
x = self.mid_attn(x) |
|
|
x = self.mid_block2(x, cond, t_emb) |
|
|
|
|
|
|
|
|
for up_layer in self.up_blocks: |
|
|
if isinstance(up_layer, ResidualBlock): |
|
|
skip_feat = skip_x.pop() |
|
|
x = torch.cat([x, skip_feat], dim=1) |
|
|
x = up_layer(x, cond, t_emb) |
|
|
elif isinstance(up_layer, AttnBlock): |
|
|
x = up_layer(x) |
|
|
else: |
|
|
raise ValueError("Wrong layer type in up_blocks") |
|
|
|
|
|
|
|
|
x = self.out_up(x) |
|
|
x = self.out_conv(x) |
|
|
|
|
|
return x |
|
|
|
|
|
class DDPM(nn.Module): |
|
|
def __init__(self, backbone, output_shape, n_steps=1000, min_beta=1e-4, max_beta=0.02, device='cuda'): |
|
|
""" |
|
|
output_shape: dim(C, H, W) |
|
|
""" |
|
|
super().__init__() |
|
|
self.device = device |
|
|
self.backbone_model = backbone |
|
|
self.output_shape = output_shape |
|
|
|
|
|
self.n_steps = n_steps |
|
|
|
|
|
|
|
|
betas = get_named_beta_schedule("linear", n_steps, min_beta, max_beta) |
|
|
alphas = 1.0 - betas |
|
|
alpha_bars = torch.cumprod(alphas, dim=0) |
|
|
|
|
|
self.register_buffer('betas', betas) |
|
|
self.register_buffer('alphas', alphas) |
|
|
self.register_buffer('alpha_bars', alpha_bars) |
|
|
|
|
|
def forward(self, x, t, cond): |
|
|
""" |
|
|
x: (b, in_c, h, w), noisy input (concatenated with some data) |
|
|
cond: (b, cond_nc, h, w), conditional input |
|
|
t: (b,), time step |
|
|
""" |
|
|
return self.backbone_model(x, t, cond) |
|
|
|
|
|
@torch.no_grad() |
|
|
def add_noise(self, x0, t, eta=None): |
|
|
""" |
|
|
x0: (b, c, h, w), original data |
|
|
t: (b,), time step (0 <= t < n_steps) |
|
|
""" |
|
|
b, c, h, w = x0.shape |
|
|
if eta is None: |
|
|
eta = torch.randn(b, c, h, w, device=x0.device) |
|
|
|
|
|
alpha_bar = self.alpha_bars[t] |
|
|
noisy_x = alpha_bar.sqrt().reshape(b, 1, 1, 1) * x0 + (1 - alpha_bar).sqrt().reshape(b, 1, 1, 1) * eta |
|
|
|
|
|
return noisy_x |
|
|
|
|
|
def denoise(self, xt, t, cond): |
|
|
""" |
|
|
xt: (b, in_c, h, w), noisy input (concatenated with some data) |
|
|
cond: (b, cond_nc, h, w), conditional input |
|
|
t: (b,), time step (0 <= t < n_steps) |
|
|
""" |
|
|
pred_noise = self(xt, t, cond) |
|
|
return pred_noise |
|
|
|
|
|
@torch.no_grad() |
|
|
def _build_progress_iter(self, iterable, total, mode: str): |
|
|
""" |
|
|
Internal helper to create a progress iterator based on verbose mode. |
|
|
""" |
|
|
mode = (mode or "none").lower() |
|
|
if mode == "tqdm": |
|
|
try: |
|
|
from tqdm import tqdm |
|
|
|
|
|
return tqdm(iterable, total=total, desc="DDPM sampling", leave=False), mode |
|
|
except Exception: |
|
|
return iterable, "none" |
|
|
return iterable, mode |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_ddpm(self, cond, input_cond=None, verbose: str = "none", store_intermediate: bool = False): |
|
|
""" |
|
|
input_frame: (b, c, h, w) number of input frames (conditional input frames) for the diffusion model |
|
|
cond: (b, cond_nc, h, w), conditional input |
|
|
verbose: "none", "text", or "tqdm" for progress display |
|
|
""" |
|
|
|
|
|
self.backbone_model.eval() |
|
|
|
|
|
B, C, H, W = cond.shape |
|
|
|
|
|
device = cond.device |
|
|
|
|
|
x = torch.randn(B, *self.output_shape, device=device) |
|
|
|
|
|
progress_iter_raw = reversed(range(self.n_steps)) |
|
|
progress_iter, mode = self._build_progress_iter(progress_iter_raw, self.n_steps, verbose) |
|
|
use_text = mode == "text" |
|
|
|
|
|
text_interval = max(1, self.n_steps // 10) |
|
|
|
|
|
frames = [] |
|
|
for idx, t in enumerate(progress_iter): |
|
|
time_tensor = (torch.ones(B, device=device) * t).long() |
|
|
if input_cond is not None: |
|
|
input_ = torch.cat((x, input_cond), dim=1) |
|
|
else: |
|
|
input_ = x |
|
|
|
|
|
eta_theta = self.denoise(input_, time_tensor, cond) |
|
|
|
|
|
alpha_t = self.alphas[t] |
|
|
alpha_t_bar = self.alpha_bars[t] |
|
|
|
|
|
a = 1 / alpha_t.sqrt() |
|
|
b = ((1 - alpha_t) / (1 - alpha_t_bar).sqrt()) * eta_theta |
|
|
|
|
|
x = a * (x - b) |
|
|
if t > 0: |
|
|
z = torch.randn(B, *self.output_shape, device=device) |
|
|
beta_t = self.betas[t] |
|
|
sigma_t = beta_t.sqrt() |
|
|
x = x + sigma_t * z |
|
|
|
|
|
|
|
|
if (idx % 50 == 0) or (t == 0): |
|
|
out = x.clone() |
|
|
out = ((out + 1) / 2).clamp(0, 1) |
|
|
out = out.cpu().numpy() |
|
|
frames.append(out) |
|
|
|
|
|
if use_text and (idx + 1) % text_interval == 0: |
|
|
print(f"DDPM sampling {idx + 1}/{self.n_steps}", flush=True) |
|
|
|
|
|
if mode == "tqdm" and hasattr(progress_iter, "close"): |
|
|
progress_iter.close() |
|
|
|
|
|
if store_intermediate: |
|
|
return x, frames |
|
|
else: |
|
|
return x |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_ddim(self, cond, input_cond=None, ddim_steps: int = 100, eta: float = 0.2, verbose: str = "none", store_intermediate: bool = False): |
|
|
""" |
|
|
Deterministic/stochastic DDIM sampling. |
|
|
|
|
|
cond: (b, cond_nc, h, w) |
|
|
input_cond: optional conditional input concatenated with the predicted frames |
|
|
ddim_steps: number of steps to sample (<= n_steps) |
|
|
eta: 0 for deterministic DDIM, >0 adds noise controlled by eta |
|
|
verbose: "none", "text", or "tqdm" for progress display |
|
|
""" |
|
|
self.backbone_model.eval() |
|
|
|
|
|
B, C, H, W = cond.shape |
|
|
device = cond.device |
|
|
ddim_steps = max(1, min(ddim_steps, self.n_steps)) |
|
|
|
|
|
|
|
|
ddim_timesteps = torch.linspace(0, self.n_steps - 1, steps=ddim_steps, device=device).long() |
|
|
ddim_timesteps = torch.unique(ddim_timesteps, sorted=True) |
|
|
ddim_t_reverse = list(reversed(ddim_timesteps.tolist())) |
|
|
|
|
|
x = torch.randn(B, *self.output_shape, device=device) |
|
|
|
|
|
progress_iter_raw = enumerate(ddim_t_reverse) |
|
|
progress_iter, mode = self._build_progress_iter(progress_iter_raw, len(ddim_t_reverse), verbose) |
|
|
use_text = mode == "text" |
|
|
text_interval = max(1, len(ddim_t_reverse) // 10) |
|
|
|
|
|
frames = [] |
|
|
for idx, (iter_idx, t) in enumerate(progress_iter): |
|
|
time_tensor = torch.full((B,), t, device=device, dtype=torch.long) |
|
|
if input_cond is not None: |
|
|
input_ = torch.cat((x, input_cond), dim=1) |
|
|
else: |
|
|
input_ = x |
|
|
|
|
|
eps = self.denoise(input_, time_tensor, cond) |
|
|
|
|
|
alpha_bar_t = self.alpha_bars[t] |
|
|
sqrt_alpha_bar_t = alpha_bar_t.sqrt() |
|
|
sqrt_one_minus_alpha_bar_t = (1 - alpha_bar_t).sqrt() |
|
|
|
|
|
x0_pred = (x - sqrt_one_minus_alpha_bar_t * eps) / sqrt_alpha_bar_t |
|
|
|
|
|
if iter_idx + 1 < len(ddim_t_reverse): |
|
|
t_prev = ddim_t_reverse[iter_idx + 1] |
|
|
alpha_bar_prev = self.alpha_bars[t_prev] |
|
|
else: |
|
|
alpha_bar_prev = torch.ones_like(alpha_bar_t, device=device) |
|
|
|
|
|
sigma_t = 0.0 |
|
|
if eta > 0 and alpha_bar_prev < 1: |
|
|
sigma_t = eta * torch.sqrt( |
|
|
(1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev) |
|
|
) |
|
|
|
|
|
sigma_t = torch.as_tensor(sigma_t, device=device, dtype=x.dtype) |
|
|
noise = torch.randn_like(x) if (eta > 0 and alpha_bar_prev < 1) else torch.zeros_like(x) |
|
|
|
|
|
c_t = torch.sqrt(torch.clamp(1 - alpha_bar_prev - sigma_t ** 2, min=0.0)) |
|
|
x = ( |
|
|
alpha_bar_prev.sqrt() * x0_pred |
|
|
+ c_t * eps |
|
|
+ sigma_t * noise |
|
|
) |
|
|
|
|
|
|
|
|
if (idx % 25 == 0) or (t == 0): |
|
|
out = x.clone() |
|
|
out = ((out + 1) / 2).clamp(0, 1) |
|
|
out = out.cpu().numpy() |
|
|
frames.append(out) |
|
|
|
|
|
if use_text and (idx + 1) % text_interval == 0: |
|
|
print(f"DDIM sampling {idx + 1}/{len(ddim_t_reverse)}", flush=True) |
|
|
|
|
|
if mode == "tqdm" and hasattr(progress_iter, "close"): |
|
|
progress_iter.close() |
|
|
|
|
|
if store_intermediate: |
|
|
return x, frames |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
sample = sample_ddpm |
|
|
|