construct_rope_3d_cache_bhl#
- construct_rope_3d_cache_bhl(
- depth,
- height,
- width,
- dim_third,
- device,
- dtype,
- rope_base,
Construct the 3D RoPE cache for given (depth, height, width) and per-axis dimension.
- Parameters:
depth (int) – int - The depth of the input tensor (Z axis).
height (int) – int - The height of the input tensor (Y axis).
width (int) – int - The width of the input tensor (X axis).
dim_third (int) – int - The per-axis channel dimension (one third of per-head dim).
device (device) – torch.device - The device to store the cache on.
dtype (dtype) – torch.dtype - The dtype of the cache.
rope_base (float) – float - The base of the RoPE.
- Returns:
The 3D RoPE cache organized as: (cos_z, sin_z, cos_y, sin_y, cos_x, sin_x) with shapes: - cos_z, sin_z: [dim_third, depth] - cos_y, sin_y: [dim_third, height] - cos_x, sin_x: [dim_third, width]
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Notes
For pairwise rotations, each per-axis channel size (dim_third) must be even.
Overall per-head dim must be divisible by 6 (since D = 3 * dim_third, and dim_third must be even).