Skip to content

vllm.v1.attention.backends.mamba_attn

M module-attribute

M = TypeVar('M')

BaseMambaAttentionMetadataBuilder

Bases: AttentionMetadataBuilder[M], ABC

Source code in vllm/v1/attention/backends/mamba_attn.py
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
    reorder_batch_threshold: ClassVar[int] = 1
    cudagraph_support: ClassVar[AttentionCGSupport] = \
        AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        assert isinstance(kv_cache_spec, MambaSpec)
        self.kv_cache_spec = kv_cache_spec
        self.device = device
        self.vllm_config = vllm_config
        self.layer_names = layer_names

        self.compilation_config = vllm_config.compilation_config
        self.decode_cudagraph_max_bs = min(
            self.vllm_config.scheduler_config.max_num_seqs,
            self.compilation_config.max_capture_size)
        self.state_indices_tensor = torch.empty(
            (self.decode_cudagraph_max_bs, ),
            dtype=torch.int32,
            device=device,
        )

    def build_for_cudagraph_capture(
            self, common_attn_metadata: CommonAttentionMetadata) -> M:
        """
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with Mamba.
        """
        m = common_attn_metadata

        assert m.num_reqs == m.num_actual_tokens, \
            "Mamba only supports decode-only full CUDAGraph capture. " \
            "Make sure all cudagraph capture sizes <= max_num_seq."

        m.max_query_len = 1  # decode-only

        return self.build(0, m)

compilation_config instance-attribute

compilation_config = compilation_config

cudagraph_support class-attribute

decode_cudagraph_max_bs instance-attribute

decode_cudagraph_max_bs = min(
    max_num_seqs, max_capture_size
)

device instance-attribute

device = device

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

layer_names instance-attribute

layer_names = layer_names

reorder_batch_threshold class-attribute

reorder_batch_threshold: int = 1

state_indices_tensor instance-attribute

state_indices_tensor = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/mamba_attn.py
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):
    assert isinstance(kv_cache_spec, MambaSpec)
    self.kv_cache_spec = kv_cache_spec
    self.device = device
    self.vllm_config = vllm_config
    self.layer_names = layer_names

    self.compilation_config = vllm_config.compilation_config
    self.decode_cudagraph_max_bs = min(
        self.vllm_config.scheduler_config.max_num_seqs,
        self.compilation_config.max_capture_size)
    self.state_indices_tensor = torch.empty(
        (self.decode_cudagraph_max_bs, ),
        dtype=torch.int32,
        device=device,
    )

build_for_cudagraph_capture

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
) -> M

This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba.

Source code in vllm/v1/attention/backends/mamba_attn.py
def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata) -> M:
    """
    This method builds the metadata for full cudagraph capture.
    Currently, only decode is supported for full cudagraphs with Mamba.
    """
    m = common_attn_metadata

    assert m.num_reqs == m.num_actual_tokens, \
        "Mamba only supports decode-only full CUDAGraph capture. " \
        "Make sure all cudagraph capture sizes <= max_num_seq."

    m.max_query_len = 1  # decode-only

    return self.build(0, m)