construct_rope_1d_cache_blh#

construct_rope_1d_cache_blh(seq_len, dim, device, dtype, rope_base)#

Construct the 1D RoPE cache for a given sequence length and hidden dimension.

Parameters:
  • seq_len (int) – int - The length of the input sequence.

  • dim (int) – int - The hidden dimension.

  • device (device) – torch.device - The device to store the cache on.

  • dtype (dtype) – torch.dtype - The dtype of the cache.

  • rope_base (float) – float - The base of the RoPE.

Returns:

The 1D RoPE cache organized as: (cos, sin) with shapes: - cos: [seq_len, dim] - sin: [seq_len, dim]

Return type:

tuple[torch.Tensor, torch.Tensor]

Notes

  • For pairwise rotations, each per-axis channel size (dim) must be even.

  • Overall per-head dim must be divisible by 2 (since D = 2 * dim, and dim must be even).