DistributedDepthwiseConv1d#

class DistributedDepthwiseConv1d(
hidden_dim,
kernel_size,
causal=False,
num_groups=None,
bias=False,
dtype=None,
device=None,
)#

Bases: Module

1-D depthwise convolution with CP-aware channel slicing and weight sharing.

Stores a compact weight of shape [G, K] (G groups, kernel size K) and expands it to [C, K] at each forward pass via repeat_interleave. When cp_group is provided, an additional slice [r·(C/P) : (r+1)·(C/P)] is taken before calling F.conv1d, so each CP rank only processes its local channel shard.

Supports two padding modes:

  • causal=True: left-only pad of (K-1) before the conv; output length equals input length with no future dependency.

  • causal=False (default): padding="same" (symmetric); suitable for spatial axes in multi-dimensional Hyena.

hidden_dim#

Total number of input/output channels C.

Type:

int

kernel_size#

Convolution kernel size K.

Type:

int

causal#

Whether left-only (causal) padding is used.

Type:

bool

num_groups#

Number of weight prototype groups G.

Type:

int

group_dim#

Channels per group C // G.

Type:

int

weight#

Filter weights of shape [G, K].

Type:

nn.Parameter

bias#

Optional bias of shape [G].

Type:

nn.Parameter | None

Parameters:
  • hidden_dim (int) – Total number of input/output channels C.

  • kernel_size (int) – Convolution kernel size K.

  • causal (bool) – Apply left-only padding. Default False.

  • num_groups (int | None) – Weight prototype groups G. NoneG = C (standard depthwise, no sharing).

  • bias (bool) – Include a learnable bias. Default False.

  • dtype (dtype | None) – Parameter dtype. Default torch.float32.

  • device (device | None) – Parameter device. Defaults to cuda:current if available.

__init__(
hidden_dim,
kernel_size,
causal=False,
num_groups=None,
bias=False,
dtype=None,
device=None,
)#

Initialise DistributedDepthwiseConv1d.

Parameters:
  • hidden_dim (int) – Total number of input/output channels C.

  • kernel_size (int) – Convolution kernel size K.

  • causal (bool) – Apply left-only causal padding. Default False.

  • num_groups (int | None) – Weight prototype groups G C. NoneG = C.

  • bias (bool) – Include a learnable bias. Default False.

  • dtype (dtype | None) – Parameter dtype. Default torch.float32.

  • device (device | None) – Parameter device. Defaults to current CUDA device.

init_weights()#

Initialize weights and bias using uniform distribution.

forward(x, cp_group=None)#

Apply 1-D depthwise convolution with optional CP channel slicing.

The full weight [G, K] is expanded to [C, K] via repeat_interleave. If cp_group is given, only the slice for the current rank is retained before calling F.conv1d.

Parameters:
  • x (Tensor) –

    Input tensor of shape [B, C_local, L] where

    • C_local = hidden_dim when not using CP, or

    • C_local = hidden_dim // cp_world_size on CP rank r.

  • cp_group (ProcessGroup | None) – Context-parallel process group. None → single-device mode (no slicing).

Returns:

Output of shape [B, C_local, L]; same length as input for stride=1.

Return type:

torch.Tensor

Raises:
  • AssertionError – If x.ndim != 3.

  • RuntimeError – If x.shape[1] does not match the expected local channel count after CP slicing.