|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank |
|
from .utils import split_tensor_along_last_dim |
|
|
|
|
|
def _reduce(input_): |
|
"""All-reduce the input tensor across model parallel group.""" |
|
|
|
|
|
if get_tensor_model_parallel_world_size()==1: |
|
return input_ |
|
|
|
|
|
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) |
|
|
|
return input_ |
|
|
|
|
|
def _split_along_last_dim(input_): |
|
"""Split the tensor along its last dimension and keep the |
|
corresponding slice.""" |
|
|
|
world_size = get_tensor_model_parallel_world_size() |
|
|
|
if world_size == 1: |
|
return input_ |
|
|
|
|
|
input_list = split_tensor_along_last_dim(input_, world_size) |
|
|
|
|
|
rank = get_tensor_model_parallel_rank() |
|
output = input_list[rank].contiguous() |
|
|
|
return output |
|
|
|
|
|
def _split_along_first_dim(input_): |
|
"""Split the tensor along its first dimension and keep the |
|
corresponding slice.""" |
|
|
|
world_size = get_tensor_model_parallel_world_size() |
|
|
|
if world_size == 1: |
|
return input_ |
|
|
|
|
|
dim_size = input_.size()[0] |
|
assert dim_size % world_size == 0, \ |
|
"First dimension of the tensor should be divisible by tensor parallel size" |
|
local_dim_size = dim_size // world_size |
|
rank = get_tensor_model_parallel_rank() |
|
dim_offset = rank * local_dim_size |
|
|
|
output = input_[dim_offset:dim_offset+local_dim_size].contiguous() |
|
|
|
return output |
|
|
|
|
|
def _gather_along_last_dim(input_): |
|
"""Gather tensors and concatinate along the last dimension.""" |
|
|
|
world_size = get_tensor_model_parallel_world_size() |
|
|
|
if world_size == 1: |
|
return input_ |
|
|
|
|
|
last_dim = input_.dim() - 1 |
|
rank = get_tensor_model_parallel_rank() |
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
tensor_list[rank] = input_ |
|
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) |
|
|
|
|
|
output = torch.cat(tensor_list, dim=last_dim).contiguous() |
|
|
|
return output |
|
|
|
|
|
def _gather_along_first_dim(input_): |
|
"""Gather tensors and concatinate along the first dimension.""" |
|
|
|
world_size = get_tensor_model_parallel_world_size() |
|
|
|
if world_size == 1: |
|
return input_ |
|
|
|
dim_size = list(input_.size()) |
|
dim_size[0] = dim_size[0] * world_size |
|
|
|
output = torch.empty(dim_size, dtype=input_.dtype, |
|
device=torch.cuda.current_device()) |
|
torch.distributed._all_gather_base(output, input_.contiguous(), |
|
group=get_tensor_model_parallel_group()) |
|
|
|
return output |
|
|
|
def _reduce_scatter_along_first_dim(input_): |
|
"""Reduce-scatter the input tensor across model parallel group.""" |
|
world_size = get_tensor_model_parallel_world_size() |
|
|
|
if world_size == 1: |
|
return input_ |
|
|
|
dim_size = list(input_.size()) |
|
assert dim_size[0] % world_size == 0, \ |
|
"First dimension of the tensor should be divisible by tensor parallel size" |
|
|
|
dim_size[0] = dim_size[0] // world_size |
|
|
|
output = torch.empty(dim_size, dtype=input_.dtype, |
|
device=torch.cuda.current_device()) |
|
torch.distributed._reduce_scatter_base(output, input_.contiguous(), |
|
group=get_tensor_model_parallel_group()) |
|
return output |
|
|
|
|
|
class _CopyToModelParallelRegion(torch.autograd.Function): |
|
"""Pass the input to the model parallel region.""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_): |
|
return input_ |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return input_ |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _reduce(grad_output) |
|
|
|
|
|
class _ReduceFromModelParallelRegion(torch.autograd.Function): |
|
"""All-reduce the input from the model parallel region.""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_): |
|
return _reduce(input_) |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _reduce(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return grad_output |
|
|
|
|
|
class _ScatterToModelParallelRegion(torch.autograd.Function): |
|
"""Split the input and keep only the corresponding chuck to the rank.""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_): |
|
return _split_along_last_dim(input_) |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _split_along_last_dim(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _gather_along_last_dim(grad_output) |
|
|
|
|
|
class _GatherFromModelParallelRegion(torch.autograd.Function): |
|
"""Gather the input from model parallel region and concatinate.""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_): |
|
return _gather_along_last_dim(input_) |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _gather_along_last_dim(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _split_along_last_dim(grad_output) |
|
|
|
|
|
class _ScatterToSequenceParallelRegion(torch.autograd.Function): |
|
"""Split the input and keep only the corresponding chuck to the rank.""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_): |
|
return _split_along_first_dim(input_) |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _split_along_first_dim(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _gather_along_first_dim(grad_output) |
|
|
|
|
|
class _GatherFromSequenceParallelRegion(torch.autograd.Function): |
|
"""Gather the input from sequence parallel region and concatinate.""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_, tensor_parallel_output_grad=True): |
|
return _gather_along_first_dim(input_) |
|
|
|
@staticmethod |
|
def forward(ctx, input_, tensor_parallel_output_grad=True): |
|
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad |
|
return _gather_along_first_dim(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad |
|
|
|
|
|
|
|
|
|
|
|
if tensor_parallel_output_grad: |
|
return _reduce_scatter_along_first_dim(grad_output), None |
|
else: |
|
return _split_along_first_dim(grad_output), None |
|
|
|
|
|
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): |
|
"""Reduce scatter the input from the model parallel region.""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_): |
|
return _reduce_scatter_along_first_dim(input_) |
|
|
|
@staticmethod |
|
def forward(ctx, input_): |
|
return _reduce_scatter_along_first_dim(input_) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _gather_along_first_dim(grad_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_to_tensor_model_parallel_region(input_): |
|
return _CopyToModelParallelRegion.apply(input_) |
|
|
|
|
|
def reduce_from_tensor_model_parallel_region(input_): |
|
return _ReduceFromModelParallelRegion.apply(input_) |
|
|
|
|
|
def scatter_to_tensor_model_parallel_region(input_): |
|
return _ScatterToModelParallelRegion.apply(input_) |
|
|
|
|
|
def gather_from_tensor_model_parallel_region(input_): |
|
return _GatherFromModelParallelRegion.apply(input_) |
|
|
|
|
|
def scatter_to_sequence_parallel_region(input_): |
|
return _ScatterToSequenceParallelRegion.apply(input_) |
|
|
|
|
|
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): |
|
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) |
|
|
|
|
|
def reduce_scatter_to_sequence_parallel_region(input_): |
|
return _ReduceScatterToSequenceParallelRegion.apply(input_) |
|
|
|
|