Skip to content

vllm.model_executor.models.mamba2

PyTorch MAMBA2 model.

KVCache module-attribute

KVCache = tuple[Tensor, Tensor]

Mamba2DecoderLayer

Bases: Module

Source code in vllm/model_executor/models/mamba2.py
class Mamba2DecoderLayer(nn.Module):

    def __init__(self,
                 config: MambaConfig,
                 model_config: Optional[ModelConfig] = None,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
        super().__init__()
        self.config = config
        self.mixer = MambaMixer2(hidden_size=config.hidden_size,
                                 ssm_state_size=config.state_size,
                                 conv_kernel_size=config.conv_kernel,
                                 intermediate_size=getattr(
                                     config, "intermediate_size",
                                     config.expand * config.hidden_size),
                                 use_conv_bias=config.use_conv_bias,
                                 use_bias=config.use_bias,
                                 n_groups=config.n_groups,
                                 num_heads=config.num_heads,
                                 head_dim=config.head_dim,
                                 rms_norm_eps=config.layer_norm_epsilon,
                                 activation=config.hidden_act,
                                 model_config=model_config,
                                 cache_config=cache_config,
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.mixer")

        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        mamba_cache_params: MambaCacheParams,
        mamba2_metadata: Mamba2Metadata,
        **kwargs,
    ):
        if residual is None:
            residual = hidden_states
            hidden_states = self.norm(hidden_states)
        else:
            hidden_states, residual = self.norm(hidden_states, residual)

        output = torch.empty_like(hidden_states)
        self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
        return output, residual

config instance-attribute

config = config

mixer instance-attribute

mixer = MambaMixer2(
    hidden_size=hidden_size,
    ssm_state_size=state_size,
    conv_kernel_size=conv_kernel,
    intermediate_size=getattr(
        config, "intermediate_size", expand * hidden_size
    ),
    use_conv_bias=use_conv_bias,
    use_bias=use_bias,
    n_groups=n_groups,
    num_heads=num_heads,
    head_dim=head_dim,
    rms_norm_eps=layer_norm_epsilon,
    activation=hidden_act,
    model_config=model_config,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.mixer",
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)

__init__

__init__(
    config: MambaConfig,
    model_config: Optional[ModelConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/mamba2.py
def __init__(self,
             config: MambaConfig,
             model_config: Optional[ModelConfig] = None,
             cache_config: Optional[CacheConfig] = None,
             quant_config: Optional[QuantizationConfig] = None,
             prefix: str = "") -> None:
    super().__init__()
    self.config = config
    self.mixer = MambaMixer2(hidden_size=config.hidden_size,
                             ssm_state_size=config.state_size,
                             conv_kernel_size=config.conv_kernel,
                             intermediate_size=getattr(
                                 config, "intermediate_size",
                                 config.expand * config.hidden_size),
                             use_conv_bias=config.use_conv_bias,
                             use_bias=config.use_bias,
                             n_groups=config.n_groups,
                             num_heads=config.num_heads,
                             head_dim=config.head_dim,
                             rms_norm_eps=config.layer_norm_epsilon,
                             activation=config.hidden_act,
                             model_config=model_config,
                             cache_config=cache_config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mixer")

    self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

forward

forward(
    hidden_states: Tensor,
    residual: Optional[Tensor],
    mamba_cache_params: MambaCacheParams,
    mamba2_metadata: Mamba2Metadata,
    **kwargs,
)
Source code in vllm/model_executor/models/mamba2.py
def forward(
    self,
    hidden_states: torch.Tensor,
    residual: Optional[torch.Tensor],
    mamba_cache_params: MambaCacheParams,
    mamba2_metadata: Mamba2Metadata,
    **kwargs,
):
    if residual is None:
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
    else:
        hidden_states, residual = self.norm(hidden_states, residual)

    output = torch.empty_like(hidden_states)
    self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
    return output, residual

Mamba2ForCausalLM

Bases: Module, HasInnerState, IsAttentionFree

Source code in vllm/model_executor/models/mamba2.py
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:

        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
        use_v1: bool = True,
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config
            use_v1: Get shapes for V1 (or V0)

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        intermediate_size = hf_config.expand * hf_config.hidden_size

        return MambaStateShapeCalculator.mamba2_state_shape(
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.n_groups,
            num_heads=hf_config.num_heads,
            head_dim=hf_config.head_dim,
            state_size=hf_config.state_size,
            conv_kernel=hf_config.conv_kernel,
            use_v1=use_v1,
        )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
        assert not cache_config.enable_prefix_caching, \
            "Mamba does not support prefix caching"

        super().__init__()
        self.config = config
        self.vllm_config = vllm_config
        self.scheduler_config = scheduler_config
        self.model_config = vllm_config.model_config
        self.backbone = Mamba2Model(vllm_config=vllm_config,
                                    prefix=maybe_prefix(prefix, "backbone"))
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
        if config.tie_word_embeddings:
            self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)

        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)

        self.make_empty_intermediate_tensors = (
            self.backbone.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.backbone.get_input_embeddings(input_ids)

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs):
        if not envs.VLLM_USE_V1:
            if self.mamba_cache is None:
                num_mamba_layers = (
                    self.model_config.get_num_layers_by_block_type(
                        self.vllm_config.parallel_config,
                        LayerBlockType.mamba))
                mamba_state_shape = \
                    self.get_mamba_state_shape_from_config(
                        self.vllm_config, use_v1=False)
                mamba_state_dtype = \
                    self.get_mamba_state_dtype_from_config(
                    self.vllm_config)
                self.mamba_cache = MambaCacheManager(self.vllm_config,
                                                     num_mamba_layers,
                                                     *mamba_state_shape,
                                                     *mamba_state_dtype)

            mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
        else:
            # NOTE: mamba_cache_params is not needed for v1
            mamba_cache_params = None

        hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
                                      intermediate_tensors, inputs_embeds)

        return hidden_states

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

backbone instance-attribute

backbone = Mamba2Model(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "backbone"),
)

config instance-attribute

config = config

lm_head instance-attribute

lm_head = ParallelLMHead(
    unpadded_vocab_size,
    hidden_size,
    org_num_embeddings=vocab_size,
    padding_size=DEFAULT_VOCAB_PADDING_SIZE
    if not lora_config
    else lora_vocab_padding_size,
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    unpadded_vocab_size, vocab_size
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

mamba_cache instance-attribute

mamba_cache: Optional[MambaCacheManager] = None

model_config instance-attribute

model_config = model_config

scheduler_config instance-attribute

scheduler_config = scheduler_config

unpadded_vocab_size instance-attribute

unpadded_vocab_size = vocab_size

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/mamba2.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    lora_config = vllm_config.lora_config
    scheduler_config = vllm_config.scheduler_config
    assert not cache_config.enable_prefix_caching, \
        "Mamba does not support prefix caching"

    super().__init__()
    self.config = config
    self.vllm_config = vllm_config
    self.scheduler_config = scheduler_config
    self.model_config = vllm_config.model_config
    self.backbone = Mamba2Model(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "backbone"))
    self.unpadded_vocab_size = config.vocab_size
    if lora_config:
        self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

    self.lm_head = ParallelLMHead(
        self.unpadded_vocab_size,
        config.hidden_size,
        org_num_embeddings=config.vocab_size,
        padding_size=DEFAULT_VOCAB_PADDING_SIZE
        # We need bigger padding if using lora for kernel
        # compatibility
        if not lora_config else lora_config.lora_vocab_padding_size,
    )
    if config.tie_word_embeddings:
        self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)

    # Used to track and store by the Mamba cache between steps.
    self.mamba_cache: Optional[MambaCacheManager] = None

    self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                            config.vocab_size)

    self.make_empty_intermediate_tensors = (
        self.backbone.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Tensor
Source code in vllm/model_executor/models/mamba2.py
def compute_logits(self, hidden_states: torch.Tensor,
                   sampling_metadata: SamplingMetadata) -> torch.Tensor:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata)
    return logits

copy_inputs_before_cuda_graphs

copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
Source code in vllm/model_executor/models/mamba2.py
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
    return self.mamba_cache.copy_inputs_before_cuda_graphs(
        input_buffers, **kwargs)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
)
Source code in vllm/model_executor/models/mamba2.py
def forward(self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            **kwargs):
    if not envs.VLLM_USE_V1:
        if self.mamba_cache is None:
            num_mamba_layers = (
                self.model_config.get_num_layers_by_block_type(
                    self.vllm_config.parallel_config,
                    LayerBlockType.mamba))
            mamba_state_shape = \
                self.get_mamba_state_shape_from_config(
                    self.vllm_config, use_v1=False)
            mamba_state_dtype = \
                self.get_mamba_state_dtype_from_config(
                self.vllm_config)
            self.mamba_cache = MambaCacheManager(self.vllm_config,
                                                 num_mamba_layers,
                                                 *mamba_state_shape,
                                                 *mamba_state_dtype)

        mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
    else:
        # NOTE: mamba_cache_params is not needed for v1
        mamba_cache_params = None

    hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
                                  intermediate_tensors, inputs_embeds)

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/mamba2.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.backbone.get_input_embeddings(input_ids)

get_mamba_state_dtype_from_config classmethod

get_mamba_state_dtype_from_config(
    vllm_config: VllmConfig,
) -> tuple[dtype, dtype]
Source code in vllm/model_executor/models/mamba2.py
@classmethod
def get_mamba_state_dtype_from_config(
    cls,
    vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:

    return MambaStateDtypeCalculator.mamba2_state_dtype(
        vllm_config.model_config.dtype,
        vllm_config.cache_config.mamba_cache_dtype,
        vllm_config.cache_config.mamba_ssm_cache_dtype,
    )

get_mamba_state_shape_from_config classmethod

get_mamba_state_shape_from_config(
    vllm_config: VllmConfig, use_v1: bool = True
) -> tuple[tuple[int, int], tuple[int, int, int]]

Calculate shapes for Mamba's convolutional and state caches.

Parameters:

Name Type Description Default
vllm_config VllmConfig

vLLM config

required
use_v1 bool

Get shapes for V1 (or V0)

True

Returns:

Type Description
tuple[int, int]

Tuple containing:

tuple[int, int, int]
  • conv_state_shape: Shape for convolutional state cache
tuple[tuple[int, int], tuple[int, int, int]]
  • temporal_state_shape: Shape for state space model cache
Source code in vllm/model_executor/models/mamba2.py
@classmethod
def get_mamba_state_shape_from_config(
    cls,
    vllm_config: "VllmConfig",
    use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
    """Calculate shapes for Mamba's convolutional and state caches.

    Args:
        vllm_config: vLLM config
        use_v1: Get shapes for V1 (or V0)

    Returns:
        Tuple containing:
        - conv_state_shape: Shape for convolutional state cache
        - temporal_state_shape: Shape for state space model cache
    """
    parallel_config = vllm_config.parallel_config
    hf_config = vllm_config.model_config.hf_config
    intermediate_size = hf_config.expand * hf_config.hidden_size

    return MambaStateShapeCalculator.mamba2_state_shape(
        intermediate_size=intermediate_size,
        tp_world_size=parallel_config.tensor_parallel_size,
        n_groups=hf_config.n_groups,
        num_heads=hf_config.num_heads,
        head_dim=hf_config.head_dim,
        state_size=hf_config.state_size,
        conv_kernel=hf_config.conv_kernel,
        use_v1=use_v1,
    )

get_seqlen_agnostic_capture_inputs

get_seqlen_agnostic_capture_inputs(batch_size: int)
Source code in vllm/model_executor/models/mamba2.py
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
    return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/mamba2.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights)

Mamba2Model

Bases: Module

Source code in vllm/model_executor/models/mamba2.py
@support_torch_compile
class Mamba2Model(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        is_lora_enabled = bool(lora_config)
        assert not is_lora_enabled

        self.config = config
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embeddings = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Mamba2DecoderLayer(config,
                                              model_config=model_config,
                                              cache_config=cache_config,
                                              quant_config=quant_config,
                                              prefix=prefix),
            prefix=f"{prefix}.layers")

        self.norm_f = RMSNorm(config.hidden_size,
                              eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        mamba_cache_params: MambaCacheParams,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        attn_metadata: AttentionMetadata = get_forward_context().attn_metadata

        if not envs.VLLM_USE_V1:
            mamba2_metadata = prepare_mamba2_metadata(
                chunk_size=self.config.chunk_size,
                attn_metadata=attn_metadata,
            )
        else:
            # v1 get mamba2_metadata from forward_context
            mamba2_metadata = None

        for i in range(len(self.layers)):
            layer = self.layers[i]

            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
                mamba_cache_params=mamba_cache_params.at_layer_idx(
                    i - self.start_layer) if mamba_cache_params else None,
                mamba2_metadata=mamba2_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

        hidden_states, _ = self.norm_f(hidden_states, residual)

        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "A_log" in name:
                name = name.replace("A_log", "A")

            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

embeddings instance-attribute

embeddings = VocabParallelEmbedding(
    vocab_size, hidden_size, org_num_embeddings=vocab_size
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], hidden_size
    )
)

norm_f instance-attribute

norm_f = RMSNorm(hidden_size, eps=layer_norm_epsilon)

org_vocab_size instance-attribute

org_vocab_size = vocab_size

vocab_size instance-attribute

vocab_size = vocab_size + lora_vocab

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/mamba2.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    model_config = vllm_config.model_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config
    is_lora_enabled = bool(lora_config)
    assert not is_lora_enabled

    self.config = config
    lora_vocab = ((lora_config.lora_extra_vocab_size *
                   (lora_config.max_loras or 1)) if lora_config else 0)
    self.vocab_size = config.vocab_size + lora_vocab
    self.org_vocab_size = config.vocab_size

    self.embeddings = VocabParallelEmbedding(
        self.vocab_size,
        config.hidden_size,
        org_num_embeddings=config.vocab_size,
    )

    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: Mamba2DecoderLayer(config,
                                          model_config=model_config,
                                          cache_config=cache_config,
                                          quant_config=quant_config,
                                          prefix=prefix),
        prefix=f"{prefix}.layers")

    self.norm_f = RMSNorm(config.hidden_size,
                          eps=config.layer_norm_epsilon)
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size))

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    mamba_cache_params: MambaCacheParams,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/models/mamba2.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    mamba_cache_params: MambaCacheParams,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    attn_metadata: AttentionMetadata = get_forward_context().attn_metadata

    if not envs.VLLM_USE_V1:
        mamba2_metadata = prepare_mamba2_metadata(
            chunk_size=self.config.chunk_size,
            attn_metadata=attn_metadata,
        )
    else:
        # v1 get mamba2_metadata from forward_context
        mamba2_metadata = None

    for i in range(len(self.layers)):
        layer = self.layers[i]

        hidden_states, residual = layer(
            positions=positions,
            hidden_states=hidden_states,
            residual=residual,
            mamba_cache_params=mamba_cache_params.at_layer_idx(
                i - self.start_layer) if mamba_cache_params else None,
            mamba2_metadata=mamba2_metadata)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors({
            "hidden_states": hidden_states,
            "residual": residual
        })

    hidden_states, _ = self.norm_f(hidden_states, residual)

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/mamba2.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/mamba2.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if "A_log" in name:
            name = name.replace("A_log", "A")

        # Skip loading extra bias for GPTQ models.
        if name.endswith(".bias") and name not in params_dict:
            continue
        if is_pp_missing_parameter(name, self):
            continue

        param = params_dict[name]
        weight_loader = getattr(param, "weight_loader",
                                default_weight_loader)
        weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params