construct_rope_2d_cache_blh#

construct_rope_2d_cache_blh(
height,
width,
dim_half,
device,
dtype,
rope_base,
)#

Construct the 2D RoPE cache for a given (height, width) and per-axis dimension.

Parameters:
  • height (int) – int - The height of the input tensor.

  • width (int) – int - The width of the input tensor.

  • dim_half (int) – int - The per-axis channel dimension.

  • device (device) – torch.device - The device to store the cache on.

  • dtype (dtype) – torch.dtype - The dtype of the cache.

  • rope_base (float) – float - The base of the RoPE.

Returns:

The 2D RoPE cache organized as: (cos_y, sin_y, cos_x, sin_x) with shapes: - cos_y, sin_y: [height, dim_half] - cos_x, sin_x: [width, dim_half]

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Notes

  • For pairwise rotations, each per-axis channel size (dim_half) must be even.

  • Overall per-head dim must be divisible by 4 (since D = 2 * dim_half, and dim_half must be even).