RegisterCompressConcat#

class RegisterCompressConcat(num_registers, hidden_dim, compressed_dim)#

Bases: Module

Compress each register token and concatenate into a single conditioning vector.

Each register token is passed through a shared linear projection that reduces its dimensionality from hidden_dim to compressed_dim. The compressed tokens are then concatenated along the feature axis, producing a conditioning vector of size num_registers * compressed_dim that preserves per-register identity (unlike RegisterPooling which averages them).

Formally, for a batch of register tensors X ℝ^{B × R × D}:

compressed_r = W · x_r, W ∈ ℝ^{compressed_dim × hidden_dim} (shared across r) c = [compressed_0 ‖ compressed_1 ‖ … ‖ compressed_{R-1}] ∈ ℝ^{R · compressed_dim}

The output c is then consumed by KernelFiLMGenerator whose cond_dim must equal num_registers * compressed_dim (see out_dim).

Inspired by Mamba-Reg (Wang et al., 2024) which distributes and individually reads out register tokens rather than pooling them.

num_registers#

Number of register tokens expected on the sequence axis.

Type:

int

hidden_dim#

Input channel dimension of each register token.

Type:

int

compressed_dim#

Output channel dimension per register after compression.

Type:

int

compress#

Shared weight-only (no bias) projection hidden_dim compressed_dim applied independently to each register.

Type:

nn.Linear

Parameters:
  • num_registers (int) – Number of register tokens expected.

  • hidden_dim (int) – Channel dimension of each register token.

  • compressed_dim (int) – Output dimension per register after compression. The final conditioning vector has size num_registers * compressed_dim.

__init__(
num_registers,
hidden_dim,
compressed_dim,
)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • num_registers (int)

  • hidden_dim (int)

  • compressed_dim (int)

property out_dim: int#

Dimensionality of the output conditioning vector.

Returns:

num_registers * compressed_dim, which must match cond_dim of any downstream KernelFiLMGenerator.

flop_count(hidden_dim)#

Count FLOPs for compress-and-concatenate (one sample).

Operations (R = num_registers, D_in = hidden_dim, D_out = compressed_dim):
  1. Shared linear applied R times: R * 2 * D_in * D_out

  2. Concatenation is a view/copy, not a FLOP.

Parameters:

hidden_dim (int) – Channel dimension of the register tokens; must equal the hidden_dim argument passed to __init__.

Returns:

Total FLOPs as an integer.

Return type:

int

forward(registers)#

Compress each register token and concatenate into a flat conditioning vector.

Applies the shared linear projection to every register independently (via broadcasting over the register axis), then flattens the register and compressed-channel axes into a single vector per sample.

Parameters:

registers (Tensor) – Register token tensor of shape [B, num_registers, hidden_dim], where B is the batch size. The number of registers along axis 1 must equal num_registers passed to __init__.

Returns:

Flat conditioning vector of shape [B, num_registers * compressed_dim]. Pass this directly as the conditioning argument of KernelFiLMGenerator.

Return type:

Tensor

extra_repr()#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str