RegisterCompressConcat#
- class RegisterCompressConcat(num_registers, hidden_dim, compressed_dim)#
Bases:
ModuleCompress each register token and concatenate into a single conditioning vector.
Each register token is passed through a shared linear projection that reduces its dimensionality from
hidden_dimtocompressed_dim. The compressed tokens are then concatenated along the feature axis, producing a conditioning vector of sizenum_registers * compressed_dimthat preserves per-register identity (unlikeRegisterPoolingwhich averages them).Formally, for a batch of register tensors
X ∈ ℝ^{B × R × D}:compressed_r = W · x_r, W ∈ ℝ^{compressed_dim × hidden_dim} (shared across r) c = [compressed_0 ‖ compressed_1 ‖ … ‖ compressed_{R-1}] ∈ ℝ^{R · compressed_dim}
The output
cis then consumed byKernelFiLMGeneratorwhosecond_dimmust equalnum_registers * compressed_dim(seeout_dim).Inspired by Mamba-Reg (Wang et al., 2024) which distributes and individually reads out register tokens rather than pooling them.
Input channel dimension of each register token.
- Type:
- compress#
Shared weight-only (no bias) projection
hidden_dim → compressed_dimapplied independently to each register.- Type:
nn.Linear
- Parameters:
- __init__(
- num_registers,
- hidden_dim,
- compressed_dim,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- property out_dim: int#
Dimensionality of the output conditioning vector.
- Returns:
num_registers * compressed_dim, which must matchcond_dimof any downstreamKernelFiLMGenerator.
- flop_count(hidden_dim)#
Count FLOPs for compress-and-concatenate (one sample).
- Operations (R = num_registers, D_in =
hidden_dim, D_out = compressed_dim): Shared linear applied R times: R * 2 * D_in * D_out
Concatenation is a view/copy, not a FLOP.
- Operations (R = num_registers, D_in =
- forward(registers)#
Compress each register token and concatenate into a flat conditioning vector.
Applies the shared linear projection to every register independently (via broadcasting over the register axis), then flattens the register and compressed-channel axes into a single vector per sample.
- Parameters:
registers (Tensor) – Register token tensor of shape
[B, num_registers, hidden_dim], whereBis the batch size. The number of registers along axis 1 must equalnum_registerspassed to__init__.- Returns:
Flat conditioning vector of shape
[B, num_registers * compressed_dim]. Pass this directly as theconditioningargument ofKernelFiLMGenerator.- Return type: