Skip to content

vllm.v1.attention.backends.short_conv_attn

ShortConvAttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/short_conv_attn.py
class ShortConvAttentionBackend(AttentionBackend):

    @staticmethod
    def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
        return ShortConvAttentionMetadataBuilder

get_builder_cls staticmethod

get_builder_cls() -> type[
    ShortConvAttentionMetadataBuilder
]
Source code in vllm/v1/attention/backends/short_conv_attn.py
@staticmethod
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
    return ShortConvAttentionMetadataBuilder

ShortConvAttentionMetadata dataclass

Source code in vllm/v1/attention/backends/short_conv_attn.py
@dataclass
class ShortConvAttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int

    query_start_loc: torch.Tensor
    has_initial_states: torch.Tensor
    state_indices_tensor: torch.Tensor  # shape: [batch,]

    # For causal_conv1d
    nums_dict: Optional[dict] = None
    cu_seqlen: Optional[int] = None
    batch_ptr: Optional[torch.tensor] = None
    token_chunk_offset_ptr: Optional[torch.tensor] = None

batch_ptr class-attribute instance-attribute

batch_ptr: Optional[tensor] = None

cu_seqlen class-attribute instance-attribute

cu_seqlen: Optional[int] = None

has_initial_states instance-attribute

has_initial_states: Tensor

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

nums_dict class-attribute instance-attribute

nums_dict: Optional[dict] = None

query_start_loc instance-attribute

query_start_loc: Tensor

state_indices_tensor instance-attribute

state_indices_tensor: Tensor

token_chunk_offset_ptr class-attribute instance-attribute

token_chunk_offset_ptr: Optional[tensor] = None

__init__

__init__(
    num_prefills: int,
    num_prefill_tokens: int,
    num_decodes: int,
    num_decode_tokens: int,
    query_start_loc: Tensor,
    has_initial_states: Tensor,
    state_indices_tensor: Tensor,
    nums_dict: Optional[dict] = None,
    cu_seqlen: Optional[int] = None,
    batch_ptr: Optional[tensor] = None,
    token_chunk_offset_ptr: Optional[tensor] = None,
) -> None

ShortConvAttentionMetadataBuilder

Bases: AttentionMetadataBuilder[ShortConvAttentionMetadata]

Source code in vllm/v1/attention/backends/short_conv_attn.py
class ShortConvAttentionMetadataBuilder(
        AttentionMetadataBuilder[ShortConvAttentionMetadata]):

    reorder_batch_threshold: ClassVar[int] = 1

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        assert isinstance(kv_cache_spec, MambaSpec)
        self.kv_cache_spec = kv_cache_spec

    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> ShortConvAttentionMetadata:
        num_reqs = common_attn_metadata.num_reqs
        query_start_loc = common_attn_metadata.query_start_loc

        state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(common_attn_metadata,
                                       decode_threshold=1))
        has_initial_states = None
        if num_prefills > 0:
            #[batch,]
            has_initial_states_cpu = (
                common_attn_metadata.
                num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
            has_initial_states = has_initial_states_cpu.to(
                query_start_loc.device)

        attn_metadata = ShortConvAttentionMetadata(
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            query_start_loc=query_start_loc,
            has_initial_states=has_initial_states,
            state_indices_tensor=state_indices_tensor,
        )
        return attn_metadata

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

reorder_batch_threshold class-attribute

reorder_batch_threshold: int = 1

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/short_conv_attn.py
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):
    assert isinstance(kv_cache_spec, MambaSpec)
    self.kv_cache_spec = kv_cache_spec

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> ShortConvAttentionMetadata
Source code in vllm/v1/attention/backends/short_conv_attn.py
def build(self,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata,
          fast_build: bool = False) -> ShortConvAttentionMetadata:
    num_reqs = common_attn_metadata.num_reqs
    query_start_loc = common_attn_metadata.query_start_loc

    state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
        split_decodes_and_prefills(common_attn_metadata,
                                   decode_threshold=1))
    has_initial_states = None
    if num_prefills > 0:
        #[batch,]
        has_initial_states_cpu = (
            common_attn_metadata.
            num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
        has_initial_states = has_initial_states_cpu.to(
            query_start_loc.device)

    attn_metadata = ShortConvAttentionMetadata(
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        query_start_loc=query_start_loc,
        has_initial_states=has_initial_states,
        state_indices_tensor=state_indices_tensor,
    )
    return attn_metadata