RegisterPooling#
- class RegisterPooling(num_registers)#
Bases:
ModuleLearnable softmax-weighted average over register tokens.
Produces a single conditioning vector
c ∈ ℝ^Cper sample from a set ofRregister tokens. The pooling weights are learned scalarsw ∈ ℝ^Rpassed through a softmax, so the contribution of each register is non-negative and the weights sum to one:c = Σ_r softmax(w)_r · x_r
This is a lightweight alternative to
RegisterCompressConcatwhen a scalar summary per register is sufficient. All register information is blended into a single[B, C]vector that can be passed directly toKernelFiLMGenerator.The learned logit vector
wis always excluded from weight decay via_no_weight_decay = True.- logits#
Learnable unnormalized pooling weights of shape
[num_registers], initialised to zero (uniform softmax at init).- Type:
nn.Parameter
- Parameters:
num_registers (int) – Number of register tokens to pool over.
- __init__(num_registers)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
num_registers (int)
- flop_count(dim)#
Count FLOPs for learnable weighted average over register tokens.
- Operations (R =
self.logits.shape[0]= num_registers, D = dim): Softmax over R logits: ~3 * R (exp + sum + divide, amortized).
Weighted sum via einsum(“r, b r c -> b c”): R * D multiplies + (R - 1) * D adds ≈ 2 * R * D.
Total: 3 * R + 2 * R * D.
- Operations (R =
- forward(registers)#
Pool register tokens into a single conditioning vector.
Applies softmax to the learned logits and computes a weighted sum over the register dimension:
c_b = Σ_r softmax(logits)_r · registers_{b,r,:}
- Parameters:
registers (Tensor) – Register token tensor of shape
[B, num_registers, C], whereBis the batch size andCis the channel dimension. The number of registers along axis 1 must equalnum_registerspassed to__init__; a mismatch will raise aRuntimeErrorfromtorch.einsum.- Returns:
Pooled conditioning vector of shape
[B, C].- Return type: