Skip to content

vllm.v1.attention.backends.mamba1_attn

Mamba1AttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mamba1_attn.py
class Mamba1AttentionBackend(AttentionBackend):

    @staticmethod
    def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
        return Mamba1AttentionMetadataBuilder

get_builder_cls staticmethod

get_builder_cls() -> type[Mamba1AttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/mamba1_attn.py
@staticmethod
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
    return Mamba1AttentionMetadataBuilder

Mamba1AttentionMetadata dataclass

Source code in vllm/v1/attention/backends/mamba1_attn.py
@dataclass
class Mamba1AttentionMetadata:
    query_start_loc: torch.Tensor
    context_lens_tensor: torch.Tensor
    state_indices_tensor: torch.Tensor
    has_initial_states: Optional[torch.Tensor]
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    num_padded_decodes: int

context_lens_tensor instance-attribute

context_lens_tensor: Tensor

has_initial_states instance-attribute

has_initial_states: Optional[Tensor]

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_padded_decodes instance-attribute

num_padded_decodes: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

query_start_loc instance-attribute

query_start_loc: Tensor

state_indices_tensor instance-attribute

state_indices_tensor: Tensor

__init__

__init__(
    query_start_loc: Tensor,
    context_lens_tensor: Tensor,
    state_indices_tensor: Tensor,
    has_initial_states: Optional[Tensor],
    num_prefills: int,
    num_prefill_tokens: int,
    num_decodes: int,
    num_decode_tokens: int,
    num_padded_decodes: int,
) -> None

Mamba1AttentionMetadataBuilder

Bases: BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]

Source code in vllm/v1/attention/backends/mamba1_attn.py
class Mamba1AttentionMetadataBuilder(
        BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]):

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> Mamba1AttentionMetadata:
        query_start_loc = common_attn_metadata.query_start_loc

        state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
        context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
            query_start_loc.device)

        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(common_attn_metadata,
                                       decode_threshold=1))

        has_initial_states = None
        padded_decodes = num_decodes

        if num_prefills > 0:
            has_initial_states = context_lens_tensor > 0
        elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
              and self.compilation_config.full_cuda_graph):
            state_indices_for_decode = state_indices_tensor[:num_decodes]
            padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
            self.state_indices_tensor[:num_decodes].copy_(
                state_indices_for_decode, non_blocking=True)
            state_indices_tensor = self.state_indices_tensor[:padded_decodes]
            state_indices_tensor[num_decodes:] = PAD_SLOT_ID

        return Mamba1AttentionMetadata(
            query_start_loc=query_start_loc,
            context_lens_tensor=context_lens_tensor,
            has_initial_states=has_initial_states,
            state_indices_tensor=state_indices_tensor,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_padded_decodes=padded_decodes,
        )

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> Mamba1AttentionMetadata
Source code in vllm/v1/attention/backends/mamba1_attn.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> Mamba1AttentionMetadata:
    query_start_loc = common_attn_metadata.query_start_loc

    state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
    context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
        query_start_loc.device)

    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
        split_decodes_and_prefills(common_attn_metadata,
                                   decode_threshold=1))

    has_initial_states = None
    padded_decodes = num_decodes

    if num_prefills > 0:
        has_initial_states = context_lens_tensor > 0
    elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
          and self.compilation_config.full_cuda_graph):
        state_indices_for_decode = state_indices_tensor[:num_decodes]
        padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
        self.state_indices_tensor[:num_decodes].copy_(
            state_indices_for_decode, non_blocking=True)
        state_indices_tensor = self.state_indices_tensor[:padded_decodes]
        state_indices_tensor[num_decodes:] = PAD_SLOT_ID

    return Mamba1AttentionMetadata(
        query_start_loc=query_start_loc,
        context_lens_tensor=context_lens_tensor,
        has_initial_states=has_initial_states,
        state_indices_tensor=state_indices_tensor,
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        num_padded_decodes=padded_decodes,
    )