DistributedDepthwiseConv2d#
- class DistributedDepthwiseConv2d(
- hidden_dim,
- kernel_size,
- num_groups=None,
- bias=False,
- dtype=None,
- device=None,
Bases:
Module2-D depthwise convolution with CP-aware channel slicing and weight sharing.
Stores weights of shape
[G, Kh, Kw]and expands to[C, Kh, Kw]at runtime viarepeat_interleave. Whencp_groupis provided, the appropriate channel slice for the current rank is extracted before callingF.conv2d(padding="same").Expects input in channel-first layout:
[B, C_local, H, W].Total number of channels
C.- Type:
- weight#
Shape
[G, Kh, Kw].- Type:
nn.Parameter
- bias#
Shape
[G]orNone.- Type:
nn.Parameter | None
- Parameters:
hidden_dim (int) – Total number of channels
C.kernel_size (int | Tuple[int, int]) – Kernel size;
int→(K, K).num_groups (int | None) – Weight prototype groups
G.None→G = C.bias (bool) – Include a learnable bias. Default
False.dtype (dtype | None) – Parameter dtype. Default
torch.float32.device (device | None) – Parameter device. Defaults to current CUDA device.
- __init__(
- hidden_dim,
- kernel_size,
- num_groups=None,
- bias=False,
- dtype=None,
- device=None,
Initialise DistributedDepthwiseConv2d.
- Parameters:
hidden_dim (int) – Total number of input/output channels
C.kernel_size (int | Tuple[int, int]) – Kernel size;
intis broadcast to(K, K).num_groups (int | None) – Weight prototype groups
G ≤ C.None→G = C.bias (bool) – Include a learnable bias. Default
False.dtype (dtype | None) – Parameter dtype. Default
torch.float32.device (device | None) – Parameter device. Defaults to current CUDA device.
- init_weights()#
Initialize weights and bias using uniform distribution.
- forward(x, cp_group=None)#
Apply 2-D depthwise convolution with optional CP channel slicing.
- Parameters:
x (Tensor) – Input tensor of shape
[B, C_local, H, W].cp_group (ProcessGroup | None) – Context-parallel process group.
None→ single-device.
- Returns:
Output of shape
[B, C_local, H, W](same spatial size becausepadding="same").- Return type:
- Raises:
AssertionError – If
x.ndim != 4.RuntimeError – If
x.shape[1]does not match the expected local channel count after CP slicing.