Instructions to use Overworld/Waypoint-1-Small with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Overworld/Waypoint-1-Small with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Overworld/Waypoint-1-Small", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright (C) 2025 Hugging Face Team and Overworld | |
| # | |
| # This program is free software: you can redistribute it and/or modify | |
| # it under the terms of the GNU General Public License as published by | |
| # the Free Software Foundation, either version 3 of the License, or | |
| # (at your option) any later version. | |
| # | |
| # This program is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # GNU General Public License for more details. | |
| # | |
| # You should have received a copy of the GNU General Public License | |
| # along with this program. If not, see <https://www.gnu.org/licenses/>. | |
| """WorldModel transformer for frame generation.""" | |
| from typing import Optional, List | |
| import math | |
| import einops as eo | |
| import torch | |
| from torch import nn, Tensor | |
| import torch.nn.functional as F | |
| from tensordict import TensorDict | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from .attn import Attn, MergedQKVAttn, CrossAttention | |
| from .nn import AdaLN, MLP, NoiseConditioner, ada_gate, ada_rmsnorm, rms_norm | |
| from .quantize import quantize_model | |
| from .cache import CachedDenoiseStepEmb, CachedCondHead | |
| def patch_cached_noise_conditioning(model) -> None: | |
| # Call AFTER: model.to(device="cuda", dtype=torch.bfloat16).eval() | |
| cached_denoise_step_emb = CachedDenoiseStepEmb( | |
| model.denoise_step_emb, model.config.scheduler_sigmas | |
| ) | |
| model.denoise_step_emb = cached_denoise_step_emb | |
| for blk in model.transformer.blocks: | |
| blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb) | |
| def patch_Attn_merge_qkv(model) -> None: | |
| for name, mod in list(model.named_modules()): | |
| if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn): | |
| model.set_submodule(name, MergedQKVAttn(mod, model.config)) | |
| def patch_MLPFusion_split(model) -> None: | |
| for name, mod in list(model.named_modules()): | |
| if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion): | |
| model.set_submodule(name, SplitMLPFusion(mod)) | |
| def _apply_inference_patches(model) -> None: | |
| patch_cached_noise_conditioning(model) | |
| patch_Attn_merge_qkv(model) | |
| patch_MLPFusion_split(model) | |
| class CFG(nn.Module): | |
| def __init__(self, d_model: int, dropout: float): | |
| super().__init__() | |
| self.dropout = dropout | |
| self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model)) | |
| def forward( | |
| self, x: torch.Tensor, is_conditioned: Optional[bool] = None | |
| ) -> torch.Tensor: | |
| """ | |
| x: [B, L, D] | |
| is_conditioned: | |
| - None: training-style random dropout | |
| - bool: whole batch conditioned / unconditioned at sampling | |
| """ | |
| B, L, _ = x.shape | |
| null = self.null_emb.expand(B, L, -1) | |
| # training-style dropout OR unspecified | |
| if self.training or is_conditioned is None: | |
| if self.dropout == 0.0: | |
| return x | |
| drop = torch.rand(B, 1, 1, device=x.device) < self.dropout # [B,1,1] | |
| return torch.where(drop, null, x) | |
| # sampling-time switch | |
| return x if is_conditioned else null | |
| class ControllerInputEmbedding(nn.Module): | |
| """Embeds controller inputs (mouse + buttons) into model dimension.""" | |
| def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4): | |
| super().__init__() | |
| self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) # mouse velocity (x,y) + scroll sign | |
| def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor): | |
| assert len(mouse.shape) == 3 | |
| x = torch.cat((mouse, button, scroll), dim=-1) | |
| return self.mlp(x) | |
| class MLPFusion(nn.Module): | |
| """Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond]).""" | |
| def __init__(self, d_model: int): | |
| super().__init__() | |
| self.mlp = MLP(2 * d_model, d_model, d_model) | |
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: | |
| B, _, D = x.shape | |
| L = cond.shape[1] | |
| Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) # each [D, D] | |
| x = x.view(B, L, -1, D) | |
| h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze( | |
| 2 | |
| ) # broadcast, no repeat/cat | |
| h = F.silu(h) | |
| y = F.linear(h, self.mlp.fc2.weight) | |
| return y.flatten(1, 2) | |
| class SplitMLPFusion(nn.Module): | |
| """Packed MLPFusion -> split linears (no cat, quant-friendly).""" | |
| def __init__(self, src: MLPFusion): | |
| super().__init__() | |
| D = src.mlp.fc2.in_features | |
| dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype | |
| self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt) | |
| self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt) | |
| self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt) | |
| with torch.no_grad(): | |
| Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1) | |
| self.fc1_x.weight.copy_(Wx) | |
| self.fc1_c.weight.copy_(Wc) | |
| self.fc2.weight.copy_(src.mlp.fc2.weight) | |
| self.train(src.training) | |
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: | |
| B, _, D = x.shape | |
| L = cond.shape[1] | |
| x = x.reshape(B, L, -1, D) | |
| return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten( | |
| 1, 2 | |
| ) | |
| class CondHead(nn.Module): | |
| """Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond).""" | |
| n_cond = 6 | |
| def __init__(self, d_model: int, noise_conditioning: str = "wan"): | |
| super().__init__() | |
| self.bias_in = ( | |
| nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None | |
| ) | |
| self.cond_proj = nn.ModuleList( | |
| [nn.Linear(d_model, d_model, bias=False) for _ in range(self.n_cond)] | |
| ) | |
| def forward(self, cond): | |
| cond = cond + self.bias_in if self.bias_in is not None else cond | |
| h = F.silu(cond) | |
| return tuple(p(h) for p in self.cond_proj) | |
| class WorldDiTBlock(nn.Module): | |
| """Single transformer block with self-attention, optional cross-attention, and MLP.""" | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_heads: int, | |
| mlp_ratio: int, | |
| layer_idx: int, | |
| prompt_conditioning: Optional[str], | |
| prompt_conditioning_period: int, | |
| prompt_embedding_dim: int, | |
| ctrl_conditioning_period: int, | |
| noise_conditioning: str, | |
| config, | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self.attn = Attn(config, layer_idx) | |
| self.mlp = MLP(d_model, d_model * mlp_ratio, d_model) | |
| self.cond_head = CondHead(d_model, noise_conditioning) | |
| do_prompt_cond = ( | |
| prompt_conditioning is not None | |
| and layer_idx % prompt_conditioning_period == 0 | |
| ) | |
| self.prompt_cross_attn = ( | |
| CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None | |
| ) | |
| do_ctrl_cond = layer_idx % ctrl_conditioning_period == 0 | |
| self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None | |
| def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): | |
| """ | |
| 0) Causal Frame Attention | |
| 1) Frame->CTX Cross Attention | |
| 2) MLP | |
| """ | |
| s0, b0, g0, s1, b1, g1 = self.cond_head(cond) | |
| # Self / Causal Attention | |
| residual = x | |
| x = ada_rmsnorm(x, s0, b0) | |
| x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache) | |
| x = ada_gate(x, g0) + residual | |
| # Cross Attention Prompt Conditioning | |
| if self.prompt_cross_attn is not None: | |
| x = ( | |
| self.prompt_cross_attn( | |
| rms_norm(x), | |
| context=rms_norm(ctx["prompt_emb"]), | |
| context_pad_mask=ctx["prompt_pad_mask"], | |
| ) | |
| + x | |
| ) | |
| # MLPFusion Controller Conditioning | |
| if self.ctrl_mlpfusion is not None: | |
| x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x | |
| # MLP | |
| x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x | |
| return x, v | |
| class WorldDiT(nn.Module): | |
| """Stack of WorldDiTBlocks with shared parameters.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.blocks = nn.ModuleList( | |
| [ | |
| WorldDiTBlock( | |
| d_model=config.d_model, | |
| n_heads=config.n_heads, | |
| mlp_ratio=config.mlp_ratio, | |
| layer_idx=idx, | |
| prompt_conditioning=config.prompt_conditioning, | |
| prompt_conditioning_period=config.prompt_conditioning_period, | |
| prompt_embedding_dim=config.prompt_embedding_dim, | |
| ctrl_conditioning_period=config.ctrl_conditioning_period, | |
| noise_conditioning=config.noise_conditioning, | |
| config=config, | |
| ) | |
| for idx in range(config.n_layers) | |
| ] | |
| ) | |
| if config.noise_conditioning in ("dit_air", "wan"): | |
| ref_proj = self.blocks[0].cond_head.cond_proj | |
| for blk in self.blocks[1:]: | |
| for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj): | |
| blk_mod.weight = ref_mod.weight | |
| # Shared RoPE buffers | |
| ref_rope = self.blocks[0].attn.rope | |
| for blk in self.blocks[1:]: | |
| blk.attn.rope = ref_rope | |
| def forward(self, x, pos_ids, cond, ctx, kv_cache=None): | |
| v = None | |
| for i, block in enumerate(self.blocks): | |
| x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache) | |
| return x | |
| class WorldModel(ModelMixin, ConfigMixin): | |
| """ | |
| WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser. | |
| Denoises a frame given: | |
| - All previous frames (via KV cache) | |
| - The prompt embedding | |
| - The controller input embedding | |
| - The current noise level | |
| """ | |
| _supports_gradient_checkpointing = False | |
| _keep_in_fp32_modules = ["denoise_step_emb", "rope"] | |
| def __init__( | |
| self, | |
| # Model architecture | |
| d_model: int = 2560, | |
| n_heads: int = 40, | |
| n_kv_heads: Optional[int] = 20, | |
| n_layers: int = 22, | |
| mlp_ratio: int = 5, | |
| channels: int = 16, | |
| height: int = 16, | |
| width: int = 16, | |
| patch: tuple = (2, 2), | |
| tokens_per_frame: int = 256, | |
| n_frames: int = 512, | |
| local_window: int = 16, | |
| global_window: int = 128, | |
| global_attn_period: int = 4, | |
| global_pinned_dilation: int = 8, | |
| global_attn_offset: int = -1, | |
| value_residual: bool = False, | |
| gated_attn: bool = True, | |
| n_buttons: int = 256, | |
| ctrl_conditioning: Optional[str] = "mlp_fusion", | |
| ctrl_conditioning_period: int = 3, | |
| ctrl_cond_dropout: float = 0.0, | |
| prompt_conditioning: Optional[str] = "cross_attention", | |
| prompt_conditioning_period: int = 3, | |
| prompt_embedding_dim: int = 2048, | |
| prompt_cond_dropout: float = 0.0, | |
| noise_conditioning: str = "wan", | |
| scheduler_sigmas: Optional[List[float]] = [ | |
| 1.0, | |
| 0.9483006596565247, | |
| 0.8379597067832947, | |
| 0.0, | |
| ], | |
| base_fps: int = 60, | |
| causal: bool = True, | |
| mlp_gradient_checkpointing: bool = True, | |
| block_gradient_checkpointing: bool = True, | |
| rope_impl: str = "ortho", | |
| ): | |
| super().__init__() | |
| self.denoise_step_emb = NoiseConditioner(d_model) | |
| self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio) | |
| if self.config.ctrl_conditioning is not None: | |
| self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout) | |
| if self.config.prompt_conditioning is not None: | |
| self.prompt_cfg = CFG( | |
| self.config.prompt_embedding_dim, self.config.prompt_cond_dropout | |
| ) | |
| self.transformer = WorldDiT(self.config) | |
| self.patch = tuple(patch) | |
| C, D = channels, d_model | |
| self.patchify = nn.Conv2d( | |
| C, D, kernel_size=self.patch, stride=self.patch, bias=False | |
| ) | |
| self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True) | |
| self.out_norm = AdaLN(d_model) | |
| # Cached 1-frame pos_ids (buffers + cached TensorDict view) | |
| T = tokens_per_frame | |
| idx = torch.arange(T, dtype=torch.long) | |
| self.register_buffer( | |
| "_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False | |
| ) | |
| self.register_buffer( | |
| "_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False | |
| ) | |
| self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False) | |
| def forward( | |
| self, | |
| x: Tensor, | |
| sigma: Tensor, | |
| frame_timestamp: Tensor, | |
| prompt_emb: Optional[Tensor] = None, | |
| prompt_pad_mask: Optional[Tensor] = None, | |
| mouse: Optional[Tensor] = None, | |
| button: Optional[Tensor] = None, | |
| scroll: Optional[Tensor] = None, | |
| kv_cache=None, | |
| ): | |
| """ | |
| Args: | |
| x: [B, N, C, H, W] - latent frames | |
| sigma: [B, N] - noise levels | |
| frame_timestamp: [B, N] - frame indices | |
| prompt_emb: [B, P, D] - prompt embeddings | |
| prompt_pad_mask: [B, P] - padding mask for prompts | |
| mouse: [B, N, 2] - mouse velocity | |
| button: [B, N, n_buttons] - button states | |
| scroll: [B, N, 1] - scroll wheel sign (-1, 0, 1) | |
| kv_cache: StaticKVCache instance | |
| ctrl_cond: whether to apply controller conditioning (inference only) | |
| prompt_cond: whether to apply prompt conditioning (inference only) | |
| """ | |
| B, N, C, H, W = x.shape | |
| ph, pw = self.patch | |
| assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch" | |
| Hp, Wp = H // ph, W // pw | |
| torch._assert( | |
| Hp * Wp == self.config.tokens_per_frame, | |
| f"{Hp} * {Wp} != {self.config.tokens_per_frame}", | |
| ) | |
| torch._assert( | |
| B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1" | |
| ) | |
| self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) | |
| pos_ids = TensorDict( | |
| { | |
| "t_pos": self._t_pos_1f[None], | |
| "y_pos": self._y_pos_1f[None], | |
| "x_pos": self._x_pos_1f[None], | |
| }, | |
| batch_size=[1, self._t_pos_1f.numel()], | |
| ) | |
| cond = self.denoise_step_emb(sigma) # [B, N, d] | |
| assert button is not None | |
| ctx = { | |
| "ctrl_emb": self.ctrl_emb(mouse, button, scroll), | |
| "prompt_emb": prompt_emb, | |
| "prompt_pad_mask": prompt_pad_mask, | |
| } | |
| D = self.unpatchify.in_features | |
| x = self.patchify(x.reshape(B * N, C, H, W)) | |
| x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d") | |
| x = self.transformer(x, pos_ids, cond, ctx, kv_cache) | |
| x = F.silu(self.out_norm(x, cond)) | |
| x = eo.rearrange( | |
| self.unpatchify(x), | |
| "b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)", | |
| n=N, | |
| hp=Hp, | |
| wp=Wp, | |
| ph=ph, | |
| pw=pw, | |
| ) | |
| return x | |
| def quantize(self, quant_type: str): | |
| quantize_model(self, quant_type) | |
| def apply_inference_patches(self): | |
| _apply_inference_patches(self) | |