vllm.attention.ops.flashmla
flash_mla_with_kvcache ¶
flash_mla_with_kvcache(
q: Tensor,
k_cache: Tensor,
block_table: Tensor,
cache_seqlens: Tensor,
head_dim_v: int,
tile_scheduler_metadata: Tensor,
num_splits: Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[Tensor] = None,
descale_k: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q | Tensor | (batch_size, seq_len_q, num_heads_q, head_dim). | required |
k_cache | Tensor | (num_blocks, page_block_size, num_heads_k, head_dim). | required |
block_table | Tensor | (batch_size, max_num_blocks_per_seq), torch.int32. | required |
cache_seqlens | Tensor | (batch_size), torch.int32. | required |
head_dim_v | int | Head_dim of v. | required |
tile_scheduler_metadata | Tensor | (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. | required |
num_splits | Tensor | (batch_size + 1), torch.int32, return by get_mla_metadata. | required |
softmax_scale | Optional[float] | float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). | None |
causal | bool | bool. Whether to apply causal attention mask. | False |
descale_q | Optional[Tensor] | (batch_size), torch.float32. Descaling factors for Q. | None |
descale_k | Optional[Tensor] | (batch_size), torch.float32. Descaling factors for K. | None |
Return
out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
Source code in vllm/attention/ops/flashmla.py
get_mla_metadata ¶
get_mla_metadata(
cache_seqlens: Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[Tensor, Tensor]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cache_seqlens | Tensor | (batch_size), dtype torch.int32. | required |
num_heads_per_head_k | int | Equals to seq_len_q * num_heads_q // num_heads_k. | required |
num_heads_k | int | num_heads_k. | required |
Return
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32.
Source code in vllm/attention/ops/flashmla.py
is_flashmla_supported ¶
Return: is_supported_flag, unsupported_reason (optional).