Parallel#

Context-parallel communication primitives (zigzag splits / all-to-all) shared by the mixer and conv modules above.

init_parallel_state([...])

Initialize distributed training and megatron parallel state.

zigzag_split_across_group_ranks(data, group)

Distributes tensor data across group ranks using zigzag pattern.

zigzag_gather_from_group_ranks(data, group)

Reconstructs complete tensor from zigzag-distributed chunks.

setup_rank0_logging([log_file])

Set up logging that only prints to console from rank 0, but logs all ranks to files.

AllToAllSingleFunction(*args, **kwargs)

Differentiable all-to-all collective for CP sequence ↔ channel redistribution.