apply_qk_norm#
- apply_qk_norm(query, key, dim=-1, eps=1e-12)#
L2-normalise query and key tensors along a given dimension.
Computes
F.normalize(q, p=2, dim=dim)and the equivalent fork. The resulting vectors have unit L2 norm alongdim, so their dot product is bounded to[-1, 1](cosine similarity).- Parameters:
query (Tensor) – Query tensor of shape
[B, H, T, D](or any layout wheredimselects the head/feature axis to normalise over).key (Tensor) – Key tensor; must be broadcast-compatible with
query.dim (int) – Axis to normalise over. Default
-1(last axis = feature dimension in[B, H, T, D]layout).eps (float) – Small constant added to the L2 norm for numerical stability. Default
1e-12.
- Returns:
Tuple
(query_normed, key_normed)with the same shapes and dtypes as the inputs.