starflow / utils /model_setup.py
leoeric's picture
Initial commit for HF Space - code files only
0b4562b
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
"""
Model setup utilities for STARFlow.
Includes: transformer setup, VAE setup, text encoders.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import pathlib
import os
import numpy as np
from collections import OrderedDict
from typing import Optional, Tuple, Union
from einops import rearrange
from transformer_flow import pre_model_configs, Model
from diffusers.models import AutoencoderKL, AutoencoderKLWan
from diffusers import DiTPipeline
from misc.wan_vae2 import video_vae2 as AutoencoderKLWan2
from transformers import AutoTokenizer, AutoModel, AutoConfig, T5Tokenizer, T5EncoderModel
# ==== Model Setup Functions ====
def setup_transformer(args, dist, **other_kwargs):
"""Setup transformer model with given arguments."""
common_kwargs = dict(
in_channels=args.channel_size,
img_size=args.img_size,
txt_size=args.txt_size,
sos=args.sos, # sos_token
cond_top_only=args.cond_top_only,
use_softplus=args.use_softplus,
use_pretrained_lm=args.use_pretrained_lm,
use_mm_attn=args.use_mm_attn,
use_final_norm=args.use_final_norm,
soft_clip=args.soft_clip,
seq_order=args.seq_order,
learnable_self_denoiser=args.learnable_self_denoiser,
conditional_denoiser=args.conditional_denoiser,
noise_embed_denoiser=args.noise_embed_denoiser,
temporal_causal=args.temporal_causal,
shallow_block_local=args.shallow_block_local,
denoiser_window=args.denoiser_window,
local_attn_window=args.local_attn_window,
top_block_channels=getattr(args, 'top_block_channels', None),
)
common_kwargs.update(other_kwargs)
if getattr(args, "model_type", None) is not None:
model = pre_model_configs[args.model_type](**common_kwargs)
else:
# generic model initialization
model = Model(
patch_size=args.patch_size,
channels=args.channels,
num_blocks=args.blocks if len(args.layers_per_block) == 1 else len(args.layers_per_block),
layers_per_block=args.layers_per_block,
rope=args.rope,
pt_seq_len=args.pt_seq_len,
head_dim=args.head_dim,
num_heads=args.num_heads,
num_kv_heads=args.num_kv_heads,
use_swiglu=args.use_swiglu,
use_bias=args.use_bias,
use_qk_norm=args.use_qk_norm,
use_post_norm=args.use_post_norm,
norm_type=args.norm_type,
**common_kwargs)
if args.use_pretrained_lm: # Note: pretrained model download removed
model_name = args.use_pretrained_lm
assert model_name in ['gemma3_4b', 'gemma2_2b', 'gemma3_1b'], f'{model_name} not supported'
# Note: Pretrained LM weights are no longer automatically downloaded
# Users should provide their own pretrained weights if needed
local_path = pathlib.Path(args.logdir) / model_name / 'gemma_meta_block.pth'
if local_path.exists():
model.blocks[-1].load_state_dict(torch.load(local_path, map_location='cpu'), strict=False)
print(f'Load top block with pretrained LLM weights from {model_name}')
else:
print(f"Warning: Pretrained LM weights for {model_name} not found at {local_path}")
print("Please provide pretrained weights manually or disable use_pretrained_lm")
return model
class VAE(nn.Module):
def __init__(self, model_name, dist, adapter=None):
super().__init__()
self.model_name = model_name
self.video_vae = False
self.dist = dist
model_name, extra = model_name.split(':') if ':' in model_name else (model_name, None)
if 'Wan-AI/Wan2.1' in model_name:
self.vae = AutoencoderKLWan.from_pretrained(model_name, subfolder="vae", torch_dtype=torch.bfloat16)
self.latents_std = self.vae.config.latents_std
self.latents_mean = self.vae.config.latents_mean
self.downsample_factor = 2 ** (len(self.vae.config.dim_mult) - 1)
self.temporal_downsample_factor = 2 ** sum(self.vae.config.temperal_downsample)
self.video_vae = True # this is a Video VAE
elif 'Wan-AI/Wan2.2' in model_name:
filename = "/tmp/Wan2.2_VAE.pth" # Use local temp path, download if not exists. WAN2.2 has no diffusers
if not os.path.exists(filename):
if dist.local_rank == 0:
print("Downloading Wan2.2 VAE weights...")
os.system(f"wget https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B/resolve/main/Wan2.2_VAE.pth -O {filename}")
dist.barrier() # Ensure only one process downloads
self.vae = AutoencoderKLWan2(pretrained_path=filename)
self.downsample_factor = 16
self.video_vae = True
self.latents_std = self.vae.std
self.latents_mean = self.vae.mean
self.temporal_downsample_factor = 4
self.temporal_scale = float(extra) if extra is not None else 1
else:
if 'sd-vae' in model_name or 'sdxl-vae' in model_name:
self.vae = AutoencoderKL.from_pretrained(model_name)
self.scaling_factor = self.vae.config.scaling_factor
else:
self.vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=torch.bfloat16)
self.scaling_factor = self.vae.config.scaling_factor
self.downsample_factor = 2 ** (len(self.vae.config.down_block_types) - 1)
self.temporal_downsample_factor = 1 # this is an Image VAE, no temporal downsample
# self.vae.load_state_dict(self.vae.state_dict(), strict=False) # what is this?
self.use_adapter = adapter is not None
if self.use_adapter: # adapter is dit #
self.dit_pipe = DiTPipeline.from_pretrained(adapter, torch_dtype=torch.bfloat16)
def to(self, device):
if self.use_adapter:
self.dit_pipe.to(device)
return super().to(device)
def _encode(self, x):
return self.vae.encode(x)
def _decode(self, z):
return self.vae.decode(z)
def encode(self, x):
if self.video_vae: # video VAE
if 'Wan-AI/Wan2.2' in self.model_name:
if x.dim() == 5:
z = rearrange(self.vae.sample(rearrange(x, 'b t c h w -> b c t h w'), self.vae.scale), 'b c t h w -> b t c h w')
if self.temporal_scale != 1:
z[:, 1:] = z[:, 1:] * self.temporal_scale # scale the temporal latent
else:
z = rearrange(self.vae.sample(rearrange(x, 'b c h w -> b c 1 h w'), self.vae.scale), 'b c 1 h w -> b c h w')
else:
if x.dim() == 5:
z = rearrange(self._encode(rearrange(x, 'b t c h w -> b c t h w')).latent_dist.sample(), 'b c t h w -> b t c h w')
else:
z = rearrange(self._encode(rearrange(x, 'b c h w -> b c 1 h w')).latent_dist.sample(), 'b c 1 h w -> b c h w')
shape = [1, 1, -1, 1, 1] if z.dim() == 5 else [1, -1, 1, 1]
scale, shift = torch.tensor(self.latents_std, device=x.device).view(*shape), torch.tensor(self.latents_mean, device=x.device).view(*shape)
z = (z - shift) / scale
else: # image VAE
if x.dim() == 5:
z = rearrange(self._encode(rearrange(x, 'b t c h w -> (b t) c h w')).latent_dist.sample(), '(b t) c h w -> b t c h w', t=x.shape[1])
else:
z = self._encode(x).latent_dist.sample()
z = z * self.scaling_factor
return z
def decode(self, z, total_steps=100, noise_std=0.3):
if self.use_adapter:
z = self.adapter_denoise(z, total_steps, noise_std)
if self.video_vae: # video VAE
if 'Wan-AI/Wan2.2' in self.model_name:
if z.dim() == 5:
if self.temporal_scale != 1:
z = z.clone()
z[:, 1:] = z[:, 1:] / self.temporal_scale
x = rearrange(self.vae.decode(rearrange(z, 'b t c h w -> b c t h w'), self.vae.scale), 'b c t h w -> b t c h w')
else:
x = rearrange(self.vae.decode(rearrange(z, 'b c h w -> b c 1 h w'), self.vae.scale), 'b c 1 h w -> b c h w')
else:
shape = [1, 1, -1, 1, 1] if z.dim() == 5 else [1, -1, 1, 1]
scale = torch.tensor(self.latents_std, device=z.device).view(*shape)
shift = torch.tensor(self.latents_mean, device=z.device).view(*shape)
z = z * scale + shift
if z.dim() == 5:
x = rearrange(self._decode(rearrange(z, 'b t c h w -> b c t h w')).sample, 'b c t h w -> b t c h w')
else:
x = rearrange(self._decode(rearrange(z, 'b c h w -> b c 1 h w')).sample, 'b c 1 h w -> b c h w')
else:
z = z / self.scaling_factor
if z.dim() == 5: # (b, t, c, h, w)
x = rearrange(self._decode(rearrange(z, 'b t c h w -> (b t) c h w')).sample, '(b t) c h w -> b t c h w', t=z.shape[1])
else:
x = self._decode(z).sample
return x
@torch.no_grad()
def adapter_denoise(self, z, total_steps=100, noise_std=0.3):
self.dit_pipe.scheduler.set_timesteps(total_steps)
timesteps = self.dit_pipe.scheduler.timesteps
one = torch.ones(z.shape[0], device=z.device)
target_alpha2 = 1 / (1 + noise_std ** 2)
target_t = (torch.abs(self.dit_pipe.scheduler.alphas_cumprod - target_alpha2)).argmin().item()
z = z * np.sqrt(target_alpha2) # normalize the latent
for it in range(len(timesteps)):
if timesteps[it] > target_t: continue
noise_pred = self.dit_pipe.transformer(z, one * timesteps[it], class_labels=one.long() * 1000).sample
model_output = torch.split(noise_pred, self.dit_pipe.transformer.config.in_channels, dim=1)[0]
z = self.dit_pipe.scheduler.step(model_output, timesteps[it], z).prev_sample
return z
def setup_vae(args, dist, device='cuda'):
"""Setup VAE model with given arguments."""
print(f'Loading VAE {args.vae}...')
# setup VAE
vae = VAE(args.vae, dist=dist, adapter=getattr(args, "vae_adapter", None)).to(device)
# (optional) load pretrained VAE
if getattr(args, "finetuned_vae", None) is not None and args.finetuned_vae != 'none':
vae_task_id = args.finetuned_vae
local_folder = args.logdir / 'vae'
local_folder.mkdir(parents=True, exist_ok=True)
# Try to load from local path first
if vae_task_id == "px82zaheuu":
local_path = local_folder / "pytorch_model.bin"
if local_path.exists():
finetuned_vae_state = torch.load(local_path, map_location="cpu", weights_only=False)
renamed_state = OrderedDict()
for key in finetuned_vae_state:
new_key = key.replace("encoder.0", "encoder").replace("encoder.1", "quant_conv").replace("decoder.0", "post_quant_conv").replace("decoder.1", "decoder")
renamed_state[new_key] = finetuned_vae_state[key]
vae.vae.load_state_dict(renamed_state)
print(f'Loaded finetuned VAE {vae_task_id}')
else:
print(f"Warning: Finetuned VAE weights for {vae_task_id} not found at {local_path}")
print("Please provide finetuned VAE weights manually or set finetuned_vae to 'none'")
else:
# Try to load general task weights
local_path = local_folder / f"{vae_task_id}.pth"
if local_path.exists():
vae.load_state_dict(torch.load(local_path, map_location='cpu', weights_only=False))
print(f'Loaded finetuned VAE {vae_task_id}')
else:
print(f"Warning: Finetuned VAE weights for {vae_task_id} not found at {local_path}")
print("Please provide finetuned VAE weights manually or set finetuned_vae to 'none'")
return vae
# ==== Text Encoder Classes and Setup ====
class LookupTableTokenizer:
"""Simple lookup table tokenizer for label-based datasets."""
def __init__(self, vocab_file):
from .common import read_tsv
self.vocab = {l[0]: i for i, l in enumerate(read_tsv(f'configs/dataset/{vocab_file}'))}
self.empty_id = len(self.vocab)
def __len__(self):
return len(self.vocab)
def __call__(self, text):
return {'input_ids': torch.tensor([[self.vocab.get(t, self.empty_id)] for t in text], dtype=torch.long)}
class LabelEmbdder(nn.Module):
"""Simple label embedder for classification-style conditioning."""
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
self.config = type('Config', (), {'hidden_size': num_classes + 1})()
self.Embedding = nn.Parameter(torch.eye(num_classes+1), requires_grad=False)
def forward(self, y):
return F.embedding(y, self.Embedding)
class TextEmbedder(nn.Module):
"""Text embedder for large language models like Gemma."""
def __init__(self, config):
super().__init__()
if hasattr(config, "text_config"): # Gemma3
self.config = config.text_config
self.vocab_size = config.image_token_index
else:
self.config = config
self.vocab_size = config.vocab_size
self.text_token_embedder = nn.Embedding(
self.vocab_size, self.config.hidden_size)
self.text_token_embedder.weight.requires_grad = False
self.normalizer = float(self.config.hidden_size) ** 0.5
def forward(self, x):
x = self.text_token_embedder(x)
return (x * self.normalizer).to(x.dtype)
@torch.no_grad()
def sample(
self,
hidden_states: torch.Tensor,
temperatures: Union[float, None] = 1.0,
top_ps: float = 0.95,
top_ks: int = 64,
embedding_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = hidden_states.device
batch_size = hidden_states.shape[0]
temperatures = None if not temperatures else torch.FloatTensor(
[temperatures] * batch_size).to(device)
top_ps = torch.FloatTensor([top_ps] * batch_size).to(device)
top_ks = torch.LongTensor([top_ks] * batch_size).to(device)
# Select the last element for each sequence.
hidden_states = hidden_states[:, -1]
embedding = self.text_token_embedder.weight
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
if hasattr(self.config, 'final_logit_softcapping') and self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
if temperatures is None:
return torch.argmax(logits, dim=-1).squeeze(dim=-1), logits
# Apply temperature scaling.
logits.div_(temperatures.unsqueeze(dim=1))
# Apply top-k and top-p filtering (simplified version)
probs = F.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
return next_tokens, logits
def setup_encoder(args, dist, device='cuda'):
"""Setup text encoder based on arguments."""
assert args.txt_size > 0, 'txt_size must be set'
print(f'Loading text encoder {args.text}...')
if args.text.endswith('.vocab'): # caption -> label
tokenizer = LookupTableTokenizer(args.text)
text_encoder = LabelEmbdder(len(tokenizer)).to(device)
block_name = 'Embedding'
elif args.text == 't5xxl':
tokenizer = T5Tokenizer.from_pretrained("THUDM/CogView3-Plus-3B", subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained("THUDM/CogView3-Plus-3B",
subfolder="text_encoder", torch_dtype=torch.bfloat16).to(device)
block_name = 'T5Block'
elif args.text == 't5xl' or args.text.startswith('google'):
tokenizer = AutoTokenizer.from_pretrained(args.text)
text_encoder = AutoModel.from_pretrained(args.text, add_cross_attention=False).encoder.to(device)
block_name = 'T5Block'
elif args.text == "gemma" or args.text.startswith("Alpha-VLLM"):
tokenizer = AutoTokenizer.from_pretrained(args.text, subfolder="tokenizer")
text_encoder = AutoModel.from_pretrained(args.text, subfolder="text_encoder", torch_dtype=torch.bfloat16).to(device)
block_name = 'GemmaDecoderLayer'
elif args.text in ["gemma3_4b", "gemma3_1b", "gemma2_2b"]: # NOTE: special text embedder
model_name = args.text
repo_name = {"gemma3_4b": "google/gemma-3-4b-it",
"gemma3_1b": "google/gemma-3-1b-it",
"gemma2_2b": "google/gemma-2-2b-it"}[model_name]
tokenizer = AutoTokenizer.from_pretrained(repo_name)
config = AutoConfig.from_pretrained(repo_name)
text_encoder = TextEmbedder(config).to(device)
block_name = "Embedding"
# Try to load embedding layer
local_path = pathlib.Path(args.logdir) / model_name
local_path.mkdir(parents=True, exist_ok=True)
local_path = local_path / 'gemma_text_embed.pth'
if local_path.exists():
text_encoder.load_state_dict(torch.load(local_path, map_location='cpu'))
print(f'Loaded text encoder weights for {model_name}')
else:
print(f"Warning: Text encoder weights for {model_name} not found at {local_path}")
print("Please provide text encoder weights manually or use a different text encoder")
else:
raise NotImplementedError(f'Unknown text encoder {args.text}')
text_encoder.base_block_name = block_name
return tokenizer, text_encoder