apply_qk_norm#

apply_qk_norm(query, key, dim=-1, eps=1e-12)#

L2-normalise query and key tensors along a given dimension.

Computes F.normalize(q, p=2, dim=dim) and the equivalent for k. The resulting vectors have unit L2 norm along dim, so their dot product is bounded to [-1, 1] (cosine similarity).

Parameters:
  • query (Tensor) – Query tensor of shape [B, H, T, D] (or any layout where dim selects the head/feature axis to normalise over).

  • key (Tensor) – Key tensor; must be broadcast-compatible with query.

  • dim (int) – Axis to normalise over. Default -1 (last axis = feature dimension in [B, H, T, D] layout).

  • eps (float) – Small constant added to the L2 norm for numerical stability. Default 1e-12.

Returns:

Tuple (query_normed, key_normed) with the same shapes and dtypes as the inputs.