vllm.v1.attention.backends.mla.flashmla
FlashMLABackend ¶
Bases: MLACommonBackend
Source code in vllm/v1/attention/backends/mla/flashmla.py
get_builder_cls staticmethod
¶
get_builder_cls() -> type[FlashMLAMetadataBuilder]
get_impl_cls staticmethod
¶
get_impl_cls() -> type[FlashMLAImpl]
get_metadata_cls staticmethod
¶
get_metadata_cls() -> type[FlashMLAMetadata]
FlashMLADecodeMetadata dataclass
¶
Bases: MLACommonDecodeMetadata
Source code in vllm/v1/attention/backends/mla/flashmla.py
FlashMLAImpl ¶
Bases: MLACommonImpl[FlashMLAMetadata]
Source code in vllm/v1/attention/backends/mla/flashmla.py
__init__ ¶
__init__(
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/flashmla.py
_forward_decode ¶
_forward_decode(
q_nope: Tensor,
q_pe: Tensor,
kv_c_and_k_pe_cache: Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla.py
FlashMLAMetadata dataclass
¶
Bases: MLACommonMetadata[FlashMLADecodeMetadata]
Source code in vllm/v1/attention/backends/mla/flashmla.py
__init__ ¶
__init__(
num_reqs: int,
max_query_len: int,
num_actual_tokens: int,
query_start_loc: Tensor,
slot_mapping: Tensor,
num_decodes: int,
num_decode_tokens: int,
num_prefills: int,
head_dim: Optional[int] = None,
decode: Optional[D] = None,
prefill: Optional[
Union[
MLACommonPrefillMetadata,
FlashInferPrefillMetadata,
CudnnPrefillMetadata,
]
] = None,
) -> None
FlashMLAMetadataBuilder ¶
Bases: MLACommonMetadataBuilder[FlashMLAMetadata]
Source code in vllm/v1/attention/backends/mla/flashmla.py
__init__ ¶
__init__(
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: device,
)
Source code in vllm/v1/attention/backends/mla/flashmla.py
_build_decode ¶
_build_decode(
block_table_tensor: Tensor, seq_lens: Tensor
) -> FlashMLADecodeMetadata