DistributedDepthwiseConv2d#

class DistributedDepthwiseConv2d(
hidden_dim,
kernel_size,
num_groups=None,
bias=False,
dtype=None,
device=None,
)#

Bases: Module

2-D depthwise convolution with CP-aware channel slicing and weight sharing.

Stores weights of shape [G, Kh, Kw] and expands to [C, Kh, Kw] at runtime via repeat_interleave. When cp_group is provided, the appropriate channel slice for the current rank is extracted before calling F.conv2d(padding="same").

Expects input in channel-first layout: [B, C_local, H, W].

hidden_dim#

Total number of channels C.

Type:

int

kernel_size#

2-D kernel dimensions (Kh, Kw).

Type:

Tuple[int, int]

num_groups#

Weight prototype groups G.

Type:

int

group_dim#

Channels per group C // G.

Type:

int

weight#

Shape [G, Kh, Kw].

Type:

nn.Parameter

bias#

Shape [G] or None.

Type:

nn.Parameter | None

Parameters:
  • hidden_dim (int) – Total number of channels C.

  • kernel_size (int | Tuple[int, int]) – Kernel size; int(K, K).

  • num_groups (int | None) – Weight prototype groups G. NoneG = C.

  • bias (bool) – Include a learnable bias. Default False.

  • dtype (dtype | None) – Parameter dtype. Default torch.float32.

  • device (device | None) – Parameter device. Defaults to current CUDA device.

__init__(
hidden_dim,
kernel_size,
num_groups=None,
bias=False,
dtype=None,
device=None,
)#

Initialise DistributedDepthwiseConv2d.

Parameters:
  • hidden_dim (int) – Total number of input/output channels C.

  • kernel_size (int | Tuple[int, int]) – Kernel size; int is broadcast to (K, K).

  • num_groups (int | None) – Weight prototype groups G C. NoneG = C.

  • bias (bool) – Include a learnable bias. Default False.

  • dtype (dtype | None) – Parameter dtype. Default torch.float32.

  • device (device | None) – Parameter device. Defaults to current CUDA device.

init_weights()#

Initialize weights and bias using uniform distribution.

forward(x, cp_group=None)#

Apply 2-D depthwise convolution with optional CP channel slicing.

Parameters:
  • x (Tensor) – Input tensor of shape [B, C_local, H, W].

  • cp_group (ProcessGroup | None) – Context-parallel process group. None → single-device.

Returns:

Output of shape [B, C_local, H, W] (same spatial size because padding="same").

Return type:

torch.Tensor

Raises:
  • AssertionError – If x.ndim != 4.

  • RuntimeError – If x.shape[1] does not match the expected local channel count after CP slicing.