# # For licensing see accompanying LICENSE file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # """ Training utilities for STARFlow. """ import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed import torch.distributed.checkpoint as dcp from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy from torch.distributed._tensor import DeviceMesh from torch.distributed.device_mesh import init_device_mesh import datetime import math import os import random import numpy as np import contextlib import typing as t from typing import Any, Dict, List, Union, Optional from collections import defaultdict, OrderedDict from fnmatch import fnmatch # ==== Learning Rate Schedule ==== class CosineLRSchedule(torch.nn.Module): counter: torch.Tensor def __init__(self, optimizer, warmup_steps: int, total_steps: int, min_lr: float, max_lr: float): super().__init__() self.register_buffer('counter', torch.zeros(())) self.warmup_steps = warmup_steps self.total_steps = total_steps self.optimizer = optimizer self.min_lr = min_lr self.start_lr = min(min_lr, 1e-6) self.max_lr = max_lr self.set_lr(min_lr) def set_lr(self, lr: float) -> float: if self.min_lr <= lr <= self.max_lr: for pg in self.optimizer.param_groups: pg['lr'] = lr return pg['lr'] def step(self) -> float: with torch.no_grad(): counter = self.counter.add_(1).item() if self.counter <= self.warmup_steps: new_lr = self.start_lr + counter / self.warmup_steps * (self.max_lr - self.start_lr) return self.set_lr(new_lr) t = (counter - self.warmup_steps) / (self.total_steps - self.warmup_steps) new_lr = self.min_lr + 0.5 * (1 + math.cos(math.pi * t)) * (self.max_lr - self.min_lr) return self.set_lr(new_lr) # ==== Distributed Training ==== class Distributed: timeout: float = 72000 def __init__(self): if os.environ.get('MASTER_PORT'): # When running with torchrun self.rank = int(os.environ['RANK']) self.local_rank = int(os.environ['LOCAL_RANK']) self.world_size = int(os.environ['WORLD_SIZE']) self.distributed = True torch.distributed.init_process_group( backend='nccl', init_method='env://', world_size=self.world_size, timeout=datetime.timedelta(seconds=self.timeout), rank=self.rank, ) else: # When running with python for debugging self.rank, self.local_rank, self.world_size = 0, 0, 1 self.distributed = False # Only set CUDA device if CUDA is available if torch.cuda.is_available(): torch.cuda.set_device(self.local_rank) self.barrier() def barrier(self) -> None: if self.distributed: torch.distributed.barrier() def gather_concat(self, x: torch.Tensor) -> torch.Tensor: if not self.distributed: return x x_list = [torch.empty_like(x) for _ in range(self.world_size)] torch.distributed.all_gather(x_list, x) return torch.cat(x_list) def reduce(self, x): if not self.distributed: return x torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM) return x def __del__(self): if self.distributed: torch.distributed.destroy_process_group() def get_local_rank() -> int: if os.environ.get('MASTER_PORT'): # When running with torchrun return int(os.environ['LOCAL_RANK']) return 0 def get_device_mesh(dp_size: int, tp_size: int = 1) -> DeviceMesh: """Create DeviceMesh based on tensor and data parallelism configuration.""" # by default, I will use TP=1 for simplicity mesh_shape = (dp_size, tp_size) names = ("dp", "tp") return init_device_mesh("cuda", mesh_shape=mesh_shape, mesh_dim_names=names) def wrap_matching_layers( model: nn.Module, layer_patterns: t.List[str], wrapper_fn: t.Callable[[nn.Module], nn.Module], ): """ Recursively wraps submodules in the order they appear in layer_patterns. For each pattern (in order), we do a pass over the model and wrap matches. """ def _wrap_single_pattern(mod: nn.Module, pattern: str): """ Recurse over mod, wrapping submodules that match `pattern`. We do a post-order traversal so children get wrapped before the parent. """ for child_name, child_module in list(mod.named_children()): # Wrap grandchildren first. _wrap_single_pattern(child_module, pattern) # Check if the child's class name matches the pattern. if fnmatch(child_module.__class__.__name__, pattern): # Replace the child in the parent. wrapped = wrapper_fn(child_module) setattr(mod, child_name, wrapped) # We do a pass for each pattern in order for pattern in layer_patterns: _wrap_single_pattern(model, pattern) def parallelize_model(args, model: nn.Module, dist: Distributed, device='cuda', block_names=['AttentionBlock']) -> nn.Module: if not getattr(args, "fsdp", False): # use standard DDP model = model.to(device=device) if dist.distributed: print(f"Using DDP") model_ddp = torch.nn.parallel.DistributedDataParallel(model, device_ids=[dist.local_rank]) else: model_ddp = model # compatible with DDP return model, model_ddp # Instantiate mixed precision policy from config mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, output_dtype=torch.bfloat16, cast_forward_inputs=True ) print(f"Using FSDP2 with: {mp_policy}") # Apply FSDP wrapping based on specified parallel dimensions dp_mesh = get_device_mesh(dist.world_size)["dp"] # Configure core FSDP parameters fsdp_config = {"mp_policy": mp_policy, "mesh": dp_mesh, "reshard_after_forward": True} # Wrap specified layer patterns with FSDP wrap_matching_layers(model, block_names, lambda m: fully_shard(m, **fsdp_config)) # Then wrap full model (remaining modules are captured with this) model = fully_shard(model, **fsdp_config) model = model.to(device=device) return model, model # for compatibility with DDP def save_model(args, dist, model, model_ckpt_file): states = model.state_dict() if not getattr(args, "fsdp", False): # save DDP checkpoints if dist.local_rank == 0: torch.save(states, model_ckpt_file) else: # save FSDP checkpoints dcp.save(states, checkpoint_id=str(model_ckpt_file)) def save_optimizer(args, dist, optimizer, lr_schedule, opt_ckpt_file): optim_states, lr_states = optimizer.state_dict(), lr_schedule.state_dict() if not getattr(args, "fsdp", False): # save DDP checkpoints if dist.local_rank == 0: torch.save({"optimizer": optim_states, "lr_schedule": lr_states}, opt_ckpt_file) else: filename = str(opt_ckpt_file) dcp.save(optim_states, checkpoint_id=f"{filename}/optimizer") torch.save(lr_states, f"{filename}/lr_schedule.bin") # lr_schedule is not fsdp @contextlib.contextmanager def _fsdp2_no_sync(module, sync): # v2 APIs module.set_requires_gradient_sync(sync, recurse=True) try: yield finally: module.set_requires_gradient_sync(True, recurse=True) def sync_ctx(model, sync=True): if hasattr(model, 'set_requires_gradient_sync'): return _fsdp2_no_sync(model, sync) elif not sync and hasattr(model, 'no_sync'): return model.no_sync() return contextlib.nullcontext() # ==== Utility Functions ==== def set_random_seed(seed: int) -> None: """Set random seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed)