Skip to content

vllm.model_executor.layers.rotary_embedding.llama4_vision_rope

Llama4VisionRotaryEmbedding

Bases: RotaryEmbedding

Source code in vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
class Llama4VisionRotaryEmbedding(RotaryEmbedding):

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
    ):
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        inv_freqs = super()._compute_inv_freq(base)
        inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
        return inv_freqs

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.base)

        # self.max_position_embeddings here is number of image patches
        # i.e. (image_size // patch_size) ** 2
        num_patches = self.max_position_embeddings
        img_idx = torch.arange(num_patches,
                    dtype=torch.int32) \
                    .reshape(num_patches, 1)
        img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
        img_idx[-1, -1] = -2  # set to ID_CLS_TOKEN
        num_patches_single_dim = int(math.sqrt(num_patches))
        frequencies_x = img_idx % num_patches_single_dim
        frequencies_y = img_idx // num_patches_single_dim
        freqs_x = ((frequencies_x + 1)[..., None] *
                   inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
        freqs_y = ((frequencies_y + 1)[..., None] *
                   inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
        freqs = torch.cat([freqs_x, freqs_y],
                          dim=-1).float().contiguous()[..., ::2]
        freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
        cache = torch.view_as_complex(
            torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
        return cache

    def forward(
        self,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert key is not None
        self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
        query_ = torch.view_as_complex(query.float().reshape(
            *query.shape[:-1], -1, 2))
        key_ = torch.view_as_complex(key.float().reshape(
            *key.shape[:-1], -1, 2))
        broadcast_shape = [
            d if i == 1 or i == (query_.ndim - 1) else 1
            for i, d in enumerate(query_.shape)
        ]
        freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
        query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
        key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
        return query_out.type_as(query), key_out.type_as(key)

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
)
Source code in vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
):
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.base)

    # self.max_position_embeddings here is number of image patches
    # i.e. (image_size // patch_size) ** 2
    num_patches = self.max_position_embeddings
    img_idx = torch.arange(num_patches,
                dtype=torch.int32) \
                .reshape(num_patches, 1)
    img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
    img_idx[-1, -1] = -2  # set to ID_CLS_TOKEN
    num_patches_single_dim = int(math.sqrt(num_patches))
    frequencies_x = img_idx % num_patches_single_dim
    frequencies_y = img_idx // num_patches_single_dim
    freqs_x = ((frequencies_x + 1)[..., None] *
               inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
    freqs_y = ((frequencies_y + 1)[..., None] *
               inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
    freqs = torch.cat([freqs_x, freqs_y],
                      dim=-1).float().contiguous()[..., ::2]
    freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
    cache = torch.view_as_complex(
        torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
    return cache

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    inv_freqs = super()._compute_inv_freq(base)
    inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
    return inv_freqs

forward

forward(
    query: Tensor, key: Optional[Tensor] = None
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py
def forward(
    self,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert key is not None
    self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
    query_ = torch.view_as_complex(query.float().reshape(
        *query.shape[:-1], -1, 2))
    key_ = torch.view_as_complex(key.float().reshape(
        *key.shape[:-1], -1, 2))
    broadcast_shape = [
        d if i == 1 or i == (query_.ndim - 1) else 1
        for i, d in enumerate(query_.shape)
    ]
    freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
    query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
    key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
    return query_out.type_as(query), key_out.type_as(key)