RegisterPooling#

class RegisterPooling(num_registers)#

Bases: Module

Learnable softmax-weighted average over register tokens.

Produces a single conditioning vector c ℝ^C per sample from a set of R register tokens. The pooling weights are learned scalars w ℝ^R passed through a softmax, so the contribution of each register is non-negative and the weights sum to one:

c = Σ_r softmax(w)_r · x_r

This is a lightweight alternative to RegisterCompressConcat when a scalar summary per register is sufficient. All register information is blended into a single [B, C] vector that can be passed directly to KernelFiLMGenerator.

The learned logit vector w is always excluded from weight decay via _no_weight_decay = True.

logits#

Learnable unnormalized pooling weights of shape [num_registers], initialised to zero (uniform softmax at init).

Type:

nn.Parameter

Parameters:

num_registers (int) – Number of register tokens to pool over.

__init__(num_registers)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:

num_registers (int)

flop_count(dim)#

Count FLOPs for learnable weighted average over register tokens.

Operations (R = self.logits.shape[0] = num_registers, D = dim):
  1. Softmax over R logits: ~3 * R (exp + sum + divide, amortized).

  2. Weighted sum via einsum(“r, b r c -> b c”): R * D multiplies + (R - 1) * D adds ≈ 2 * R * D.

Total: 3 * R + 2 * R * D.

Parameters:

dim (int) – Channel dimension (C) of the register tokens.

Returns:

Total FLOPs as an integer.

Return type:

int

forward(registers)#

Pool register tokens into a single conditioning vector.

Applies softmax to the learned logits and computes a weighted sum over the register dimension:

c_b = Σ_r softmax(logits)_r · registers_{b,r,:}

Parameters:

registers (Tensor) – Register token tensor of shape [B, num_registers, C], where B is the batch size and C is the channel dimension. The number of registers along axis 1 must equal num_registers passed to __init__; a mismatch will raise a RuntimeError from torch.einsum.

Returns:

Pooled conditioning vector of shape [B, C].

Return type:

Tensor

extra_repr()#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Return type:

str