Spaces:
Running
Running
| from typing import Optional | |
| import torch | |
| import torch.distributed as dist | |
| from colossalai.cluster.process_group_mesh import ProcessGroupMesh | |
| from torch.distributed import ProcessGroup | |
| from videosys.utils.logging import init_dist_logger, logger | |
| from videosys.utils.utils import set_seed | |
| PARALLEL_MANAGER = None | |
| class ParallelManager(ProcessGroupMesh): | |
| def __init__(self, dp_size, cp_size, sp_size): | |
| super().__init__(dp_size, cp_size, sp_size) | |
| dp_axis, cp_axis, sp_axis = 0, 1, 2 | |
| self.dp_size = dp_size | |
| self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis) | |
| self.dp_rank = dist.get_rank(self.dp_group) | |
| self.cp_size = cp_size | |
| self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis) | |
| self.cp_rank = dist.get_rank(self.cp_group) | |
| self.sp_size = sp_size | |
| self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis) | |
| self.sp_rank = dist.get_rank(self.sp_group) | |
| self.enable_sp = sp_size > 1 | |
| logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}") | |
| def set_parallel_manager(dp_size, cp_size, sp_size): | |
| global PARALLEL_MANAGER | |
| PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size) | |
| def get_data_parallel_group(): | |
| return PARALLEL_MANAGER.dp_group | |
| def get_data_parallel_size(): | |
| return PARALLEL_MANAGER.dp_size | |
| def get_data_parallel_rank(): | |
| return PARALLEL_MANAGER.dp_rank | |
| def get_sequence_parallel_group(): | |
| return PARALLEL_MANAGER.sp_group | |
| def get_sequence_parallel_size(): | |
| return PARALLEL_MANAGER.sp_size | |
| def get_sequence_parallel_rank(): | |
| return PARALLEL_MANAGER.sp_rank | |
| def get_cfg_parallel_group(): | |
| return PARALLEL_MANAGER.cp_group | |
| def get_cfg_parallel_size(): | |
| return PARALLEL_MANAGER.cp_size | |
| def enable_sequence_parallel(): | |
| if PARALLEL_MANAGER is None: | |
| return False | |
| return PARALLEL_MANAGER.enable_sp | |
| def get_parallel_manager(): | |
| return PARALLEL_MANAGER | |
| def initialize( | |
| rank=0, | |
| world_size=1, | |
| init_method=None, | |
| seed: Optional[int] = None, | |
| sp_size: Optional[int] = None, | |
| enable_cp: bool = True, | |
| ): | |
| if not dist.is_initialized(): | |
| try: | |
| dist.destroy_process_group() | |
| except Exception: | |
| pass | |
| dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank) | |
| torch.cuda.set_device(rank) | |
| init_dist_logger() | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # init sequence parallel | |
| if sp_size is None: | |
| sp_size = dist.get_world_size() | |
| dp_size = 1 | |
| else: | |
| assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size" | |
| dp_size = dist.get_world_size() // sp_size | |
| # update cfg parallel | |
| if enable_cp and sp_size % 2 == 0: | |
| sp_size = sp_size // 2 | |
| cp_size = 2 | |
| else: | |
| cp_size = 1 | |
| set_parallel_manager(dp_size, cp_size, sp_size) | |
| if seed is not None: | |
| set_seed(seed + get_data_parallel_rank()) | |