starflow / transformer_flow.py
leoeric's picture
Initial commit for HF Space - code files only
0b4562b
raw
history blame
59.3 kB
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import copy
import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Tuple
from misc.pe import VisionRotaryEmbeddingFast, apply_rope, get_positions
from misc import print
from functools import partial
from einops import rearrange, repeat
from torch.utils.checkpoint import checkpoint
INV_SOFTPLUS_1 = 0.541324854612918
def modulate(x, shift, scale):
if shift is None:
return x * (1 + scale)
return x * (1 + scale) + shift
def stable_neg_log_softplus(x):
return torch.where(
x > 20, # softplus(x) β‰ˆ x β†’ log β‰ˆ log(x)
-x.log(), # so -log(softplus(x)) β‰ˆ -log(x)
-F.softplus(x).log()
)
class KVCache:
def __init__(self):
self._is_empty = True
self.prefix_cache = None
self.meta_data = {}
def initialize(self, num_blocks, *size):
self._is_empty = False
self.num_blocks = num_blocks
self.size = size
self.kv_caches = [torch.zeros(2, *size) for _ in range(num_blocks)]
self.kv_index = [0] * num_blocks
def register_prefix_cache(self, prefix_cache):
self.prefix_cache = prefix_cache
@property
def is_empty(self):
return self._is_empty
@property
def is_full(self):
if self.is_empty:
return False
return all(index == self.size[2] for index in self.kv_index)
def delete(self):
if not self.is_empty:
self._is_empty = True
del self.kv_caches
del self.kv_index
def to(self, device, dtype=torch.bfloat16):
for i in range(self.num_blocks):
self.kv_caches[i] = self.kv_caches[i].to(device=device, dtype=dtype)
def extend_length(self, length):
assert not self.is_empty, "KVCache is empty, cannot extend length"
self.size = (self.size[0], self.size[1], self.size[2] + length, self.size[3])
for i in range(self.num_blocks):
pad = self.kv_caches[i].new_zeros((2, *self.size))
pad[:, :, :, :self.kv_caches[i].size(3)] = self.kv_caches[i]
self.kv_caches[i] = pad
def expand_batch(self, ratio=2):
self.size = (self.size[0] * ratio, *self.size[1:])
for i in range(self.num_blocks):
self.kv_caches[i] = torch.cat([self.kv_caches[i] for _ in range(ratio)], dim=1)
def remove_negative_cache(self):
self.size = (self.size[0] // 2, *self.size[1:])
for i in range(self.num_blocks):
self.kv_caches[i] = self.kv_caches[i].chunk(2, dim=1)[0]
def backward_in_time(self, l):
for i in range(self.num_blocks):
self.kv_index[i] = max(0, self.kv_index[i] - l)
def reset_kv_index(self):
for i in range(self.num_blocks):
self.kv_index[i] = 0
def __call__(self, block_idx, k, v):
assert block_idx < self.num_blocks, f'block_idx {block_idx} out of range {self.num_blocks}'
# write cache
l = k.size(2)
kv_index = self.kv_index[block_idx]
if kv_index + l > self.size[2]:
raise NotImplementedError("Overflow mode is not implemented")
self.kv_caches[block_idx][0][:, :, kv_index: kv_index+l] = k
self.kv_caches[block_idx][1][:, :, kv_index: kv_index+l] = v
self.kv_index[block_idx] = kv_index + l
# read cache
kv_index = self.kv_index[block_idx]
return self.kv_caches[block_idx][0][:, :, :kv_index], self.kv_caches[block_idx][1][:, :, :kv_index]
class Permutation(torch.nn.Module):
def __init__(self, seq_length: int):
super().__init__()
self.seq_length = seq_length
self.input_shape = None
def forward(self, x: torch.Tensor | List[torch.Tensor], dim: int = 1, inverse: bool = False):
if not inverse:
self.input_shape = x.shape
x = rearrange(x, 'b t h w c -> b (t h w) c' if x.dim() == 5 else 'b h w c -> b (h w) c')
x = self.permute(x, dim, self.input_shape, inverse=False)
else:
x = self.permute(x, dim, self.input_shape, inverse=True)
x = x.reshape(-1, *self.input_shape[1:])
return x
def permute(self, x: torch.Tensor, dim: int = 1, shape=None, inverse: bool = False) -> torch.Tensor:
raise NotImplementedError('Overload me')
class PermutationIdentity(Permutation):
def permute(self, x: torch.Tensor, dim: int = 1, shape=None, inverse: bool = False) -> torch.Tensor:
return x.clone()
class PermutationFlip(Permutation):
def permute(self, x: torch.Tensor, dim: int = 1, shape=None, inverse: bool = False) -> torch.Tensor:
return x.flip(dims=[dim])
class PermutationFlipInBlock(Permutation):
def permute(self, x: torch.Tensor, dim: int = 1, shape=None, inverse: bool = False) -> torch.Tensor:
assert shape is not None, "shape must be provided for PermutationFlipInBlock"
if len(shape) == 5:
assert dim == 1, "dim must be 1 for 5D tensor in PermutationFlipInBlock"
# flip the tensor within blocks of size `block_size`, globally still in the same order
x = x.view(x.size(0), shape[1], -1, x.size(-1)).flip(dims=[2]).view_as(x)
else:
x = x.flip(dims=[dim])
return x
class RMSNorm(torch.nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = True,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = torch.nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = self._norm(x.float())
if self.add_unit_offset:
output = output * (1 + self.weight.float())
else:
output = output * self.weight.float()
return output.type_as(x)
class Attention(torch.nn.Module):
def __init__(self, in_channels: int, head_channels: int, norm_type: str = "layer_norm",
num_heads=None, num_kv_heads=None, use_qk_norm=False,
use_post_norm=False, use_bias=True, hf_style_rope=False, non_causal=False):
super().__init__()
if norm_type == "layer_norm":
self.norm = torch.nn.LayerNorm(in_channels)
elif norm_type == "rms_norm":
self.norm = RMSNorm(in_channels)
else:
self.norm = torch.nn.Identity()
self.head_channels = head_channels
self.num_heads = num_heads if num_heads is not None else in_channels // head_channels
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else self.num_heads # GQA
self.q_size = self.num_heads * head_channels
self.kv_size = self.num_kv_heads * head_channels
self.qkv = torch.nn.Linear(in_channels, self.q_size + 2 * self.kv_size, bias=use_bias)
self.proj = torch.nn.Linear(self.q_size, in_channels, bias=use_bias)
self.query_norm = (RMSNorm(self.head_channels) if use_qk_norm else None)
self.key_norm = (RMSNorm(self.head_channels) if use_qk_norm else None)
self.post_norm = (RMSNorm(in_channels) if use_post_norm else None)
self.sqrt_scale = head_channels ** (-0.25)
self.hf_style_rope = hf_style_rope
self.non_causal = non_causal
def apply_rope(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
if self.hf_style_rope:
return rearrange(apply_rope(rearrange(x, '... (u d) -> ... (d u)', u=2), freqs_cis), '... (d u) -> ... (u d)', u=2)
return apply_rope(x, freqs_cis)
def prepare_for_attention(self, x: torch.Tensor, freqs_cis=None, kv_cache=None):
B, T, _ = x.size()
q, k, v = self.qkv(self.norm(x)).split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(B, T, self.num_heads, self.head_channels).transpose(1, 2) # (b, h, t, d)
k = k.view(B, T, self.num_kv_heads, self.head_channels).transpose(1, 2) # (b, h, t, d)
v = v.view(B, T, self.num_kv_heads, self.head_channels).transpose(1, 2) # (b, h, t, d)
if self.query_norm is not None and self.key_norm is not None:
q, k = self.query_norm(q), self.key_norm(k)
if kv_cache is not None:
k, v = kv_cache(k, v)
if freqs_cis is not None:
lq, lk = q.size(2), k.size(2)
q, k = self.apply_rope(q, freqs_cis[lk-lq:lk]), self.apply_rope(k, freqs_cis[:lk])
if self.num_kv_heads != self.num_heads: # GQA (b, h, t, d)
k = torch.repeat_interleave(k, self.num_heads // self.num_kv_heads, dim=1)
v = torch.repeat_interleave(v, self.num_heads // self.num_kv_heads, dim=1)
return q.to(x.dtype), k.to(x.dtype), v.to(x.dtype)
def output_after_attention(self, x: torch.Tensor):
B, _, T, _ = x.shape
x = x.transpose(1, 2).reshape(B, T, self.q_size)
x = self.proj(x)
if self.post_norm is not None:
x = self.post_norm(x)
return x
def apply_attention(self, q, k, v, mask=None, temp=1.0):
scale = self.sqrt_scale**2 / temp
is_causal = not self.non_causal
if is_causal and q.size(2) < k.size(2) and mask is None:
prefix_len = k.size(2) - q.size(2)
mask = torch.tril(torch.ones(q.size(2), k.size(2), device=q.device, dtype=torch.bool), diagonal=prefix_len)
if mask is not None:
mask = mask.bool()
is_causal = False
# spda
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, is_causal=is_causal, scale=scale)
return x
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None, temp: float = 1.0, freqs_cis=None, kv_cache=None,
) -> torch.Tensor:
q, k, v = self.prepare_for_attention(x, freqs_cis, kv_cache)
x = self.apply_attention(q, k, v, mask, temp)
x = self.output_after_attention(x)
return x
class MLP(torch.nn.Module):
def __init__(self, channels: int, expansion: float, use_swiglu=False, norm_type="layer_norm", use_post_norm=False, use_bias=True):
super().__init__()
if norm_type == "layer_norm":
self.norm = torch.nn.LayerNorm(channels)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels)
else:
self.norm = torch.nn.Identity()
self.post_norm = (RMSNorm(channels) if use_post_norm else None)
self.use_swiglu = use_swiglu
intermediate_channels = int(channels * expansion)
if use_swiglu:
self.gate_proj = torch.nn.Linear(channels, intermediate_channels, bias=use_bias)
self.up_proj = torch.nn.Linear(channels, intermediate_channels, bias=use_bias)
self.down_proj = torch.nn.Linear(intermediate_channels, channels, bias=use_bias)
else:
self.main = torch.nn.Sequential(
torch.nn.Linear(channels, intermediate_channels, bias=use_bias),
torch.nn.GELU(), torch.nn.Linear(intermediate_channels, channels, bias=use_bias)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_swiglu:
x = self.norm(x)
x = self.down_proj(F.gelu(self.gate_proj(x), approximate='tanh') * self.up_proj(x))
else:
x = self.main(self.norm(x))
return self.post_norm(x) if self.post_norm is not None else x
class AttentionBlock(torch.nn.Module):
def __init__(self, channels: int, head_channels: int, expansion: float = 4, use_adaln: bool = False,
use_swiglu=False, norm_type="layer_norm", num_heads=None, num_kv_heads=None,
use_qk_norm=False, use_post_norm=False, use_bias=True, hf_style_rope=False, non_causal=False):
super().__init__()
if use_adaln:
self.adaLN_modulation = torch.nn.Sequential(
torch.nn.SiLU(),
torch.nn.Linear(channels, 4 * channels, bias=True),
)
self.norm1 = torch.nn.LayerNorm(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = torch.nn.LayerNorm(channels, elementwise_affine=False, eps=1e-6)
torch.nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
torch.nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
# Hard-coded norm_type=="none" for adaLN
norm_type = 'none'
else:
self.adaLN_modulation = None
self.attention = Attention(channels, head_channels, norm_type, num_heads, num_kv_heads, use_qk_norm, use_post_norm, use_bias, hf_style_rope, non_causal)
self.mlp = MLP(channels, expansion, use_swiglu, norm_type, use_post_norm, use_bias)
def forward(
self, x: torch.Tensor, y: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None,
attn_temp: float = 1.0, c=None, freqs_cis=None, kv_cache=None,
checkpoint_attn: bool = False, checkpoint_mlp: bool = False
) -> torch.Tensor:
assert (x is not None) or (y is not None), "x or y must be provided"
z = torch.cat([y, x], 1) if (x is not None) and (y is not None) else x if x is not None else y
if self.adaLN_modulation is not None and c is not None:
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(c).chunk(4, dim=-1)
z = z + self._forward_attention(z, attn_mask, attn_temp, freqs_cis, kv_cache, checkpoint_attn, shift_msa, scale_msa)
z = z + self._forward_mlp(z, checkpoint_mlp, shift_mlp, scale_mlp)
else:
z = z + self._forward_attention(z, attn_mask, attn_temp, freqs_cis, kv_cache, checkpoint_attn)
z = z + self._forward_mlp(z, checkpoint_mlp)
x, y = (z[:, y.size(1):], z[:, :y.size(1)]) if (x is not None) and (y is not None) \
else (z, None) if x is not None else (None, z)
return x, y
def _forward_attention(self, z, attn_mask, attn_temp, freqs_cis, kv_cache, checkpoint_attn, shift=None, scale=None):
def attn_fn(z_in):
if shift is not None and scale is not None:
z_in = modulate(self.norm1(z_in), shift, scale)
return self.attention(z_in, attn_mask, attn_temp, freqs_cis, kv_cache)
return checkpoint(attn_fn, z, use_reentrant=False) if checkpoint_attn and self.training else attn_fn(z)
def _forward_mlp(self, z, checkpoint_mlp, shift=None, scale=None):
def mlp_fn(z_in):
if shift is not None and scale is not None:
z_in = modulate(self.norm2(z_in), shift, scale)
return self.mlp(z_in)
return checkpoint(mlp_fn, z, use_reentrant=False) if checkpoint_mlp and self.training else mlp_fn(z)
class MetaBlock(torch.nn.Module):
attn_mask: torch.Tensor
def __init__(
self,
in_channels: int,
channels: int,
img_size: int,
permutation: Permutation,
pt_seq_len: int | None = None,
num_layers: int = 1,
head_dim: int = 64,
num_heads: None | int = None,
num_kv_heads: None | int = None,
txt_size: int = 0,
txt_dim: int = 0,
expansion: float = 4,
use_rope: bool = False,
use_sos: bool = False,
use_softplus: bool = False,
use_swiglu: bool = False,
use_qk_norm: bool =False,
use_post_norm: bool = False,
use_final_norm: bool = False,
use_bias: bool = True,
use_proj_txt: bool = True,
hf_style_rope: bool = False,
norm_type: str ="layer_norm",
use_mm_attn: bool = False,
use_checkpoint: int = False,
use_checkpoint_mlp: int = None,
soft_clip: float = 0,
local_attn_window: int = None,
):
super().__init__()
out_channels = in_channels * 2
self.proj_in = torch.nn.Linear(in_channels, channels)
self.proj_out = torch.nn.Linear(channels, out_channels)
if use_sos:
self.sos_embed = torch.nn.Parameter(torch.randn(1, 1, in_channels))
torch.nn.init.constant_(self.proj_out.weight, 0)
self.txt_size = txt_size
self.img_size = img_size
self.txt_dim = txt_dim
self.pt_seq_len = pt_seq_len or img_size
# KV cache configurations
num_kv_heads = num_kv_heads or (num_heads or channels // head_dim)
self.kv_cache_size = [num_kv_heads, head_dim]
if not use_rope:
self.pos_embed = torch.nn.Parameter(torch.randn(img_size ** 2, channels) * 1e-2)
else:
self.pos_embed = None
if txt_dim > 0:
self.proj_txt = torch.nn.Linear(txt_dim, channels) if use_proj_txt else torch.nn.Identity()
assert use_proj_txt or (txt_dim == channels), 'text dimension must equal channels when not using projection'
self.attn_blocks = torch.nn.ModuleList(
[AttentionBlock(channels, head_dim, expansion, False, use_swiglu,
norm_type, num_heads, num_kv_heads, use_qk_norm, use_post_norm, use_bias, hf_style_rope)
for _ in range(num_layers)])
self.use_final_norm = use_final_norm
if use_final_norm:
self.final_norm = RMSNorm(channels)
self.use_softplus = use_softplus
self.permutation = permutation
self.use_checkpoint = use_checkpoint
self.use_checkpoint_mlp = use_checkpoint_mlp
self.use_sos = use_sos
self.soft_clip = soft_clip
self.local_attn_window = local_attn_window
self.block_masks = {} # for local attention
# ---- DEPRECATED: do not pass mask to enable flash attention ----- For compatibility ----- #
self.register_buffer('attn_mask', torch.tril(torch.ones(pt_seq_len ** 2 + txt_size, pt_seq_len ** 2 + txt_size)))
def get_freqs_cis(self, x, y, rope):
# get the input shape
h, w = x.size(-3), x.size(-2)
d = x.size(1) if x.dim() == 5 else 0
txt_size = y.size(1) if self.txt_size > 0 and y is not None else 0
if not rope.is_1d: # prepare 2D RoPE
if self.txt_size > 0 or d > 0: # prepare 3D RoPE
if self.txt_dim > 0: # text is conditioned
pos = get_positions(h, w, txt_size, rope.pt_seq_len, d, mode='3d')
else: # text is not conditioned
pos = get_positions(h, w, 0, rope.pt_seq_len, d, mode='3d')
else:
pos = get_positions(h, w, 0, rope.pt_seq_len, mode='2d')
else: # prepare 1D RoPE
pos = get_positions(h, w, txt_size, rope.pt_seq_len, mode='1d')
return rope(pos.type_as(x))
def get_sos_embed(self, x):
sos_embed = self.sos_embed.expand(x.size(0), -1, -1)
return sos_embed
def get_prepared(self, x):
# input, output, freqs_cis
x_in = x.clone()
if self.use_sos: # add SOS token, predict the first token sos->x_in[0]
x = torch.cat([self.get_sos_embed(x), x[:, :-1]], dim=1)
return x_in, x
def get_proj_in(self, x):
x = self.proj_in(x)
return x
def get_proj_out(self, x):
x = self.proj_out(x)
if hasattr(self, "soft_clip") and self.soft_clip > 0:
x = self.soft_clip * torch.tanh(x / self.soft_clip)
return x
def get_local_window_mask(self, x, y):
_, T, H, W, _ = x.shape
L = y.size(1) if y is not None else 0
B = H * W
N = T * B
S = L + N
G = self.local_attn_window
def mask(q, k):
return (k <= q) & ((k < L) | ((k - L) // B > (q - L) // B - G))
return mask(torch.arange(S, device=x.device)[:, None], torch.arange(S, device=x.device)[None, :])
def initialize_kv_cache(self, kv_cache, x, freqs_cis, reuse_kv_cache=False):
if self.local_attn_window is not None and self.local_attn_window > 0:
video_frame_size = x.size(-3) * x.size(-2)
kv_cache_length = self.local_attn_window * video_frame_size
kv_cache_length += self.txt_size if self.txt_dim > 0 else 0
kv_cache.meta_data.update(
{"frame_size": video_frame_size, "txt_size": self.txt_size + 1 if self.txt_dim > 0 else 0})
else:
kv_cache_length = freqs_cis.size(0)
kv_cache_size = (x.size(0), self.kv_cache_size[0], kv_cache_length, self.kv_cache_size[1])
if kv_cache.is_empty:
kv_cache.initialize(len(self.attn_blocks), *kv_cache_size)
kv_cache.to(x.device, x.dtype)
else:
target_size = kv_cache_size[-2]
if reuse_kv_cache:
target_size = target_size - kv_cache.kv_index[0]
kv_cache.extend_length(target_size)
return kv_cache
def forward(self, x: torch.Tensor | List[torch.Tensor], y: torch.Tensor | None = None, rope=None, kv_cache=None, guidance=None):
freqs_cis = self.get_freqs_cis(x, y, rope) if rope is not None else None
attn_mask = None
if kv_cache is not None:
kv_cache = self.initialize_kv_cache(kv_cache, x, freqs_cis)
x = self.permutation(x)
pos_embed = self.permutation(self.pos_embed, dim=0) if self.pos_embed is not None else None
# prepare input
x_in, x = self.get_prepared(x)
if kv_cache is not None:
kv_cache.register_prefix_cache(x_in)
# input projection
x = self.get_proj_in(x)
if pos_embed is not None:
x = x + pos_embed
# conditioning
if self.txt_dim > 0:
y = self.proj_txt(y)
else:
y = None
# main block
for it, block in enumerate(self.attn_blocks):
_kv_cache = partial(kv_cache, it) if kv_cache is not None else None
# Frequency-based checkpointing strategy:
# - Checkpoint attention every use_checkpoint blocks (if use_checkpoint > 0)
# - Checkpoint MLP every use_checkpoint_mlp blocks (if provided), otherwise every use_checkpoint blocks
checkpoint_attn = self.training and self.use_checkpoint > 0 and ((it + 1) % self.use_checkpoint == 0)
if self.use_checkpoint_mlp is not None:
checkpoint_mlp = self.training and self.use_checkpoint_mlp > 0 and ((it + 1) % self.use_checkpoint_mlp == 0)
else:
checkpoint_mlp = self.training and self.use_checkpoint > 0 and ((it + 1) % self.use_checkpoint == 0)
x, y = block(x, y, attn_mask, 1.0, None, freqs_cis, _kv_cache,
checkpoint_attn=checkpoint_attn,
checkpoint_mlp=checkpoint_mlp)
# final norm
if self.use_final_norm:
x, y = self.final_norm(x), self.final_norm(y) if y is not None else None
x = self.get_proj_out(x)
if not self.use_sos: # no SOS token, we need to shift the sequence
x = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1)
xa, xb = x.chunk(2, dim=-1)
# Store original dtype for output conversion
original_dtype = xa.dtype
# Convert to fp32 for numerical stability
xa, xb, x_in = xa.float(), xb.float(), x_in.float()
if not self.use_softplus:
xa = xa.exp()
else:
xa = F.softplus(xa + INV_SOFTPLUS_1)
if guidance is not None and guidance > 0:
xb, xa = self.guidance(xa, xb, guidance, 1.0, 'ab')
# NOTE: this "scale" is in fact 1/sigma, not sigma
x = self.permutation((x_in - xb) / xa, inverse=True)
logdet = -torch.log(xa) # keep all the dimensions
# Convert back to original precision
x = x.to(original_dtype)
return x, y, logdet
def guidance(self, za, zb, guidance, r=1.0, guide_what='ab'):
za, za_u = [torch.cat([a, a]) for a in za.chunk(2, dim=0)]
zb, zb_u = [torch.cat([a, a]) for a in zb.chunk(2, dim=0)]
g = r * guidance
def logits_guided(mu_c, sigma_c, mu_u, sigma_u, w):
# inspired from: (1+w) * logP_cond - w * logP_uncond
# sigma_c = torch.minimum(sigma_c, sigma_u)
s = (sigma_c / sigma_u).clip(max=1.0).square()
sigma_eff = sigma_c / (1 + w - w * s).sqrt()
mu_eff = ((1 + w) * mu_c - (w * s) * mu_u) / (1 + w - w * s)
return mu_eff, sigma_eff
def original_guidance(mu_c, sigma_c, mu_u, sigma_u, w):
if 'a' in guide_what:
sigma_c = sigma_c + g * (sigma_c - sigma_u)
if 'b' in guide_what:
mu_c = mu_c + g * (mu_c - mu_u)
return mu_c, sigma_c
#zb, za = original_guidance(zb, za, zb_u, za_u, guidance)
zb, za = logits_guided(zb, za, zb_u, za_u, guidance)
return zb, za
def reverse_step(
self, x: torch.Tensor, t: int, kv_cache: KVCache,
pos_embed: torch.Tensor | None = None, y: torch.Tensor | None = None,
attn_temp: float = 1.0, freqs_cis=None
) -> tuple[torch.Tensor, torch.Tensor]:
# Store original dtype for sampling tensor
original_dtype = x.dtype
if self.use_sos: # get i-th patch but keep the sequence dimension
x_in = self.get_sos_embed(x[:, :1]) if t == 0 else x[:, t - 1 : t]
else:
x_in = x[:, t : t + 1]
# Convert to model's dtype for neural network computation
if hasattr(self.proj_in, 'weight'):
target_dtype = self.proj_in.weight.dtype
x_in = x_in.to(target_dtype)
x = self.get_proj_in(x_in)
# if positional embedding
if pos_embed is not None:
x = x + pos_embed[t: t+1]
# main block
for i, block in enumerate(self.attn_blocks):
x, _ = block(x, None, attn_temp=attn_temp, freqs_cis=freqs_cis, kv_cache=partial(kv_cache, i))
# final norm
if self.use_final_norm:
x = self.final_norm(x)
x = self.get_proj_out(x)
xa, xb = x.chunk(2, dim=-1)
# Convert back to original dtype for sampling computations
return xa.to(original_dtype), xb.to(original_dtype)
def reverse_step_condition(self, y, kv_cache, pos_embed=None, attn_temp: float = 1.0, freqs_cis=None):
# Convert to model's dtype for neural network computation
if hasattr(self.proj_txt, 'weight'):
target_dtype = self.proj_txt.weight.dtype
y = y.to(target_dtype)
y = self.proj_txt(y)
for i, block in enumerate(self.attn_blocks):
_, y = block(None, y, attn_temp=attn_temp, freqs_cis=freqs_cis, kv_cache=partial(kv_cache, i))
return y
def reverse(
self,
z: torch.Tensor,
y: torch.Tensor | None = None,
guidance: float = 0,
guide_what: str = 'ab',
attn_temp: float = 1.0,
annealed_guidance: bool = False,
rope=None,
verbose=False,
kv_cache: KVCache=KVCache(),
**unused_kwargs
) -> torch.Tensor:
# Ensure sampling tensors are in float32 for numerical stability
original_dtype = z.dtype
z = z.float()
freqs_cis = self.get_freqs_cis(z, y, rope) if rope is not None else None
if guidance > 0:
z = torch.cat([z, z], 0)
# kv cache
reuse_kv_cache = kv_cache.prefix_cache is not None and kv_cache.kv_index[0] > 0
kv_cache = self.initialize_kv_cache(kv_cache, z, freqs_cis, reuse_kv_cache)
# permute the input
z = self.permutation(z)
pos_embed = self.permutation(self.pos_embed, dim=0) if self.pos_embed is not None else None
# run additional text condition, results will be used in KV cache.
if self.txt_dim > 0:
if not reuse_kv_cache:
self.reverse_step_condition(y, kv_cache, pos_embed, attn_temp, freqs_cis)
txt_size = y.size(1) if self.txt_dim > 0 else 0
# run the reverse process
x = z.clone()
if reuse_kv_cache:
x[:, :kv_cache.prefix_cache.size(1)] = kv_cache.prefix_cache # fill the prefix cache
T = x.size(1) - 1 if not self.use_sos else x.size(1)
for t in tqdm.trange(T, disable=not verbose, desc='Sub-flow Sampling', leave=False):
if reuse_kv_cache and kv_cache.kv_index[0] > t + txt_size:
continue
za, zb = self.reverse_step(x, t, kv_cache, pos_embed, y, attn_temp, freqs_cis)
# Ensure sampling computations stay in float32
za, zb = za.float(), zb.float()
if not self.use_softplus:
za, zb = za.exp().squeeze(1), zb.squeeze(1)
else:
za, zb = F.softplus(za + INV_SOFTPLUS_1).squeeze(1), zb.squeeze(1)
if guidance > 0 and guide_what:
r = (t + 1) / T if annealed_guidance else 1.0
zb, za = self.guidance(za, zb, guidance, r, guide_what)
if self.use_sos:
x[:, t] = z[:, t] * za + zb
else:
x[:, t + 1] = z[:, t + 1] * za + zb
if guidance > 0:
x = x.chunk(2, dim=0)[0]
kv_cache.remove_negative_cache() # remove the second half of the cache
x = self.permutation(x, inverse=True)
# Convert back to original dtype if needed
return x.to(original_dtype)
def jacobi(self,
z: torch.Tensor,
y: torch.Tensor | None = None,
guidance: float = 0,
rope=None,
kv_cache=None,
verbose=False,
jacobi_block_size: int = 32,
jacobi_max_iter: int = 32,
jacobi_th: float = 0.001,
context_length: int = None,
**unused_kwargs) -> torch.Tensor:
assert self.use_sos, "Jacobi iteration requires SOS token to be used"
assert self.pos_embed is None, "Jacobi iteration does not support positional embedding"
# Ensure sampling tensors are in float32 for numerical stability
original_dtype = z.dtype
z = z.float()
freqs_cis = self.get_freqs_cis(z, y, rope) if rope is not None else None
if guidance > 0:
z = torch.cat([z, z], 0)
# kv cache
reuse_kv_cache = kv_cache.prefix_cache is not None and kv_cache.kv_index[0] > 0
kv_cache = self.initialize_kv_cache(kv_cache, z, freqs_cis, reuse_kv_cache)
video_length = z.size(1) if z.dim() == 5 else 1
# permute the input
z = self.permutation(z)
# prepare input
x_full = torch.cat([self.get_sos_embed(z), z.clone()], dim=1)
if reuse_kv_cache:
x_full[:, 1: kv_cache.prefix_cache.size(1) + 1] = kv_cache.prefix_cache # fill the prefix cache
# conditioning
if self.txt_dim > 0:
if not reuse_kv_cache:
self.reverse_step_condition(y, kv_cache, freqs_cis=freqs_cis)
txt_size = y.size(1) if self.txt_dim > 0 else 0
video_frame_size = z.size(1) // video_length
start_idx = 0
if reuse_kv_cache:
start_idx = kv_cache.kv_index[0] - txt_size # start from the last cached index
prog_bar = tqdm.tqdm(total=z.size(1), disable=not verbose, desc='Block-wise Jacobi Iteration', leave=False)
prog_bar.update(start_idx)
local_attn_window = self.local_attn_window * video_frame_size if self.local_attn_window is not None else None
target_frame_size = z.size(1) if local_attn_window is None else min(z.size(1), local_attn_window)
context_size = None if local_attn_window is None else context_length * video_frame_size
while target_frame_size <= z.size(1):
while start_idx < target_frame_size:
chunk_size = jacobi_block_size if start_idx <= video_frame_size else jacobi_block_size * 4
local_done = torch.zeros((), dtype=torch.bool, device=x_full.device)
for i in tqdm.tqdm(range(jacobi_max_iter), disable=True, desc='Jacobi Iteration', leave=False):
if start_idx + chunk_size >= target_frame_size:
chunk_size = target_frame_size - start_idx
if i == 0 and start_idx > video_frame_size: # optional to use past frame to initialize the current frame
x = x_full[:, start_idx - video_frame_size: start_idx + chunk_size - video_frame_size]
else:
x = x_full[:, start_idx: start_idx + chunk_size]
# main forward - convert to model dtype for neural network computation
if hasattr(self.proj_in, 'weight'):
target_dtype = self.proj_in.weight.dtype
x = x.to(target_dtype)
x = self.get_proj_in(x)
for it, block in enumerate(self.attn_blocks):
_kv_cache = partial(kv_cache, it) if kv_cache is not None else None
x = block(x, None, freqs_cis=freqs_cis, kv_cache=_kv_cache)[0]
if self.use_final_norm:
x = self.final_norm(x)
x = self.get_proj_out(x)
xa, xb = x.chunk(2, dim=-1)
# Convert back to float32 for sampling computations
xa, xb = xa.float(), xb.float()
if not self.use_softplus:
xa = xa.exp()
else:
xa = F.softplus(xa + INV_SOFTPLUS_1)
if guidance > 0:
xb, xa = self.guidance(xa, xb, guidance, 1.0, 'ab')
# compute the Jacobi Iteration - all in float32
new_x = xb + xa * z[:, start_idx: start_idx+chunk_size]
diff = ((new_x - x_full[:, start_idx+1: start_idx+1+chunk_size]) ** 2).mean() / (new_x ** 2).mean()
x_full[:, start_idx+1: start_idx+1+chunk_size] = new_x
if diff < jacobi_th or i == jacobi_max_iter - 1: # do not clean the cache on the last iteration
local_done.fill_(1)
global_done = local_done.clone()
torch.distributed.all_reduce(global_done, op=torch.distributed.ReduceOp.MIN)
if int(global_done.item()) == 1:
break
kv_cache.backward_in_time(chunk_size)
start_idx += chunk_size
prog_bar.update(chunk_size)
if target_frame_size >= z.size(1):
break
target_frame_size += local_attn_window - context_size if local_attn_window is not None else video_frame_size
target_frame_size = min(target_frame_size, z.size(1))
# re-encode the context with attention blocks
print(f're-encoding the context {start_idx+1-context_size}:{start_idx+1}')
kv_cache.reset_kv_index()
if self.txt_dim > 0:
self.reverse_step_condition(y, kv_cache, freqs_cis=freqs_cis)
x_context = x_full[:, start_idx+1-context_size: start_idx+1]
x_context_in, x_context = self.get_prepared(x_context)
x_context = self.get_proj_in(x_context)
for it, block in enumerate(self.attn_blocks):
_kv_cache = partial(kv_cache, it) if kv_cache is not None else None
x_context = block(x_context, None, freqs_cis=freqs_cis, kv_cache=_kv_cache)[0]
x = x_full[:, 1:]
if guidance > 0:
x = x.chunk(2, dim=0)[0] # remove SOS token
x = self.permutation(x, inverse=True)
# Convert back to original dtype if needed
return x.to(original_dtype)
class IdentityBlock(MetaBlock):
def __init__(self, *args, **kwargs):
super(MetaBlock, self).__init__()
def forward(self, x, y=None, rope=None, **unused):
return x, y, x.new_zeros(x.size(0))
def reverse(self,
z: torch.Tensor,
y: torch.Tensor | None = None,
guidance: float = 0,
guide_what: str = 'ab',
attn_temp: float = 1.0,
annealed_guidance: bool = False,
rope=None,
verbose=False,
kv_cache: KVCache=KVCache(), **unused):
# Preserve original dtype
return z
def jacobi(self,
z: torch.Tensor,
y: torch.Tensor | None = None,
guidance: float = 0,
rope=None,
kv_cache=None,
verbose=False,
jacobi_block_size: int = 64,
jacobi_th: float = 0.005, **unused_kwargs) -> torch.Tensor:
return z
class NonCausalBlock(MetaBlock):
def __init__(
self,
in_channels: int,
channels: int,
img_size: int,
pt_seq_len: int | None = None,
num_layers: int = 8,
head_dim: int = 64,
num_heads: None | int = None,
num_kv_heads: None | int = None,
txt_size: int = 0,
txt_dim: int = 0,
expansion: float = 4,
use_rope: bool = False,
use_swiglu: bool = False,
use_qk_norm: bool =False,
use_post_norm: bool = False,
use_final_norm: bool = False,
use_bias: bool = True,
hf_style_rope: bool = False,
norm_type: str ="layer_norm",
use_checkpoint: int = False,
use_checkpoint_mlp: int = None,
block_causal: int = 0,
window: int = None,
**unused_kwargs,
):
super(MetaBlock, self).__init__()
out_channels = in_channels
self.proj_in = torch.nn.Linear(in_channels, channels)
self.proj_out = torch.nn.Linear(channels, out_channels)
torch.nn.init.constant_(self.proj_out.weight, 0)
self.txt_size = txt_size
self.img_size = img_size
self.txt_dim = txt_dim
self.pt_seq_len = pt_seq_len or img_size
self.block_causal = block_causal
self.window = window
# KV cache configurations
num_kv_heads = num_kv_heads or (num_heads or channels // head_dim)
self.kv_cache_size = [num_kv_heads, head_dim]
if txt_dim > 0:
self.proj_txt = torch.nn.Linear(txt_dim, channels)
self.attn_blocks = torch.nn.ModuleList(
[AttentionBlock(channels, head_dim, expansion, False, use_swiglu, norm_type, num_heads, num_kv_heads,
use_qk_norm, use_post_norm, use_bias, hf_style_rope, non_causal=True) for _ in range(num_layers)])
self.use_final_norm = use_final_norm
if use_final_norm:
self.final_norm = RMSNorm(channels)
self.use_checkpoint = use_checkpoint
self.use_checkpoint_mlp = use_checkpoint_mlp
self.block_masks = {} # for local attention
def get_local_window_mask(self, x, y):
_, T, H, W, _ = x.shape
L = y.size(1) if y is not None else 0
B = H * W
N = T * B
S = L + N
A = self.block_causal
G = self.window if self.window is not None else 10000
def mask(q, k):
return (k < L) | (
((k - L) // B >= (q - L) // B + A - 1 - G) &
((k - L) // B <= torch.relu(q - L) // B + A - 1)
)
return mask(torch.arange(S, device=x.device)[:, None], torch.arange(S, device=x.device)[None, :])
def forward(self, x, y, rope, **unused):
freqs_cis = self.get_freqs_cis(x, y, rope) if rope is not None else None
if self.block_causal > 0 and x.dim() == 5:
attn_mask = self.get_local_window_mask(x, y if self.txt_dim > 0 else None)
else:
attn_mask = None
if x.dim() == 5: # video input
N, H, W, x = x.size(1), x.size(2), x.size(3), rearrange(x, 'b t h w c -> b (t h w) c') # flatten x
else:
N, H, W, x = 0, x.size(1), x.size(2), rearrange(x, 'b h w c -> b (h w) c') # flatten x
x = self.get_proj_in(x)
y = self.proj_txt(y) if self.txt_dim > 0 else None
for it, block in enumerate(self.attn_blocks):
# Frequency-based checkpointing strategy:
# - Checkpoint attention every use_checkpoint blocks (if use_checkpoint > 0)
# - Checkpoint MLP every use_checkpoint_mlp blocks (if provided), otherwise every use_checkpoint blocks
checkpoint_attn = self.training and self.use_checkpoint > 0 and ((it + 1) % self.use_checkpoint == 0)
if self.use_checkpoint_mlp is not None:
checkpoint_mlp = self.training and self.use_checkpoint_mlp > 0 and ((it + 1) % self.use_checkpoint_mlp == 0)
else:
checkpoint_mlp = self.training and self.use_checkpoint > 0 and ((it + 1) % self.use_checkpoint == 0)
x, y = block(x, y, attn_mask, 1.0, None, freqs_cis,
checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp)
if self.use_final_norm:
x = self.final_norm(x)
x = self.get_proj_out(x)
if N > 0:
x = rearrange(x, 'b (t h w) d -> b t h w d', t=N, h=H, w=W)
else:
x = rearrange(x, 'b (h w) d -> b h w d', h=H, w=W)
return x
class Model(torch.nn.Module):
def __init__(
self,
in_channels: int,
img_size: int,
patch_size: int,
channels: int,
num_blocks: int,
layers_per_block: List[int],
head_dim: int = 64,
num_heads: None | int = None,
num_kv_heads: None | int = None,
rope: bool = False,
pt_seq_len: None | int = None,
sos: bool = False,
txt_size: int = 0,
txt_dim: int = 0,
cond_top_only: bool = False,
use_softplus: bool = False,
use_swiglu: bool = False,
use_bias: bool = True,
use_qk_norm: bool = False,
use_post_norm: bool = False,
use_final_norm: bool = False,
hf_style_rope: bool = False,
norm_type: str = "layer_norm",
use_checkpoint: int = 0,
use_checkpoint_mlp: int = None,
use_pretrained_lm: str | None = None,
use_mm_attn: bool = False,
soft_clip: float = 0,
seq_order: str = "R2L",
learnable_self_denoiser: bool = False,
conditional_denoiser: bool = False,
temporal_causal: int = 0,
top_block_channels: int = None, # If specified, top block uses different size
shallow_block_local: bool = False, # If True, shallow blocks only constrained within a frame
denoiser_window: int = None, # If specified, use local attention in the denoiser with given window size
local_attn_window: int = None, # If specified, use local attention in all blocks with given window size
**unused_kwargs,
):
super().__init__()
self.img_size = img_size
self.in_channels = in_channels
self.patch_size = patch_size
self.pt_seq_len = pt_seq_len or img_size // patch_size
self.num_patches = self.pt_seq_len ** 2
self.use_rope = rope
self.use_sos = sos
self.use_softplus = use_softplus
self.cond_top_only = cond_top_only
self.seq_order = seq_order
self.temporal_causal = temporal_causal
self.top_block_channels = top_block_channels or channels
self.shallow_block_local = shallow_block_local
self.expansion_init_std = 0.02
assert (not local_attn_window) or shallow_block_local, 'local_attn_window requires shallow_block_local'
assert (not shallow_block_local) or self.cond_top_only, 'shallow_block_local requires cond_top_only'
assert (not self.cond_top_only) or (txt_size > 0), 'cond_top_only requires txt_size > 0'
assert (seq_order == 'L2R') or (temporal_causal == 0), 'seq_order must be L2R if temporal causal is True'
permutations = [PermutationIdentity(self.num_patches), PermutationFlip(self.num_patches)] if temporal_causal == 0 else \
[PermutationIdentity(self.num_patches), PermutationFlipInBlock(self.num_patches)]
blocks = []
if len(layers_per_block) == 1:
layers_per_block = [layers_per_block[0]] * num_blocks
base_kwargs = dict(
in_channels=in_channels * patch_size**2,
channels=channels,
img_size=img_size // patch_size,
pt_seq_len=self.pt_seq_len,
txt_size=txt_size,
use_rope=self.use_rope, hf_style_rope=hf_style_rope, use_sos=self.use_sos,
use_softplus=self.use_softplus,
use_swiglu=use_swiglu, use_qk_norm=use_qk_norm,
use_post_norm=use_post_norm, use_final_norm=use_final_norm,
use_bias=use_bias, norm_type=norm_type, num_heads=num_heads,
num_kv_heads=num_kv_heads, head_dim=head_dim,
use_checkpoint=use_checkpoint,
use_checkpoint_mlp=use_checkpoint_mlp,
soft_clip=soft_clip,
)
# bottom blocks
for i in range(num_blocks-1):
permutation = permutations[i % 2] if seq_order == 'R2L' else permutations[(i+1) % 2]
Block = IdentityBlock if layers_per_block[i] == 0 else MetaBlock
blocks.append(Block(permutation=permutation, num_layers=layers_per_block[i], txt_dim=0 if cond_top_only else txt_dim, **base_kwargs))
# top block
gen_kwargs = copy.deepcopy(base_kwargs)
if self.top_block_channels != channels:
gen_kwargs['channels'] = self.top_block_channels
if num_heads is None:
gen_kwargs['num_heads'] = self.top_block_channels // head_dim
if use_pretrained_lm is not None:
gen_kwargs.update(eval(f"{use_pretrained_lm}_kwargs"))
if use_mm_attn:
gen_kwargs.update({"use_mm_attn": True}) # only top block will receive this
else:
gen_kwargs.update({"num_layers": layers_per_block[-1]})
permutation = permutations[(num_blocks-1) % 2] if seq_order == 'R2L' else permutations[(num_blocks) % 2]
top_block = MetaBlock(permutation=permutation, txt_dim=txt_dim, local_attn_window=local_attn_window, **gen_kwargs)
blocks.append(top_block)
# put together
self.blocks = torch.nn.ModuleList(blocks)
# Self-denoiser
if learnable_self_denoiser:
self.learnable_self_denoiser = NonCausalBlock(
num_layers=8, block_causal=temporal_causal, window=denoiser_window,
txt_dim=0 if not conditional_denoiser else txt_dim,
**base_kwargs)
# setup rotary embeddings
if self.use_rope:
self.feat_rope = VisionRotaryEmbeddingFast(
dim=base_kwargs['head_dim'] // 2, pt_seq_len=base_kwargs['pt_seq_len'], latent_len=txt_size)
if use_pretrained_lm is not None: # using standard 1D RoPE
self.feat_rope_gen = VisionRotaryEmbeddingFast(
dim=gen_kwargs['head_dim'] // 2, pt_seq_len=gen_kwargs['pt_seq_len'], no_buffer=True, is_1d=True)
else:
self.feat_rope_gen = VisionRotaryEmbeddingFast(
dim=gen_kwargs['head_dim'] // 2, pt_seq_len=gen_kwargs['pt_seq_len'], latent_len=txt_size, no_buffer=True)
else:
self.feat_rope = self.feat_rope_gen = None
# ----- DEPRECATED: not useful -------
self.register_buffer('var', torch.ones(self.num_patches, in_channels * patch_size**2))
def patchify(self, x: List[torch.Tensor] | torch.Tensor, p: int | None = None) -> torch.Tensor:
"""Convert an image (N,C',H,W) to a sequence of patches (N,T,C')"""
if len(x.shape) < 4:
return x # no need patchify
H, W = x.shape[-2], x.shape[-1]
p = self.patch_size * p if p is not None else self.patch_size
assert H % p == 0 and W % p == 0, "H and W must be divisible by patch_size"
x = rearrange(x, '... c (h p1) (w p2) -> ... h w (p1 p2 c)', p1=p, p2=p)
return x
def unpatchify(self, x: List[torch.Tensor] | torch.Tensor, p: int | None = None) -> torch.Tensor:
"""Convert a sequence of patches (N,T,C) to an image (N,C',H,W)"""
if len(x.shape) < 4:
return x # no need unpatchify
p = self.patch_size * p if p is not None else self.patch_size
H, W = x.shape[-3], x.shape[-2]
return rearrange(x, '... h w (p1 p2 c) -> ... c (h p1) (w p2)', h=H, w=W, p1=p, p2=p)
def get_loss(self,
z: torch.Tensor | List[torch.Tensor],
logdets: torch.Tensor | List[torch.Tensor],
weights: torch.Tensor | None = None,
drop_first=False) -> dict[str, torch.Tensor]:
if drop_first:
z, logdets = z[:, 1:], [logdet[:, 1:] for logdet in logdets]
loss_z = 0.5 * z.pow(2).mean(dim=tuple(range(1, z.dim())))
loss_logdet = -sum([logdet.mean(dim=tuple(range(1, logdet.dim()))) for logdet in logdets])
loss = loss_z + loss_logdet
if weights is not None:
loss = loss * weights
loss = loss.mean()
return {'loss': loss, 'loss_z': loss_z.detach().mean(), 'loss_logdet': loss_logdet.detach().mean()}
def forward(
self, x: torch.Tensor, y: torch.Tensor | None = None,
reverse=False, kv_caches=None, denoiser=False, context=False, **kwargs
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
if context:
return self.forward_context(x, y, kv_caches=kv_caches, **kwargs)
if reverse: # inference mode
return self.reverse(x, y, kv_caches=kv_caches, **kwargs)
if denoiser: # forward with self-denoiser
x = self.patchify(x)
x = self.learnable_self_denoiser(x, y, self.feat_rope, **kwargs)
return self.unpatchify(x)
logdets, outputs = [], []
guidance = kwargs.get('guidance', 0)
# Bottom blocks
x = self.patchify(x)
outputs += [x]
for it, block in enumerate(self.blocks[:-1]):
if self.shallow_block_local and x.dim() == 5: # video input
x = rearrange(x, 'b t h w c -> (b t) 1 h w c')
x, _, logdet = block(x, y.chunk(2, dim=0)[0] if self.cond_top_only and guidance > 0 else y,
self.feat_rope, kv_cache=kv_caches[-(it+1)] if kv_caches is not None else None)
if self.shallow_block_local and x.dim() == 5: # video input
x = rearrange(x, '(b t) 1 h w c -> b t h w c', b=outputs[0].size(0), t=outputs[0].size(1))
logdet = rearrange(logdet, '(b t) l c -> b t l c', b=outputs[0].size(0), t=outputs[0].size(1))
logdets += [logdet]
outputs += x if isinstance(x, list) else [x]
# Top block
x, y, logdet = self.blocks[-1](x, y, self.feat_rope_gen,
kv_cache=kv_caches[0] if kv_caches is not None else None,
guidance=guidance)
outputs += [x]
x = self.unpatchify(x)
logdets += [logdet]
return x, y, outputs, logdets
def forward_context(self, x: torch.Tensor, y: torch.Tensor | None = None, kv_caches: List[KVCache] | None = None, **kwargs):
if kv_caches is None:
kv_caches = [KVCache() for _ in range(len(self.blocks))]
use_cfg = (x.size(0) * 2 == y.size(0)) if (y is not None and self.cond_top_only) else False
if use_cfg:
x = torch.cat([x, x], 0) # duplicate for classifier-free guidance generation
self.forward(x, y, kv_caches=kv_caches, **kwargs) # run once to fill the cache
if use_cfg:
for kv in kv_caches[1:]:
kv.remove_negative_cache() # remove negative cache except for the first block
kv.prefix_cache = kv.prefix_cache.chunk(2, dim=0)[0] if kv.prefix_cache is not None else None
return kv_caches
def reverse_deep(self,
x: List[torch.Tensor] | torch.Tensor,
y: torch.Tensor | None = None,
guidance: float = 0,
verbose: bool = False,
kv_caches: List[KVCache] | None = None,
jacobi: bool = False,
need_caches: bool = False,
seq: List[torch.Tensor] = [],
**sampling_kwargs,):
x = self.patchify(x)
x = (self.blocks[-1].jacobi if jacobi else self.blocks[-1].reverse)(
x, y, guidance, rope=self.feat_rope_gen, kv_cache=kv_caches[0], verbose=verbose, **sampling_kwargs)
x = self.unpatchify(x)
if not need_caches:
kv_caches[0].delete()
seq.append(x)
return x
def reverse_shallow(self,
x: List[torch.Tensor] | torch.Tensor,
y: torch.Tensor | None = None,
guidance: float = 0,
verbose: bool = False,
kv_caches: List[KVCache] | None = None,
jacobi: bool = False,
need_caches: bool = False,
seq: List[torch.Tensor] = [],
**sampling_kwargs,):
x = self.patchify(x)
for it, block in enumerate(reversed(self.blocks[:-1])):
if self.shallow_block_local and x.dim() == 5: # video input
x = rearrange(x, 'b t h w c -> (b t) 1 h w c')
kv_caches[it+1]._is_empty = True
kv_caches[it+1].prefix_cache = None
x = (block.jacobi if jacobi else block.reverse)(
x, y, guidance, rope=self.feat_rope, kv_cache=kv_caches[it+1], verbose=verbose, **sampling_kwargs)
if self.shallow_block_local and x.dim() == 5: # video input
x = rearrange(x, '(b t) 1 h w c -> b t h w c', b=seq[0].size(0), t=seq[0].size(1))
seq.append(self.unpatchify(x))
if not need_caches:
kv_caches[it+1].delete()
x = self.unpatchify(x)
return x
def reverse(
self,
x: List[torch.Tensor] | torch.Tensor,
y: torch.Tensor | None = None,
guidance: float = 0,
guide_top: int | None = None,
return_sequence: bool = False,
verbose: bool = False,
kv_caches: List[KVCache] | None = None,
jacobi: bool = False,
**sampling_kwargs,
) -> torch.Tensor | list[torch.Tensor]:
seq, need_caches, kv_caches = [x], (kv_caches is not None), kv_caches or [KVCache() for _ in range(len(self.blocks))]
# run the deep block first
x = self.reverse_deep(x, y, guidance, verbose, kv_caches, jacobi, need_caches, seq, **sampling_kwargs)
# remove guidance if bottom is unconditional
if (guide_top is not None or self.cond_top_only) and guidance > 0:
guidance, y = 0, y.chunk(2, dim=0)[0]
# run the shallow blocks
x = self.reverse_shallow(x, y, guidance, verbose, kv_caches, jacobi, need_caches, seq, **sampling_kwargs)
return seq if return_sequence else x
#################################################################################
# TARFLow Configs #
#################################################################################
def TarFlow_XL_1(**kwargs):
return Model(num_blocks=6, layers_per_block=[2,2,2,2,10,10],
channels=2048, patch_size=1, head_dim=64, rope=1, **kwargs)
def TarFlow_XL_2(**kwargs):
return Model(num_blocks=6, layers_per_block=[2,2,2,2,10,10],
channels=2048, patch_size=2, head_dim=64, rope=1, **kwargs)
def TarFlow_XXL_1(**kwargs):
return Model(num_blocks=6, layers_per_block=[2,2,2,2,13,13],
channels=3072, patch_size=1, head_dim=64, rope=1, **kwargs)
def TarFlow_XLv2_1(**kwargs): # 1.4B
return Model(num_blocks=6, layers_per_block=[2,2,2,2,2,18],
channels=2048, patch_size=1, head_dim=64, rope=1, **kwargs)
def TarFlow_XXLv2_1(**kwargs): # 4B
return Model(num_blocks=6, layers_per_block=[2,2,2,2,2,24],
channels=3072, patch_size=1, head_dim=64, rope=1, **kwargs)
def TarFlow_Gemma2B(**kwargs): # 2B
return Model(num_blocks=6, layers_per_block=[2,2,2,2,2,26],
channels=2304, patch_size=1, rope=1,
use_rope=True, hf_style_rope=True, use_adaln=False,
use_swiglu=True, use_qk_norm=False, use_post_norm=True,
use_final_norm=True, use_bias=False, norm_type="rms_norm",
num_heads=8, num_kv_heads=4, head_dim=256, **kwargs)
# Pre-trained model configs
pre_model_configs = {
"TarFlow_XL_1": TarFlow_XL_1,
"TarFlow_XLv2_1": TarFlow_XLv2_1,
"TarFlow_XL_2": TarFlow_XL_2,
"TarFlow_XXL_1": TarFlow_XXL_1,
"TarFlow_XXLv2_1": TarFlow_XXLv2_1,
}
#################################################################################
# Pretrained LLMs #
#################################################################################
gemma3_4b_kwargs = dict(
use_rope=True, hf_style_rope=True, use_adaln=False,
use_swiglu=True, use_qk_norm=True, use_post_norm=True,
use_final_norm=True, use_bias=False, norm_type="rms_norm",
num_heads=8, num_kv_heads=4, head_dim=256, channels=2560,
num_layers=34, use_proj_txt=False)
gemma3_1b_kwargs = dict(
use_rope=True, hf_style_rope=True, use_adaln=False,
use_swiglu=True, use_qk_norm=True, use_post_norm=True,
use_final_norm=True, use_bias=False, norm_type="rms_norm",
num_heads=4, num_kv_heads=1, head_dim=256, channels=1152, expansion=6,
num_layers=26, use_proj_txt=False)
gemma2_2b_kwargs = dict(
use_rope=True, hf_style_rope=True, use_adaln=False,
use_swiglu=True, use_qk_norm=False, use_post_norm=True,
use_final_norm=True, use_bias=False, norm_type="rms_norm",
num_heads=8, num_kv_heads=4, head_dim=256, channels=2304,
num_layers=26, use_proj_txt=False)