Skip to content

vllm.v1.attention.backends.linear_attn

LinearAttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/linear_attn.py
class LinearAttentionBackend(AttentionBackend):

    @staticmethod
    def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
        return LinearAttentionMetadataBuilder

get_builder_cls staticmethod

get_builder_cls() -> type[LinearAttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/linear_attn.py
@staticmethod
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
    return LinearAttentionMetadataBuilder

LinearAttentionMetadata dataclass

Source code in vllm/v1/attention/backends/linear_attn.py
@dataclass
class LinearAttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    query_start_loc: torch.Tensor
    seq_lens: torch.Tensor

    state_indices_tensor: torch.Tensor  # shape: [batch,]

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

query_start_loc instance-attribute

query_start_loc: Tensor

seq_lens instance-attribute

seq_lens: Tensor

state_indices_tensor instance-attribute

state_indices_tensor: Tensor

__init__

__init__(
    num_prefills: int,
    num_prefill_tokens: int,
    num_decodes: int,
    num_decode_tokens: int,
    query_start_loc: Tensor,
    seq_lens: Tensor,
    state_indices_tensor: Tensor,
) -> None

LinearAttentionMetadataBuilder

Bases: AttentionMetadataBuilder[LinearAttentionMetadata]

Source code in vllm/v1/attention/backends/linear_attn.py
class LinearAttentionMetadataBuilder(
        AttentionMetadataBuilder[LinearAttentionMetadata]):

    reorder_batch_threshold: ClassVar[int] = 1

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        assert isinstance(kv_cache_spec, MambaSpec)
        self.kv_cache_spec = kv_cache_spec

    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> LinearAttentionMetadata:
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens

        state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(common_attn_metadata,
                                       decode_threshold=1))

        attn_metadata = LinearAttentionMetadata(
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            query_start_loc=query_start_loc,
            seq_lens=seq_lens,
            state_indices_tensor=state_indices_tensor,
        )
        return attn_metadata

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

reorder_batch_threshold class-attribute

reorder_batch_threshold: int = 1

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/linear_attn.py
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):
    assert isinstance(kv_cache_spec, MambaSpec)
    self.kv_cache_spec = kv_cache_spec

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> LinearAttentionMetadata
Source code in vllm/v1/attention/backends/linear_attn.py
def build(self,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata,
          fast_build: bool = False) -> LinearAttentionMetadata:
    query_start_loc = common_attn_metadata.query_start_loc
    seq_lens = common_attn_metadata.seq_lens

    state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
        split_decodes_and_prefills(common_attn_metadata,
                                   decode_threshold=1))

    attn_metadata = LinearAttentionMetadata(
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        query_start_loc=query_start_loc,
        seq_lens=seq_lens,
        state_indices_tensor=state_indices_tensor,
    )
    return attn_metadata