Spaces:
Sleeping
Sleeping
File size: 8,190 Bytes
0b4562b 34395b9 0b4562b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
"""
Training utilities for STARFlow.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed
import torch.distributed.checkpoint as dcp
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy
from torch.distributed._tensor import DeviceMesh
from torch.distributed.device_mesh import init_device_mesh
import datetime
import math
import os
import random
import numpy as np
import contextlib
import typing as t
from typing import Any, Dict, List, Union, Optional
from collections import defaultdict, OrderedDict
from fnmatch import fnmatch
# ==== Learning Rate Schedule ====
class CosineLRSchedule(torch.nn.Module):
counter: torch.Tensor
def __init__(self, optimizer, warmup_steps: int, total_steps: int, min_lr: float, max_lr: float):
super().__init__()
self.register_buffer('counter', torch.zeros(()))
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.optimizer = optimizer
self.min_lr = min_lr
self.start_lr = min(min_lr, 1e-6)
self.max_lr = max_lr
self.set_lr(min_lr)
def set_lr(self, lr: float) -> float:
if self.min_lr <= lr <= self.max_lr:
for pg in self.optimizer.param_groups:
pg['lr'] = lr
return pg['lr']
def step(self) -> float:
with torch.no_grad():
counter = self.counter.add_(1).item()
if self.counter <= self.warmup_steps:
new_lr = self.start_lr + counter / self.warmup_steps * (self.max_lr - self.start_lr)
return self.set_lr(new_lr)
t = (counter - self.warmup_steps) / (self.total_steps - self.warmup_steps)
new_lr = self.min_lr + 0.5 * (1 + math.cos(math.pi * t)) * (self.max_lr - self.min_lr)
return self.set_lr(new_lr)
# ==== Distributed Training ====
class Distributed:
timeout: float = 72000
def __init__(self):
if os.environ.get('MASTER_PORT'): # When running with torchrun
self.rank = int(os.environ['RANK'])
self.local_rank = int(os.environ['LOCAL_RANK'])
self.world_size = int(os.environ['WORLD_SIZE'])
self.distributed = True
torch.distributed.init_process_group(
backend='nccl',
init_method='env://',
world_size=self.world_size,
timeout=datetime.timedelta(seconds=self.timeout),
rank=self.rank,
)
else: # When running with python for debugging
self.rank, self.local_rank, self.world_size = 0, 0, 1
self.distributed = False
# Only set CUDA device if CUDA is available
if torch.cuda.is_available():
torch.cuda.set_device(self.local_rank)
self.barrier()
def barrier(self) -> None:
if self.distributed:
torch.distributed.barrier()
def gather_concat(self, x: torch.Tensor) -> torch.Tensor:
if not self.distributed:
return x
x_list = [torch.empty_like(x) for _ in range(self.world_size)]
torch.distributed.all_gather(x_list, x)
return torch.cat(x_list)
def reduce(self, x):
if not self.distributed:
return x
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
return x
def __del__(self):
if self.distributed:
torch.distributed.destroy_process_group()
def get_local_rank() -> int:
if os.environ.get('MASTER_PORT'): # When running with torchrun
return int(os.environ['LOCAL_RANK'])
return 0
def get_device_mesh(dp_size: int, tp_size: int = 1) -> DeviceMesh:
"""Create DeviceMesh based on tensor and data parallelism configuration."""
# by default, I will use TP=1 for simplicity
mesh_shape = (dp_size, tp_size)
names = ("dp", "tp")
return init_device_mesh("cuda", mesh_shape=mesh_shape, mesh_dim_names=names)
def wrap_matching_layers(
model: nn.Module,
layer_patterns: t.List[str],
wrapper_fn: t.Callable[[nn.Module], nn.Module],
):
"""
Recursively wraps submodules in the order they appear in layer_patterns.
For each pattern (in order), we do a pass over the model and wrap matches.
"""
def _wrap_single_pattern(mod: nn.Module, pattern: str):
"""
Recurse over mod, wrapping submodules that match `pattern`.
We do a post-order traversal so children get wrapped before the parent.
"""
for child_name, child_module in list(mod.named_children()):
# Wrap grandchildren first.
_wrap_single_pattern(child_module, pattern)
# Check if the child's class name matches the pattern.
if fnmatch(child_module.__class__.__name__, pattern):
# Replace the child in the parent.
wrapped = wrapper_fn(child_module)
setattr(mod, child_name, wrapped)
# We do a pass for each pattern in order
for pattern in layer_patterns:
_wrap_single_pattern(model, pattern)
def parallelize_model(args, model: nn.Module, dist: Distributed, device='cuda', block_names=['AttentionBlock']) -> nn.Module:
if not getattr(args, "fsdp", False): # use standard DDP
model = model.to(device=device)
if dist.distributed:
print(f"Using DDP")
model_ddp = torch.nn.parallel.DistributedDataParallel(model, device_ids=[dist.local_rank])
else:
model_ddp = model # compatible with DDP
return model, model_ddp
# Instantiate mixed precision policy from config
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
output_dtype=torch.bfloat16,
cast_forward_inputs=True
)
print(f"Using FSDP2 with: {mp_policy}")
# Apply FSDP wrapping based on specified parallel dimensions
dp_mesh = get_device_mesh(dist.world_size)["dp"]
# Configure core FSDP parameters
fsdp_config = {"mp_policy": mp_policy, "mesh": dp_mesh, "reshard_after_forward": True}
# Wrap specified layer patterns with FSDP
wrap_matching_layers(model, block_names, lambda m: fully_shard(m, **fsdp_config))
# Then wrap full model (remaining modules are captured with this)
model = fully_shard(model, **fsdp_config)
model = model.to(device=device)
return model, model # for compatibility with DDP
def save_model(args, dist, model, model_ckpt_file):
states = model.state_dict()
if not getattr(args, "fsdp", False): # save DDP checkpoints
if dist.local_rank == 0:
torch.save(states, model_ckpt_file)
else: # save FSDP checkpoints
dcp.save(states, checkpoint_id=str(model_ckpt_file))
def save_optimizer(args, dist, optimizer, lr_schedule, opt_ckpt_file):
optim_states, lr_states = optimizer.state_dict(), lr_schedule.state_dict()
if not getattr(args, "fsdp", False): # save DDP checkpoints
if dist.local_rank == 0:
torch.save({"optimizer": optim_states, "lr_schedule": lr_states}, opt_ckpt_file)
else:
filename = str(opt_ckpt_file)
dcp.save(optim_states, checkpoint_id=f"{filename}/optimizer")
torch.save(lr_states, f"{filename}/lr_schedule.bin") # lr_schedule is not fsdp
@contextlib.contextmanager
def _fsdp2_no_sync(module, sync):
# v2 APIs
module.set_requires_gradient_sync(sync, recurse=True)
try:
yield
finally:
module.set_requires_gradient_sync(True, recurse=True)
def sync_ctx(model, sync=True):
if hasattr(model, 'set_requires_gradient_sync'):
return _fsdp2_no_sync(model, sync)
elif not sync and hasattr(model, 'no_sync'):
return model.no_sync()
return contextlib.nullcontext()
# ==== Utility Functions ====
def set_random_seed(seed: int) -> None:
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) |