Skip to content

vllm.model_executor.models.siglip2navit

Implementation of SiglipVisionModel intended to be only used within a vision language model.

Siglip2Attention

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper

Source code in vllm/model_executor/models/siglip2navit.py
class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.is_causal = False

        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        self.tp_size = (1 if use_data_parallel else
                        get_tensor_model_parallel_world_size())
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
        self.use_rope = config.use_rope

        # Detect attention implementation.
        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
        if self.attn_backend not in {
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
                _Backend.ROCM_AITER_FA
        }:
            self.attn_backend = _Backend.TORCH_SDPA
        self.is_flash_attn_backend = self.attn_backend in {
            _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
        }

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: Optional[tuple[torch.Tensor,
                                            torch.Tensor]] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Input shape: Batch x Time x Channel"""

        seq_length, embed_dim = hidden_states.shape

        qkv_states, _ = self.qkv_proj(hidden_states)
        queries, keys, values = qkv_states.chunk(3, dim=-1)

        queries = queries.view(seq_length, self.num_heads_per_partition,
                               self.head_dim)
        keys = keys.view(seq_length, self.num_heads_per_partition,
                         self.head_dim)
        values = values.view(seq_length, self.num_heads_per_partition,
                             self.head_dim)

        if self.use_rope:
            cos, sin = position_embeddings
            queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0),
                                                 keys.unsqueeze(0), cos, sin,
                                                 self.is_flash_attn_backend)
            queries = queries.squeeze(0)
            keys = keys.squeeze(0)

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        if self.is_flash_attn_backend:
            if self.attn_backend == _Backend.ROCM_AITER_FA:
                from aiter import flash_attn_varlen_func
            else:
                from flash_attn import flash_attn_varlen_func
            attn_output = flash_attn_varlen_func(
                queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
                max_seqlen).reshape(seq_length, -1)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            batch_size = cu_seqlens.shape[0] - 1
            outputs = []
            cu = cu_seqlens.tolist()
            for i in range(batch_size):
                start_idx = cu[i]
                end_idx = cu[i + 1]

                # Each sequence is processed independently.
                q_i = queries[start_idx:end_idx].unsqueeze(0)
                k_i = keys[start_idx:end_idx].unsqueeze(0)
                v_i = values[start_idx:end_idx].unsqueeze(0)

                # (1, seq_len, num_heads, head_dim) ->
                # (1, num_heads, seq_len, head_dim)
                q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]

                output_i = F.scaled_dot_product_attention(q_i,
                                                          k_i,
                                                          v_i,
                                                          dropout_p=0.0)
                # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
                output_i = output_i.transpose(1, 2).reshape(
                    end_idx - start_idx, -1)
                outputs.append(output_i)

            attn_output = torch.cat(outputs, dim=0)
        attn_output, _ = self.out_proj(attn_output)
        return attn_output

attn_backend instance-attribute

attn_backend: _Backend = get_vit_attn_backend(
    support_fa=True
)

config instance-attribute

config = config

dropout instance-attribute

dropout = attention_dropout

embed_dim instance-attribute

embed_dim = hidden_size

head_dim instance-attribute

head_dim = embed_dim // num_heads

is_causal instance-attribute

is_causal = False

is_flash_attn_backend instance-attribute

is_flash_attn_backend = attn_backend in {
    FLASH_ATTN,
    ROCM_AITER_FA,
}

num_heads instance-attribute

num_heads = num_attention_heads

num_heads_per_partition instance-attribute

num_heads_per_partition = divide(num_heads, tp_size)

out_proj instance-attribute

out_proj = RowParallelLinear(
    input_size=embed_dim,
    output_size=embed_dim,
    quant_config=quant_config,
    prefix=f"{prefix}.out_proj",
)

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size=embed_dim,
    head_size=head_dim,
    total_num_heads=num_heads,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

scale instance-attribute

scale = head_dim ** -0.5

tp_size instance-attribute

tp_size = (
    1
    if use_data_parallel
    else get_tensor_model_parallel_world_size()
)

use_rope instance-attribute

use_rope = use_rope

__init__

__init__(
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
)
Source code in vllm/model_executor/models/siglip2navit.py
def __init__(
    self,
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
):
    super().__init__()
    self.config = config
    self.embed_dim = config.hidden_size
    self.num_heads = config.num_attention_heads
    self.head_dim = self.embed_dim // self.num_heads
    if self.head_dim * self.num_heads != self.embed_dim:
        raise ValueError(
            f"embed_dim must be divisible by num_heads "
            f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
            f" {self.num_heads}).")
    self.scale = self.head_dim**-0.5
    self.dropout = config.attention_dropout
    self.is_causal = False

    # TODO(Isotr0py): Enable data parallel after we support
    # disabling TP on parallel linear layer
    self.qkv_proj = QKVParallelLinear(
        hidden_size=self.embed_dim,
        head_size=self.head_dim,
        total_num_heads=self.num_heads,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.out_proj = RowParallelLinear(
        input_size=self.embed_dim,
        output_size=self.embed_dim,
        quant_config=quant_config,
        prefix=f"{prefix}.out_proj",
    )

    self.tp_size = (1 if use_data_parallel else
                    get_tensor_model_parallel_world_size())
    self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
    self.use_rope = config.use_rope

    # Detect attention implementation.
    self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
    if self.attn_backend not in {
            _Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
            _Backend.ROCM_AITER_FA
    }:
        self.attn_backend = _Backend.TORCH_SDPA
    self.is_flash_attn_backend = self.attn_backend in {
        _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
    }

forward

forward(
    hidden_states: Tensor,
    cu_seqlens: Tensor,
    position_embeddings: Optional[
        tuple[Tensor, Tensor]
    ] = None,
) -> tuple[Tensor, Optional[Tensor]]

Input shape: Batch x Time x Channel

Source code in vllm/model_executor/models/siglip2navit.py
def forward(
    self,
    hidden_states: torch.Tensor,
    cu_seqlens: torch.Tensor,
    position_embeddings: Optional[tuple[torch.Tensor,
                                        torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """Input shape: Batch x Time x Channel"""

    seq_length, embed_dim = hidden_states.shape

    qkv_states, _ = self.qkv_proj(hidden_states)
    queries, keys, values = qkv_states.chunk(3, dim=-1)

    queries = queries.view(seq_length, self.num_heads_per_partition,
                           self.head_dim)
    keys = keys.view(seq_length, self.num_heads_per_partition,
                     self.head_dim)
    values = values.view(seq_length, self.num_heads_per_partition,
                         self.head_dim)

    if self.use_rope:
        cos, sin = position_embeddings
        queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0),
                                             keys.unsqueeze(0), cos, sin,
                                             self.is_flash_attn_backend)
        queries = queries.squeeze(0)
        keys = keys.squeeze(0)

    max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
    if self.is_flash_attn_backend:
        if self.attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func
        else:
            from flash_attn import flash_attn_varlen_func
        attn_output = flash_attn_varlen_func(
            queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
            max_seqlen).reshape(seq_length, -1)
    elif self.attn_backend == _Backend.TORCH_SDPA:
        # Execute attention entry by entry for speed & less VRAM.
        batch_size = cu_seqlens.shape[0] - 1
        outputs = []
        cu = cu_seqlens.tolist()
        for i in range(batch_size):
            start_idx = cu[i]
            end_idx = cu[i + 1]

            # Each sequence is processed independently.
            q_i = queries[start_idx:end_idx].unsqueeze(0)
            k_i = keys[start_idx:end_idx].unsqueeze(0)
            v_i = values[start_idx:end_idx].unsqueeze(0)

            # (1, seq_len, num_heads, head_dim) ->
            # (1, num_heads, seq_len, head_dim)
            q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]

            output_i = F.scaled_dot_product_attention(q_i,
                                                      k_i,
                                                      v_i,
                                                      dropout_p=0.0)
            # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
            output_i = output_i.transpose(1, 2).reshape(
                end_idx - start_idx, -1)
            outputs.append(output_i)

        attn_output = torch.cat(outputs, dim=0)
    attn_output, _ = self.out_proj(attn_output)
    return attn_output

Siglip2Encoder

Bases: Module

Transformer encoder consisting of config.num_hidden_layers self attention layers. Each layer is a [Siglip2EncoderLayer].

Parameters:

Name Type Description Default
config Siglip2VisionConfig

PretrainedConfig

required
Source code in vllm/model_executor/models/siglip2navit.py
class Siglip2Encoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` 
    self attention layers. Each layer is a [`Siglip2EncoderLayer`].

    Args:
        config: PretrainedConfig
    """

    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([
            Siglip2EncoderLayer(config,
                                quant_config=quant_config,
                                prefix=f"{prefix}.layers.{idx}",
                                use_data_parallel=use_data_parallel)
            for idx in range(config.num_hidden_layers)
        ])

        self.rotary_pos_emb = VisionRotaryEmbedding(
            config.hidden_size // config.num_attention_heads // 2)
        self.patch_size = config.patch_size
        self.hidden_stride = config.hidden_stride
        self.window_size = config.window_size
        self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
        if config.fullatt_block_indexes is None:
            self.fullatt_block_indexes = None
        else:
            self.fullatt_block_indexes = [
                int(i) for i in config.fullatt_block_indexes.split('|')
            ]

    # copied from qwen2.5_vl
    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        # patch (after merge) number in each window
        vit_merger_window_size = (self.window_size // self.hidden_stride //
                                  self.patch_size)

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.hidden_stride,  # number of patch after merge
                grid_w // self.hidden_stride,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
                grid_t, llm_grid_h, llm_grid_w)
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = seqlens.cumsum(
                0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        grid_thws: torch.Tensor,
    ) -> torch.Tensor:
        r"""
        Args:
            inputs_embeds (`torch.FloatTensor` of shape
                `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to
                directly pass an embedded representation. This is useful if
                you want more control over how to convert `input_ids` indices
                into associated vectors than the model's internal embedding
                lookup matrix.
            grid_thws (`torch.LongTensor`):
                grid shape (num_patches, 3)
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See
                `hidden_states` under returned tensors for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of
                a plain tuple.
        """
        rotary_pos_emb = self.rot_pos_emb(grid_thws)
        window_index, cu_window_seqlens = self.get_window_index(grid_thws)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=inputs_embeds.device,
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = inputs_embeds.size()
        inputs_embeds = inputs_embeds.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        inputs_embeds = inputs_embeds[window_index, :, :]
        inputs_embeds = inputs_embeds.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(
            grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
        ).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have
            #    same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852
            # for more information
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        reverse_indices = torch.argsort(window_index)

        hidden_states = inputs_embeds
        for index, block in enumerate(self.layers):
            if (not self.fullatt_block_indexes
                    or index in self.fullatt_block_indexes):
                cu_seqlens_tmp = cu_seqlens
            else:
                cu_seqlens_tmp = cu_window_seqlens
            hidden_states = block(hidden_states, cu_seqlens_tmp,
                                  position_embeddings)

        hidden_states = hidden_states.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)

        return hidden_states

config instance-attribute

config = config

fullatt_block_indexes instance-attribute

fullatt_block_indexes = None

hidden_stride instance-attribute

hidden_stride = hidden_stride

layers instance-attribute

layers = ModuleList(
    [
        (
            Siglip2EncoderLayer(
                config,
                quant_config=quant_config,
                prefix=f"{prefix}.layers.{idx}",
                use_data_parallel=use_data_parallel,
            )
        )
        for idx in (range(num_hidden_layers))
    ]
)

patch_size instance-attribute

patch_size = patch_size

rotary_pos_emb instance-attribute

rotary_pos_emb = VisionRotaryEmbedding(
    hidden_size // num_attention_heads // 2
)

spatial_merge_unit instance-attribute

spatial_merge_unit = hidden_stride * hidden_stride

window_size instance-attribute

window_size = window_size

__init__

__init__(
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
)
Source code in vllm/model_executor/models/siglip2navit.py
def __init__(
    self,
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
):
    super().__init__()
    self.config = config
    self.layers = nn.ModuleList([
        Siglip2EncoderLayer(config,
                            quant_config=quant_config,
                            prefix=f"{prefix}.layers.{idx}",
                            use_data_parallel=use_data_parallel)
        for idx in range(config.num_hidden_layers)
    ])

    self.rotary_pos_emb = VisionRotaryEmbedding(
        config.hidden_size // config.num_attention_heads // 2)
    self.patch_size = config.patch_size
    self.hidden_stride = config.hidden_stride
    self.window_size = config.window_size
    self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
    if config.fullatt_block_indexes is None:
        self.fullatt_block_indexes = None
    else:
        self.fullatt_block_indexes = [
            int(i) for i in config.fullatt_block_indexes.split('|')
        ]

forward

forward(inputs_embeds: Tensor, grid_thws: Tensor) -> Tensor

Parameters:

Name Type Description Default
grid_thws `torch.LongTensor`

grid shape (num_patches, 3)

required
output_hidden_states `bool`, *optional*

Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

required
return_dict `bool`, *optional*

Whether or not to return a [~utils.ModelOutput] instead of a plain tuple.

required
Source code in vllm/model_executor/models/siglip2navit.py
def forward(
    self,
    inputs_embeds: torch.Tensor,
    grid_thws: torch.Tensor,
) -> torch.Tensor:
    r"""
    Args:
        inputs_embeds (`torch.FloatTensor` of shape
            `(batch_size, sequence_length, hidden_size)`):
            Optionally, instead of passing `input_ids` you can choose to
            directly pass an embedded representation. This is useful if
            you want more control over how to convert `input_ids` indices
            into associated vectors than the model's internal embedding
            lookup matrix.
        grid_thws (`torch.LongTensor`):
            grid shape (num_patches, 3)
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See
            `hidden_states` under returned tensors for more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of
            a plain tuple.
    """
    rotary_pos_emb = self.rot_pos_emb(grid_thws)
    window_index, cu_window_seqlens = self.get_window_index(grid_thws)
    cu_window_seqlens = torch.tensor(
        cu_window_seqlens,
        device=inputs_embeds.device,
        dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
    )
    cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

    seq_len, _ = inputs_embeds.size()
    inputs_embeds = inputs_embeds.reshape(
        seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
    inputs_embeds = inputs_embeds[window_index, :, :]
    inputs_embeds = inputs_embeds.reshape(seq_len, -1)
    rotary_pos_emb = rotary_pos_emb.reshape(
        seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
    rotary_pos_emb = rotary_pos_emb[window_index, :, :]
    rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
    emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
    position_embeddings = (emb.cos(), emb.sin())

    cu_seqlens = torch.repeat_interleave(
        grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
    ).cumsum(
        dim=0,
        # Select dtype based on the following factors:
        #  - FA2 requires that cu_seqlens_q must have dtype int32
        #  - torch.onnx.export requires that cu_seqlens_q must have
        #    same dtype as grid_thw
        # See https://github.com/huggingface/transformers/pull/34852
        # for more information
        dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
    )
    cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

    reverse_indices = torch.argsort(window_index)

    hidden_states = inputs_embeds
    for index, block in enumerate(self.layers):
        if (not self.fullatt_block_indexes
                or index in self.fullatt_block_indexes):
            cu_seqlens_tmp = cu_seqlens
        else:
            cu_seqlens_tmp = cu_window_seqlens
        hidden_states = block(hidden_states, cu_seqlens_tmp,
                              position_embeddings)

    hidden_states = hidden_states.reshape(
        seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
    hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)

    return hidden_states

get_window_index

get_window_index(grid_thw)
Source code in vllm/model_executor/models/siglip2navit.py
def get_window_index(self, grid_thw):
    window_index: list = []
    cu_window_seqlens: list = [0]
    window_index_id = 0
    # patch (after merge) number in each window
    vit_merger_window_size = (self.window_size // self.hidden_stride //
                              self.patch_size)

    for grid_t, grid_h, grid_w in grid_thw:
        llm_grid_h, llm_grid_w = (
            grid_h // self.hidden_stride,  # number of patch after merge
            grid_w // self.hidden_stride,
        )
        index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
            grid_t, llm_grid_h, llm_grid_w)
        pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
        pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
        num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
        num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
        index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
        index_padded = index_padded.reshape(
            grid_t,
            num_windows_h,
            vit_merger_window_size,
            num_windows_w,
            vit_merger_window_size,
        )
        index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
            grid_t,
            num_windows_h * num_windows_w,
            vit_merger_window_size,
            vit_merger_window_size,
        )
        seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
        index_padded = index_padded.reshape(-1)
        index_new = index_padded[index_padded != -100]
        window_index.append(index_new + window_index_id)
        cu_seqlens_tmp = seqlens.cumsum(
            0) * self.spatial_merge_unit + cu_window_seqlens[-1]
        cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
        window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
    window_index = torch.cat(window_index, dim=0)

    return window_index, cu_window_seqlens

rot_pos_emb

rot_pos_emb(grid_thw)
Source code in vllm/model_executor/models/siglip2navit.py
def rot_pos_emb(self, grid_thw):
    pos_ids = []
    for t, h, w in grid_thw:
        hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
        hpos_ids = hpos_ids.reshape(
            h // self.hidden_stride,
            self.hidden_stride,
            w // self.hidden_stride,
            self.hidden_stride,
        )
        hpos_ids = hpos_ids.permute(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
        wpos_ids = wpos_ids.reshape(
            h // self.hidden_stride,
            self.hidden_stride,
            w // self.hidden_stride,
            self.hidden_stride,
        )
        wpos_ids = wpos_ids.permute(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()
        pos_ids.append(
            torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
    pos_ids = torch.cat(pos_ids, dim=0)
    max_grid_size = grid_thw[:, 1:].max()
    rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
    rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
    return rotary_pos_emb

Siglip2EncoderLayer

Bases: Module

Source code in vllm/model_executor/models/siglip2navit.py
class Siglip2EncoderLayer(nn.Module):

    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
        self.self_attn = Siglip2Attention(config,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.self_attn",
                                          use_data_parallel=use_data_parallel)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
        self.mlp = Siglip2MLP(config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.mlp",
                              use_data_parallel=use_data_parallel)

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
                position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all 
                attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       cu_seqlens=cu_seqlens,
                                       position_embeddings=position_embeddings)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

embed_dim instance-attribute

embed_dim = hidden_size

layer_norm1 instance-attribute

layer_norm1 = LayerNorm(embed_dim, eps=layer_norm_eps)

layer_norm2 instance-attribute

layer_norm2 = LayerNorm(embed_dim, eps=layer_norm_eps)

mlp instance-attribute

mlp = Siglip2MLP(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
    use_data_parallel=use_data_parallel,
)

self_attn instance-attribute

self_attn = Siglip2Attention(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
    use_data_parallel=use_data_parallel,
)

__init__

__init__(
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
)
Source code in vllm/model_executor/models/siglip2navit.py
def __init__(
    self,
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
):
    super().__init__()
    self.embed_dim = config.hidden_size
    self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                    eps=config.layer_norm_eps)
    self.self_attn = Siglip2Attention(config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.self_attn",
                                      use_data_parallel=use_data_parallel)
    self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                    eps=config.layer_norm_eps)
    self.mlp = Siglip2MLP(config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.mlp",
                          use_data_parallel=use_data_parallel)

forward

forward(
    hidden_states: Tensor,
    cu_seqlens: Tensor,
    position_embeddings: Tensor,
) -> tuple[FloatTensor]

Parameters:

Name Type Description Default
hidden_states `torch.FloatTensor`

Input to the layer of shape (batch, seq_len, embed_dim).

required
output_attentions `bool`, *optional*, defaults to `False`

Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

required
Source code in vllm/model_executor/models/siglip2navit.py
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
            position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
    """
    Args:
        hidden_states (`torch.FloatTensor`):
            Input to the layer of shape `(batch, seq_len, embed_dim)`.
        output_attentions (`bool`, *optional*, defaults to `False`):
            Whether or not to return the attentions tensors of all 
            attention layers. See `attentions` under
            returned tensors for more detail.
    """
    residual = hidden_states

    hidden_states = self.layer_norm1(hidden_states)
    hidden_states = self.self_attn(hidden_states=hidden_states,
                                   cu_seqlens=cu_seqlens,
                                   position_embeddings=position_embeddings)
    hidden_states = residual + hidden_states

    residual = hidden_states
    hidden_states = self.layer_norm2(hidden_states)
    hidden_states = self.mlp(hidden_states)
    hidden_states = residual + hidden_states
    return hidden_states

Siglip2MLP

Bases: Module

Source code in vllm/model_executor/models/siglip2navit.py
class Siglip2MLP(nn.Module):

    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states

activation_fn instance-attribute

activation_fn = get_act_fn(hidden_act)

config instance-attribute

config = config

fc1 instance-attribute

fc1 = ColumnParallelLinear(
    hidden_size,
    intermediate_size,
    quant_config=quant_config,
    prefix=f"{prefix}.fc1",
)

fc2 instance-attribute

fc2 = RowParallelLinear(
    intermediate_size,
    hidden_size,
    quant_config=quant_config,
    prefix=f"{prefix}.fc2",
)

__init__

__init__(
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
)
Source code in vllm/model_executor/models/siglip2navit.py
def __init__(
    self,
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
):
    super().__init__()
    self.config = config
    self.activation_fn = get_act_fn(config.hidden_act)
    # TODO(Isotr0py): Enable data parallel after we support
    # disabling TP on parallel linear layer
    self.fc1 = ColumnParallelLinear(
        config.hidden_size,
        config.intermediate_size,
        quant_config=quant_config,
        prefix=f"{prefix}.fc1",
    )
    self.fc2 = RowParallelLinear(
        config.intermediate_size,
        config.hidden_size,
        quant_config=quant_config,
        prefix=f"{prefix}.fc2",
    )

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/siglip2navit.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    hidden_states, _ = self.fc1(hidden_states)
    hidden_states = self.activation_fn(hidden_states)
    hidden_states, _ = self.fc2(hidden_states)
    return hidden_states

Siglip2NavitModel

Bases: Module

Source code in vllm/model_executor/models/siglip2navit.py
class Siglip2NavitModel(torch.nn.Module):

    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()

        self.vision_model = Siglip2VisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
            use_data_parallel=use_data_parallel)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            grid_thws=grid_thws,
        )

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

vision_model instance-attribute

vision_model = Siglip2VisionTransformer(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.vision_model",
    use_data_parallel=use_data_parallel,
)

__init__

__init__(
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
)
Source code in vllm/model_executor/models/siglip2navit.py
def __init__(
    self,
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
):
    super().__init__()

    self.vision_model = Siglip2VisionTransformer(
        config,
        quant_config=quant_config,
        prefix=f"{prefix}.vision_model",
        use_data_parallel=use_data_parallel)

forward

forward(
    pixel_values: FloatTensor, grid_thws: LongTensor
) -> Tensor
Source code in vllm/model_executor/models/siglip2navit.py
def forward(
    self,
    pixel_values: torch.FloatTensor,
    grid_thws: torch.LongTensor,
) -> torch.Tensor:
    return self.vision_model(
        pixel_values=pixel_values,
        grid_thws=grid_thws,
    )

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/siglip2navit.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    for name, loaded_weight in weights:
        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

Siglip2VisionEmbeddings

Bases: Module

Source code in vllm/model_executor/models/siglip2navit.py
class Siglip2VisionEmbeddings(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.patch_size = config.patch_size
        self.image_size = config.image_size
        self.num_patches = config.num_patches
        self.preserve_original_pe = config.preserve_original_pe
        self.hidden_stride = config.hidden_stride

        # siglip2 naflex
        if self.num_patches > 0:
            self.patch_embedding = ReplicatedLinear(
                input_size=config.num_channels * self.patch_size *
                self.patch_size,
                output_size=self.embed_dim,
                return_bias=False,
            )
            if self.preserve_original_pe:
                self.position_embedding_size = int(self.num_patches**0.5)
                self.position_embedding = nn.Embedding(self.num_patches,
                                                       self.embed_dim)

        else:
            self.patch_embedding = nn.Conv2d(
                in_channels=config.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
                padding="valid",
            )
            if self.preserve_original_pe:
                self.num_patches = (self.image_size // self.patch_size)**2
                self.position_embedding_size = (self.image_size //
                                                self.patch_size)
                self.position_embedding = nn.Embedding(self.num_patches,
                                                       self.embed_dim)

    def forward(self,
                pixel_values: torch.FloatTensor,
                grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
        """
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (
                    num_patches,
                    num_channels * temporal_patch_size * patch_size * patch_size
                )
            grid_thws: (`torch.LongTensor`):
                grid shape (num_patches, 3)
        """

        # Apply patch embeddings to already patchified pixel values
        target_dtype = self.patch_embedding.weight.dtype
        if isinstance(self.patch_embedding, LinearBase):
            patch_embeds = self.patch_embedding(
                pixel_values.to(dtype=target_dtype))
        elif isinstance(self.patch_embedding, nn.Conv2d):
            pixel_values = pixel_values.view(
                -1, self.config.num_channels * self.config.temporal_patch_size,
                self.patch_size, self.patch_size)
            patch_embeds = self.patch_embedding(
                pixel_values.to(dtype=target_dtype))
            patch_embeds = patch_embeds.reshape(-1, self.embed_dim)

        if self.preserve_original_pe:
            assert grid_thws is not None
            pos_embed_new = torch.zeros_like(patch_embeds)
            positional_embeddings = self.position_embedding.weight.reshape(
                self.position_embedding_size, self.position_embedding_size,
                -1).unsqueeze(0).permute(0, 3, 1, 2)
            cnt = 0
            for t, h, w in grid_thws:
                volume = t * h * w
                pe = F.interpolate(positional_embeddings,
                                   size=(h, w),
                                   mode='bicubic',
                                   align_corners=False)
                pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
                pe = pe[0].repeat(t, 1)
                pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride,
                                w // self.hidden_stride, self.hidden_stride,
                                -1)
                pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
                pos_embed_new[cnt:cnt + volume] = pe
                cnt += volume
            patch_embeds = patch_embeds + pos_embed_new

        return patch_embeds

config instance-attribute

config = config

embed_dim instance-attribute

embed_dim = hidden_size

hidden_stride instance-attribute

hidden_stride = hidden_stride

image_size instance-attribute

image_size = image_size

num_patches instance-attribute

num_patches = num_patches

patch_embedding instance-attribute

patch_embedding = ReplicatedLinear(
    input_size=num_channels * patch_size * patch_size,
    output_size=embed_dim,
    return_bias=False,
)

patch_size instance-attribute

patch_size = patch_size

position_embedding instance-attribute

position_embedding = Embedding(num_patches, embed_dim)

position_embedding_size instance-attribute

position_embedding_size = int(num_patches ** 0.5)

preserve_original_pe instance-attribute

preserve_original_pe = preserve_original_pe

__init__

__init__(config: PretrainedConfig)
Source code in vllm/model_executor/models/siglip2navit.py
def __init__(self, config: PretrainedConfig):
    super().__init__()
    self.config = config
    self.embed_dim = config.hidden_size
    self.patch_size = config.patch_size
    self.image_size = config.image_size
    self.num_patches = config.num_patches
    self.preserve_original_pe = config.preserve_original_pe
    self.hidden_stride = config.hidden_stride

    # siglip2 naflex
    if self.num_patches > 0:
        self.patch_embedding = ReplicatedLinear(
            input_size=config.num_channels * self.patch_size *
            self.patch_size,
            output_size=self.embed_dim,
            return_bias=False,
        )
        if self.preserve_original_pe:
            self.position_embedding_size = int(self.num_patches**0.5)
            self.position_embedding = nn.Embedding(self.num_patches,
                                                   self.embed_dim)

    else:
        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )
        if self.preserve_original_pe:
            self.num_patches = (self.image_size // self.patch_size)**2
            self.position_embedding_size = (self.image_size //
                                            self.patch_size)
            self.position_embedding = nn.Embedding(self.num_patches,
                                                   self.embed_dim)

forward

forward(
    pixel_values: FloatTensor,
    grid_thws: Optional[LongTensor] = None,
) -> Tensor

Parameters:

Name Type Description Default
pixel_values `torch.FloatTensor`

Pixel values of shape ( num_patches, num_channels * temporal_patch_size * patch_size * patch_size )

required
grid_thws Optional[LongTensor]

(torch.LongTensor): grid shape (num_patches, 3)

None
Source code in vllm/model_executor/models/siglip2navit.py
def forward(self,
            pixel_values: torch.FloatTensor,
            grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
    """
    Args:
        pixel_values (`torch.FloatTensor`):
            Pixel values of shape (
                num_patches,
                num_channels * temporal_patch_size * patch_size * patch_size
            )
        grid_thws: (`torch.LongTensor`):
            grid shape (num_patches, 3)
    """

    # Apply patch embeddings to already patchified pixel values
    target_dtype = self.patch_embedding.weight.dtype
    if isinstance(self.patch_embedding, LinearBase):
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype))
    elif isinstance(self.patch_embedding, nn.Conv2d):
        pixel_values = pixel_values.view(
            -1, self.config.num_channels * self.config.temporal_patch_size,
            self.patch_size, self.patch_size)
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype))
        patch_embeds = patch_embeds.reshape(-1, self.embed_dim)

    if self.preserve_original_pe:
        assert grid_thws is not None
        pos_embed_new = torch.zeros_like(patch_embeds)
        positional_embeddings = self.position_embedding.weight.reshape(
            self.position_embedding_size, self.position_embedding_size,
            -1).unsqueeze(0).permute(0, 3, 1, 2)
        cnt = 0
        for t, h, w in grid_thws:
            volume = t * h * w
            pe = F.interpolate(positional_embeddings,
                               size=(h, w),
                               mode='bicubic',
                               align_corners=False)
            pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
            pe = pe[0].repeat(t, 1)
            pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride,
                            w // self.hidden_stride, self.hidden_stride,
                            -1)
            pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
            pos_embed_new[cnt:cnt + volume] = pe
            cnt += volume
        patch_embeds = patch_embeds + pos_embed_new

    return patch_embeds

Siglip2VisionTransformer

Bases: Module

Source code in vllm/model_executor/models/siglip2navit.py
class Siglip2VisionTransformer(nn.Module):

    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = Siglip2VisionEmbeddings(config)
        self.encoder = Siglip2Encoder(config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.encoder",
                                      use_data_parallel=use_data_parallel)
        self.post_layernorm = nn.LayerNorm(embed_dim,
                                           eps=config.layer_norm_eps)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
    ) -> torch.Tensor:
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
            of the input images.
        """
        hidden_states = self.embeddings(pixel_values, grid_thws)

        last_hidden_state = self.encoder(hidden_states, grid_thws)
        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state

config instance-attribute

config = config

embeddings instance-attribute

embeddings = Siglip2VisionEmbeddings(config)

encoder instance-attribute

encoder = Siglip2Encoder(
    config,
    quant_config=quant_config,
    prefix=f"{prefix}.encoder",
    use_data_parallel=use_data_parallel,
)

post_layernorm instance-attribute

post_layernorm = LayerNorm(embed_dim, eps=layer_norm_eps)

__init__

__init__(
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
)
Source code in vllm/model_executor/models/siglip2navit.py
def __init__(
    self,
    config: Siglip2VisionConfig,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
    use_data_parallel: bool = False,
):
    super().__init__()
    self.config = config
    embed_dim = config.hidden_size

    self.embeddings = Siglip2VisionEmbeddings(config)
    self.encoder = Siglip2Encoder(config,
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.encoder",
                                  use_data_parallel=use_data_parallel)
    self.post_layernorm = nn.LayerNorm(embed_dim,
                                       eps=config.layer_norm_eps)

forward

forward(
    pixel_values: FloatTensor, grid_thws: LongTensor
) -> Tensor

spatial_shapes (torch.LongTensor of shape (batch_size, 2)): Tensor containing the spatial dimensions (height, width) of the input images.

Source code in vllm/model_executor/models/siglip2navit.py
def forward(
    self,
    pixel_values: torch.FloatTensor,
    grid_thws: torch.LongTensor,
) -> torch.Tensor:
    r"""
    spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
        Tensor containing the spatial dimensions (height, width)
        of the input images.
    """
    hidden_states = self.embeddings(pixel_values, grid_thws)

    last_hidden_state = self.encoder(hidden_states, grid_thws)
    last_hidden_state = self.post_layernorm(last_hidden_state)

    return last_hidden_state

VisionRotaryEmbedding

Bases: Module

Source code in vllm/model_executor/models/siglip2navit.py
class VisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(seqlen,
                           device=self.inv_freq.device,
                           dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq, self.inv_freq)
        return freqs

__init__

__init__(dim: int, theta: float = 10000.0) -> None
Source code in vllm/model_executor/models/siglip2navit.py
def __init__(self, dim: int, theta: float = 10000.0) -> None:
    super().__init__()
    inv_freq = 1.0 / (theta
                      **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
    self.register_buffer("inv_freq", inv_freq, persistent=False)

forward

forward(seqlen: int) -> Tensor
Source code in vllm/model_executor/models/siglip2navit.py
def forward(self, seqlen: int) -> torch.Tensor:
    seq = torch.arange(seqlen,
                       device=self.inv_freq.device,
                       dtype=self.inv_freq.dtype)
    freqs = torch.outer(seq, self.inv_freq)
    return freqs

apply_rotary_emb_torch

apply_rotary_emb_torch(x, cos, sin, interleaved=False)

x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)

Source code in vllm/model_executor/models/siglip2navit.py
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(
        sin,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [
            x[..., :ro_dim] * cos +
            rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
        ],
        dim=-1,
    )

apply_rotary_pos_emb

apply_rotary_pos_emb(
    q: Tensor,
    k: Tensor,
    cos: Tensor,
    sin: Tensor,
    is_flash_attn_backend: bool = False,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/siglip2navit.py
def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_flash_attn_backend: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()
    if is_flash_attn_backend:
        from flash_attn.layers.rotary import apply_rotary_emb
        apply_rotary_emb_func = apply_rotary_emb
    else:
        apply_rotary_emb_func = apply_rotary_emb_torch
    q_embed = apply_rotary_emb_func(q.float(), cos.float(),
                                    sin.float()).type_as(q)
    k_embed = apply_rotary_emb_func(k.float(), cos.float(),
                                    sin.float()).type_as(k)
    return q_embed, k_embed

rotate_half

rotate_half(x, interleaved=False)
Source code in vllm/model_executor/models/siglip2navit.py
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1),
                         "... d two -> ... (d two)",
                         two=2)