AllToAllSingleFunction#
- class AllToAllSingleFunction(*args, **kwargs)#
Bases:
FunctionDifferentiable all-to-all collective for CP sequence ↔ channel redistribution.
Wraps
all_to_all_single_fn()in atorch.autograd.Functionso that gradients flow correctly through the collective boundary. The backward pass is the dual communication direction:forward split_to_full → backward full_to_split forward full_to_split → backward split_to_full
Usage:
out = AllToAllSingleFunction.apply(x, cp_group, "split_to_full", True)
The
applyarguments correspond to the positional parameters offorward()(excludingctx).- ctx.group#
Process group saved for the backward collective.
- ctx.type#
Communication direction saved for reversal in backward.
- ctx.with_zigzag_splitting#
Zigzag flag saved for backward.
- static forward(
- ctx,
- input_tensor,
- group,
- type,
- with_zigzag_splitting,
Execute the all-to-all collective and save state for backward.
- Parameters:
ctx – Autograd context; stores
group,type, andwith_zigzag_splittingfor use inbackward().input_tensor (Tensor) – Input tensor of shape
[B, C_local, *spatial]. The tensor is detached before communication to prevent PyTorch from tracking in-collective ops.group (ProcessGroup) – CP process group.
type (Literal['split_to_full', 'full_to_split']) –
"split_to_full"or"full_to_split"(see module docstring for the reshape semantics of each direction).with_zigzag_splitting (bool) – Apply zigzag chunk permutation to balance load across ranks. Should match the value used in the corresponding backward call.
- Returns:
Redistributed tensor; shape is the dual of the input under the chosen
type.- Return type:
- static backward(ctx, grad_output)#
Propagate gradients through the dual all-to-all direction.
Reverses the communication pattern:
split_to_full↔full_to_split. Zigzag permutation and process group are taken fromctx.- Parameters:
ctx – Autograd context with saved
group,type, andwith_zigzag_splitting.grad_output (Tensor) – Upstream gradient tensor; same shape as the forward output.
- Returns:
(grad_input, None, None, None). Only the first element is meaningful; the others correspond to non-tensor arguments.- Return type:
Tuple of four elements matching the forward signature