Skip to content

vllm.model_executor.layers.mamba.abstract

MambaBase

Bases: AttentionLayerBase

Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer.

Source code in vllm/model_executor/layers/mamba/abstract.py
class MambaBase(AttentionLayerBase):
    """
    Base class for Mamba-like layers which support the v1 engine.
    Inherit from this class if you implement a custom layer.
    """

    # Contains the KV cache (mamba state) for the layer
    # in the shape specified by `self.get_state_shape`.
    # The outer list is for v0 PP virtual engine. Though this code path
    # only runs for v1, we have to do this to unify with the interface
    # of Attention + v0 PP.
    kv_cache: list[Iterable[torch.Tensor]]

    @abstractmethod
    def get_state_shape(self) -> Iterable[tuple[int, ...]]:
        """
        Defines the shape of the state.
        For mamba layers this is usually a (conv_state, ssm_state) tuple.
        In this case, returns (conv_state_shape, ssm_state_shape).
        """
        pass

    @property
    @abstractmethod
    def mamba_type(self) -> str:
        pass

    @abstractmethod
    def get_attn_backend(self) -> type["AttentionBackend"]:
        """Get the attention backend class for this Mamba layer."""
        pass

kv_cache instance-attribute

kv_cache: list[Iterable[Tensor]]

mamba_type abstractmethod property

mamba_type: str

get_attn_backend abstractmethod

get_attn_backend() -> type[AttentionBackend]

Get the attention backend class for this Mamba layer.

Source code in vllm/model_executor/layers/mamba/abstract.py
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
    """Get the attention backend class for this Mamba layer."""
    pass

get_state_shape abstractmethod

get_state_shape() -> Iterable[tuple[int, ...]]

Defines the shape of the state. For mamba layers this is usually a (conv_state, ssm_state) tuple. In this case, returns (conv_state_shape, ssm_state_shape).

Source code in vllm/model_executor/layers/mamba/abstract.py
@abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
    """
    Defines the shape of the state.
    For mamba layers this is usually a (conv_state, ssm_state) tuple.
    In this case, returns (conv_state_shape, ssm_state_shape).
    """
    pass