construct_rope_2d_cache_bhl#
- construct_rope_2d_cache_bhl(
- height,
- width,
- dim_half,
- device,
- dtype,
- rope_base,
Construct the 2D RoPE cache for a given (height, width) and per-axis dimension.
- Parameters:
height (int) – int - The height of the input tensor.
width (int) – int - The width of the input tensor.
dim_half (int) – int - The per-axis channel dimension.
device (device) – torch.device - The device to store the cache on.
dtype (dtype) – torch.dtype - The dtype of the cache.
rope_base (float) – float - The base of the RoPE.
- Returns:
The 2D RoPE cache organized as: (cos_y, sin_y, cos_x, sin_x) with shapes: - cos_y, sin_y: [dim_half, height] - cos_x, sin_x: [dim_half, width]
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Notes
For pairwise rotations, each per-axis channel size (dim_half) must be even.
Overall per-head dim must be divisible by 4 (since D = 2 * dim_half, and dim_half must be even).