apply_rope_3d_blh#

apply_rope_3d_blh(x, rope_3d_cache)#

Apply 3D RoPE to a tensor laid out as [batch_size, D, H, W, hidden_dim].

The channel dimension C is split into three equal parts: C_z, C_y, and C_x. RoPE is applied independently along Z (to C_z), Y (to C_y), and X (to C_x). For pairwise rotations, C must be divisible by 6 so that each third is even.

Parameters:
  • x (Tensor) – Input tensor of shape [batch_size, D, H, W, hidden_dim] with hidden_dim % 6 == 0.

  • rope_3d_cache (tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]) – tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] The cache of 3D RoPE cos/sin for z, y, and x axes, organized as (cos_z, sin_z, cos_y, sin_y, cos_x, sin_x).

Returns:

Tensor with the same shape as x. Rotations are written back in-place via views to reduce allocations.

Return type:

Tensor

Broadcasting:
  • cos_z/sin_z are reshaped to [1, D, 1, 1, hidden_dim/3] for the first third.

  • cos_y/sin_y are reshaped to [1, 1, H, 1, hidden_dim/3] for the second third.

  • cos_x/sin_x are reshaped to [1, 1, 1, W, hidden_dim/3] for the final third.