AllToAllSingleFunction#

class AllToAllSingleFunction(*args, **kwargs)#

Bases: Function

Differentiable all-to-all collective for CP sequence ↔ channel redistribution.

Wraps all_to_all_single_fn() in a torch.autograd.Function so that gradients flow correctly through the collective boundary. The backward pass is the dual communication direction:

forward  split_to_full → backward full_to_split
forward  full_to_split → backward split_to_full

Usage:

out = AllToAllSingleFunction.apply(x, cp_group, "split_to_full", True)

The apply arguments correspond to the positional parameters of forward() (excluding ctx).

ctx.group#

Process group saved for the backward collective.

ctx.type#

Communication direction saved for reversal in backward.

ctx.with_zigzag_splitting#

Zigzag flag saved for backward.

static forward(
ctx,
input_tensor,
group,
type,
with_zigzag_splitting,
)#

Execute the all-to-all collective and save state for backward.

Parameters:
  • ctx – Autograd context; stores group, type, and with_zigzag_splitting for use in backward().

  • input_tensor (Tensor) – Input tensor of shape [B, C_local, *spatial]. The tensor is detached before communication to prevent PyTorch from tracking in-collective ops.

  • group (ProcessGroup) – CP process group.

  • type (Literal['split_to_full', 'full_to_split']) – "split_to_full" or "full_to_split" (see module docstring for the reshape semantics of each direction).

  • with_zigzag_splitting (bool) – Apply zigzag chunk permutation to balance load across ranks. Should match the value used in the corresponding backward call.

Returns:

Redistributed tensor; shape is the dual of the input under the chosen type.

Return type:

torch.Tensor

static backward(ctx, grad_output)#

Propagate gradients through the dual all-to-all direction.

Reverses the communication pattern: split_to_fullfull_to_split. Zigzag permutation and process group are taken from ctx.

Parameters:
  • ctx – Autograd context with saved group, type, and with_zigzag_splitting.

  • grad_output (Tensor) – Upstream gradient tensor; same shape as the forward output.

Returns:

(grad_input, None, None, None). Only the first element is meaningful; the others correspond to non-tensor arguments.

Return type:

Tuple of four elements matching the forward signature