construct_rope_3d_cache_bhl#

construct_rope_3d_cache_bhl(
depth,
height,
width,
dim_third,
device,
dtype,
rope_base,
)#

Construct the 3D RoPE cache for given (depth, height, width) and per-axis dimension.

Parameters:
  • depth (int) – int - The depth of the input tensor (Z axis).

  • height (int) – int - The height of the input tensor (Y axis).

  • width (int) – int - The width of the input tensor (X axis).

  • dim_third (int) – int - The per-axis channel dimension (one third of per-head dim).

  • device (device) – torch.device - The device to store the cache on.

  • dtype (dtype) – torch.dtype - The dtype of the cache.

  • rope_base (float) – float - The base of the RoPE.

Returns:

The 3D RoPE cache organized as: (cos_z, sin_z, cos_y, sin_y, cos_x, sin_x) with shapes: - cos_z, sin_z: [dim_third, depth] - cos_y, sin_y: [dim_third, height] - cos_x, sin_x: [dim_third, width]

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Notes

  • For pairwise rotations, each per-axis channel size (dim_third) must be even.

  • Overall per-head dim must be divisible by 6 (since D = 3 * dim_third, and dim_third must be even).