Spaces:
Running
Running
| from typing import Any, Optional, Tuple | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from torch import Tensor | |
| from torch.distributed import ProcessGroup | |
| from videosys.core.parallel_mgr import get_sequence_parallel_size | |
| # ====================================================== | |
| # Model | |
| # ====================================================== | |
| def model_sharding(model: torch.nn.Module): | |
| global_rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| for _, param in model.named_parameters(): | |
| padding_size = (world_size - param.numel() % world_size) % world_size | |
| if padding_size > 0: | |
| padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) | |
| else: | |
| padding_param = param.data.view(-1) | |
| splited_params = padding_param.split(padding_param.numel() // world_size) | |
| splited_params = splited_params[global_rank] | |
| param.data = splited_params | |
| # ====================================================== | |
| # AllGather & ReduceScatter | |
| # ====================================================== | |
| class AsyncAllGatherForTwo(torch.autograd.Function): | |
| def forward( | |
| ctx: Any, | |
| inputs: Tensor, | |
| weight: Tensor, | |
| bias: Tensor, | |
| sp_rank: int, | |
| sp_size: int, | |
| group: Optional[ProcessGroup] = None, | |
| ) -> Tuple[Tensor, Any]: | |
| """ | |
| Returns: | |
| outputs: Tensor | |
| handle: Optional[Work], if overlap is True | |
| """ | |
| from torch.distributed._functional_collectives import all_gather_tensor | |
| ctx.group = group | |
| ctx.sp_rank = sp_rank | |
| ctx.sp_size = sp_size | |
| # all gather inputs | |
| all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group) | |
| # compute local qkv | |
| local_qkv = F.linear(inputs, weight, bias).unsqueeze(0) | |
| # remote compute | |
| remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1]) | |
| # compute remote qkv | |
| remote_qkv = F.linear(remote_inputs, weight, bias) | |
| # concat local and remote qkv | |
| if sp_rank == 0: | |
| qkv = torch.cat([local_qkv, remote_qkv], dim=0) | |
| else: | |
| qkv = torch.cat([remote_qkv, local_qkv], dim=0) | |
| qkv = rearrange(qkv, "sp b n c -> b (sp n) c") | |
| ctx.save_for_backward(inputs, weight, remote_inputs) | |
| return qkv | |
| def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: | |
| from torch.distributed._functional_collectives import reduce_scatter_tensor | |
| group = ctx.group | |
| sp_rank = ctx.sp_rank | |
| sp_size = ctx.sp_size | |
| inputs, weight, remote_inputs = ctx.saved_tensors | |
| # split qkv_grad | |
| qkv_grad = grad_outputs[0] | |
| qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size) | |
| qkv_grad = torch.chunk(qkv_grad, 2, dim=0) | |
| if sp_rank == 0: | |
| local_qkv_grad, remote_qkv_grad = qkv_grad | |
| else: | |
| remote_qkv_grad, local_qkv_grad = qkv_grad | |
| # compute remote grad | |
| remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0) | |
| weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0) | |
| bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0) | |
| # launch async reduce scatter | |
| remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad) | |
| if sp_rank == 0: | |
| remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0) | |
| else: | |
| remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0) | |
| remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group) | |
| # compute local grad and wait for reduce scatter | |
| local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0) | |
| weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0) | |
| bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0) | |
| # sum remote and local grad | |
| inputs_grad = remote_inputs_grad + local_input_grad | |
| return inputs_grad, weight_grad, bias_grad, None, None, None | |
| class AllGather(torch.autograd.Function): | |
| def forward( | |
| ctx: Any, | |
| inputs: Tensor, | |
| group: Optional[ProcessGroup] = None, | |
| overlap: bool = False, | |
| ) -> Tuple[Tensor, Any]: | |
| """ | |
| Returns: | |
| outputs: Tensor | |
| handle: Optional[Work], if overlap is True | |
| """ | |
| assert ctx is not None or not overlap | |
| if ctx is not None: | |
| ctx.comm_grp = group | |
| comm_size = dist.get_world_size(group) | |
| if comm_size == 1: | |
| return inputs.unsqueeze(0), None | |
| buffer_shape = (comm_size,) + inputs.shape | |
| outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) | |
| buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) | |
| if not overlap: | |
| dist.all_gather(buffer_list, inputs, group=group) | |
| return outputs, None | |
| else: | |
| handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) | |
| return outputs, handle | |
| def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: | |
| return ( | |
| ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], | |
| None, | |
| None, | |
| ) | |
| class ReduceScatter(torch.autograd.Function): | |
| def forward( | |
| ctx: Any, | |
| inputs: Tensor, | |
| group: ProcessGroup, | |
| overlap: bool = False, | |
| ) -> Tuple[Tensor, Any]: | |
| """ | |
| Returns: | |
| outputs: Tensor | |
| handle: Optional[Work], if overlap is True | |
| """ | |
| assert ctx is not None or not overlap | |
| if ctx is not None: | |
| ctx.comm_grp = group | |
| comm_size = dist.get_world_size(group) | |
| if comm_size == 1: | |
| return inputs.squeeze(0), None | |
| if not inputs.is_contiguous(): | |
| inputs = inputs.contiguous() | |
| output_shape = inputs.shape[1:] | |
| outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) | |
| buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) | |
| if not overlap: | |
| dist.reduce_scatter(outputs, buffer_list, group=group) | |
| return outputs, None | |
| else: | |
| handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) | |
| return outputs, handle | |
| def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: | |
| # TODO: support async backward | |
| return ( | |
| AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], | |
| None, | |
| None, | |
| ) | |
| # ====================================================== | |
| # AlltoAll | |
| # ====================================================== | |
| def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim): | |
| input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] | |
| output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] | |
| dist.all_to_all(output_list, input_list, group=group) | |
| return torch.cat(output_list, dim=gather_dim).contiguous() | |
| class _AllToAll(torch.autograd.Function): | |
| """All-to-all communication. | |
| Args: | |
| input_: input matrix | |
| process_group: communication group | |
| scatter_dim: scatter dimension | |
| gather_dim: gather dimension | |
| """ | |
| def forward(ctx, input_, process_group, scatter_dim, gather_dim): | |
| ctx.process_group = process_group | |
| ctx.scatter_dim = scatter_dim | |
| ctx.gather_dim = gather_dim | |
| world_size = dist.get_world_size(process_group) | |
| return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) | |
| def backward(ctx, *grad_output): | |
| process_group = ctx.process_group | |
| scatter_dim = ctx.gather_dim | |
| gather_dim = ctx.scatter_dim | |
| return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) | |
| return (return_grad, None, None, None) | |
| def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): | |
| return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) | |
| # ====================================================== | |
| # Sequence Gather & Split | |
| # ====================================================== | |
| def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int): | |
| # skip if only one rank involved | |
| world_size = dist.get_world_size(pg) | |
| rank = dist.get_rank(pg) | |
| if world_size == 1: | |
| return input_ | |
| if pad > 0: | |
| pad_size = list(input_.shape) | |
| pad_size[dim] = pad | |
| input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim) | |
| dim_size = input_.size(dim) | |
| assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})" | |
| tensor_list = torch.split(input_, dim_size // world_size, dim=dim) | |
| output = tensor_list[rank].contiguous() | |
| return output | |
| def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int): | |
| # skip if only one rank involved | |
| input_ = input_.contiguous() | |
| world_size = dist.get_world_size(pg) | |
| dist.get_rank(pg) | |
| if world_size == 1: | |
| return input_ | |
| # all gather | |
| tensor_list = [torch.empty_like(input_) for _ in range(world_size)] | |
| assert input_.device.type == "cuda" | |
| torch.distributed.all_gather(tensor_list, input_, group=pg) | |
| # concat | |
| output = torch.cat(tensor_list, dim=dim) | |
| if pad > 0: | |
| output = output.narrow(dim, 0, output.size(dim) - pad) | |
| return output | |
| class _GatherForwardSplitBackward(torch.autograd.Function): | |
| """ | |
| Gather the input sequence. | |
| Args: | |
| input_: input matrix. | |
| process_group: process group. | |
| dim: dimension | |
| """ | |
| def symbolic(graph, input_): | |
| return _gather_sequence_func(input_) | |
| def forward(ctx, input_, process_group, dim, grad_scale, pad): | |
| ctx.process_group = process_group | |
| ctx.dim = dim | |
| ctx.grad_scale = grad_scale | |
| ctx.pad = pad | |
| return _gather_sequence_func(input_, process_group, dim, pad) | |
| def backward(ctx, grad_output): | |
| if ctx.grad_scale == "up": | |
| grad_output = grad_output * dist.get_world_size(ctx.process_group) | |
| elif ctx.grad_scale == "down": | |
| grad_output = grad_output / dist.get_world_size(ctx.process_group) | |
| return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None | |
| class _SplitForwardGatherBackward(torch.autograd.Function): | |
| """ | |
| Split sequence. | |
| Args: | |
| input_: input matrix. | |
| process_group: parallel mode. | |
| dim: dimension | |
| """ | |
| def symbolic(graph, input_): | |
| return _split_sequence_func(input_) | |
| def forward(ctx, input_, process_group, dim, grad_scale, pad): | |
| ctx.process_group = process_group | |
| ctx.dim = dim | |
| ctx.grad_scale = grad_scale | |
| ctx.pad = pad | |
| return _split_sequence_func(input_, process_group, dim, pad) | |
| def backward(ctx, grad_output): | |
| if ctx.grad_scale == "up": | |
| grad_output = grad_output * dist.get_world_size(ctx.process_group) | |
| elif ctx.grad_scale == "down": | |
| grad_output = grad_output / dist.get_world_size(ctx.process_group) | |
| return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None | |
| def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0): | |
| return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad) | |
| def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0): | |
| return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad) | |
| # ============================== | |
| # Pad | |
| # ============================== | |
| SPTIAL_PAD = 0 | |
| TEMPORAL_PAD = 0 | |
| def set_spatial_pad(dim_size: int): | |
| sp_size = get_sequence_parallel_size() | |
| pad = (sp_size - (dim_size % sp_size)) % sp_size | |
| global SPTIAL_PAD | |
| SPTIAL_PAD = pad | |
| def get_spatial_pad() -> int: | |
| return SPTIAL_PAD | |
| def set_temporal_pad(dim_size: int): | |
| sp_size = get_sequence_parallel_size() | |
| pad = (sp_size - (dim_size % sp_size)) % sp_size | |
| global TEMPORAL_PAD | |
| TEMPORAL_PAD = pad | |
| def get_temporal_pad() -> int: | |
| return TEMPORAL_PAD | |
| def all_to_all_with_pad( | |
| input_: torch.Tensor, | |
| process_group: dist.ProcessGroup, | |
| scatter_dim: int = 2, | |
| gather_dim: int = 1, | |
| scatter_pad: int = 0, | |
| gather_pad: int = 0, | |
| ): | |
| if scatter_pad > 0: | |
| pad_shape = list(input_.shape) | |
| pad_shape[scatter_dim] = scatter_pad | |
| pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) | |
| input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) | |
| assert ( | |
| input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0 | |
| ), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})" | |
| input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) | |
| if gather_pad > 0: | |
| input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) | |
| return input_ | |