Skip to content

vllm.v1.attention.backends.flashinfer

Attention layer with FlashInfer.

FLASHINFER_WORKSPACE_BUFFER_SIZE module-attribute

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024

FP4_DTYPE module-attribute

FP4_DTYPE = uint8

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

logger module-attribute

logger = init_logger(__name__)

FlashInferBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/flashinfer.py
class FlashInferBackend(AttentionBackend):

    accept_output_buffer: bool = True

    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        return [64, 128, 256]

    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

    @staticmethod
    def get_name() -> str:
        return "FLASHINFER_VLLM_V1"

    @staticmethod
    def get_impl_cls() -> type[FlashInferImpl]:
        return FlashInferImpl

    @staticmethod
    def get_metadata_cls() -> type[FlashInferMetadata]:
        return FlashInferMetadata

    @staticmethod
    def get_builder_cls() -> type[FlashInferMetadataBuilder]:
        return FlashInferMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
        # `stride_order` indicates the permutation that gets us from
        # `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
        if cache_layout == "NHD":
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[FlashInferMetadataBuilder]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_builder_cls() -> type[FlashInferMetadataBuilder]:
    return FlashInferMetadataBuilder

get_fp8_dtype_for_flashinfer staticmethod

get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> dtype
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
    if kv_cache_dtype in ("fp8", "fp8_e4m3"):
        return torch.float8_e4m3fn
    elif kv_cache_dtype == "fp8_e5m2":
        return torch.float8_e5m2
    else:
        raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

get_impl_cls staticmethod

get_impl_cls() -> type[FlashInferImpl]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_impl_cls() -> type[FlashInferImpl]:
    return FlashInferImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    return (num_blocks, 2, block_size, num_kv_heads, head_size)

get_kv_cache_stride_order staticmethod

get_kv_cache_stride_order() -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
    # `stride_order` indicates the permutation that gets us from
    # `get_kv_cache_shape` to the actual memory layout we want.
    cache_layout = get_kv_cache_layout()
    if cache_layout == "NHD":
        stride_order = (0, 1, 2, 3, 4)
    elif cache_layout == "HND":
        stride_order = (0, 1, 3, 2, 4)
    else:
        raise ValueError(f"Unknown cache layout format {cache_layout}.")
    return stride_order

get_metadata_cls staticmethod

get_metadata_cls() -> type[FlashInferMetadata]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_metadata_cls() -> type[FlashInferMetadata]:
    return FlashInferMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_name() -> str:
    return "FLASHINFER_VLLM_V1"

get_supported_dtypes classmethod

get_supported_dtypes() -> list[dtype]
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
    return [torch.float16, torch.bfloat16]

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
    return [64, 128, 256]

validate_head_size classmethod

validate_head_size(head_size: int) -> None
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def validate_head_size(cls, head_size: int) -> None:
    supported_head_sizes = cls.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        attn_type = cls.__name__.removesuffix("Backend")
        raise ValueError(
            f"Head size {head_size} is not supported by {attn_type}. "
            f"Supported head sizes are: {supported_head_sizes}. "
            "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
            "FlexAttention backend which supports all head sizes.")

FlashInferImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/flashinfer.py
class FlashInferImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
        attn_type: AttentionType = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[int] = None,
        sinks: Optional[torch.Tensor] = None,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
        self.window_left = (self.sliding_window[0]
                            if self.sliding_window is not None else -1)
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashInferImpl")

        self.sinks: Optional[torch.Tensor] = None
        if sinks is not None:
            if sinks.shape[0] != num_heads:
                raise ValueError(
                    "Sinks must have the same number of heads as the number of "
                    f"heads in the layer. Expected {num_heads}, but got "
                    f"{sinks.shape[0]}.")
            self.sinks = sinks

        self.support_trtllm_attn = (supports_trtllm_attention()
                                    and num_heads % num_kv_heads == 0)
        self.bmm1_scale: Optional[float] = None
        self.bmm2_scale: Optional[float] = None
        self.o_sf_scale: Optional[float] = None

    def fused_output_quant_supported(self, quant_key: QuantKey):
        return (self.support_trtllm_attn
                and self.kv_cache_dtype.startswith("fp8")
                and quant_key in (kFp8StaticTensorSym, kNvfp4Quant))

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashInferMetadata,
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
        output_block_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache: shape -
            # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
            # HND: [num_blocks, 2,  num_kv_heads, block_size, head_size]


            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
            # Profiling run.
            return output

        if self.bmm1_scale is None:
            self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
                               self.scale)

        if self.bmm2_scale is None:
            self.bmm2_scale = layer._v_scale_float

        # The attn+quant fusion happens when output_scale is provided.
        if output_scale is None:
            assert attn_metadata.q_data_type != FP8_DTYPE, \
                "Query can only be FP8 if output fusion happened."
            assert output_block_scale is None, "output_block_scale "\
                "is not supported when fusion has not happened"
        else:
            assert attn_metadata.q_data_type == FP8_DTYPE, \
                "Query must be FP8 when attn+quant fusion happened."
            assert (attn_metadata.prefill_use_trtllm and
                    attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"

            if output.dtype == FP8_DTYPE:
                assert output_block_scale is None, \
                    "output_block_scale should not be provided for fp8 output"
            elif output.dtype == FP4_DTYPE:
                assert output_block_scale is not None, \
                    "output_block_scale is required for nvfp4 output"
            else:
                raise ValueError(f"Unsupported output dtype: {output.dtype}")

            # TRTLLM attn kernel requires o scale to pass as a host scalar,
            # store the o scale as a host scalar in warmup run with cuda graph
            # not enabled
            if layer._o_scale_float is None:
                layer._o_scale_float = output_scale.cpu().item()
                if output.dtype == FP8_DTYPE:
                    self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
                elif output.dtype == FP4_DTYPE:
                    self.o_sf_scale = layer._o_scale_float

            # Insert FP8 quant for query
            num_tokens, num_heads, head_size = query.shape
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if self.kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    self.kv_cache_dtype)
                kv_cache = kv_cache.view(torch_dtype)

        # Inputs and outputs may be padded for CUDA graphs
        query = query[:num_actual_tokens]
        output_padded = output
        output = output[:num_actual_tokens]

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefill_tokens = attn_metadata.num_prefill_tokens

        stride_order = FlashInferBackend.get_kv_cache_stride_order()
        kv_cache_permute = kv_cache.permute(*stride_order)
        # Regular attention (common case).
        # Decodes are at the front and prefills are at the back,
        # according to reorder_batch()
        if num_prefill_tokens > 0:
            prefill_wrapper = attn_metadata.prefill_wrapper
            prefill_query = query[num_decode_tokens:]
            assert prefill_query.shape[0] == num_prefill_tokens
            assert prefill_wrapper is not None

            if not attn_metadata.prefill_use_trtllm:
                assert prefill_wrapper._causal
                assert prefill_wrapper._window_left == self.window_left
                assert prefill_wrapper._logits_soft_cap == (
                    self.logits_soft_cap or 0.0)
                assert prefill_wrapper._sm_scale == self.scale
                prefill_wrapper.run(
                    prefill_query,
                    kv_cache_permute,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[num_decode_tokens:],
                )
            else:
                # prefill_query may be non-contiguous
                prefill_query = prefill_query.contiguous()
                workspace_buffer = prefill_wrapper._float_workspace_buffer
                block_tables_prefill = attn_metadata.block_table_tensor[
                    num_decode_tokens:]
                seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
                assert prefill_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_prefill.is_contiguous()
                assert seq_lens_prefill.is_contiguous()

                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
                    out = FP4Tensor(data=output[num_decode_tokens:],
                                    scale=output_block_scale,
                                    scale_start_index=num_decode_tokens,
                                    original_shape=prefill_query.shape)
                else:
                    assert self.o_sf_scale is None
                    out = output[num_decode_tokens:]

                trtllm_batch_context_with_kv_cache(
                    query=prefill_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_prefill,
                    seq_lens=seq_lens_prefill,
                    max_q_len=attn_metadata.max_q_len,
                    max_kv_len=attn_metadata.max_seq_len,
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
                    batch_size=attn_metadata.num_prefills,
                    cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
                    cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
                    window_left=self.window_left,
                    sinks=self.sinks,
                    o_sf_scale=self.o_sf_scale,
                    out=out,
                )

        if num_decode_tokens > 0:
            decode_wrapper = attn_metadata.decode_wrapper
            decode_query = query[:num_decode_tokens]
            assert decode_query.shape[0] == num_decode_tokens
            assert decode_wrapper is not None

            if not attn_metadata.decode_use_trtllm:
                assert decode_wrapper._window_left == self.window_left
                assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
                                                           or 0.0)
                assert decode_wrapper._sm_scale == self.scale
                decode_wrapper.run(
                    decode_query,
                    kv_cache_permute,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[:num_decode_tokens],
                )
            else:
                # decode_query may be non-contiguous
                decode_query = decode_query.contiguous()
                workspace_buffer = decode_wrapper._float_workspace_buffer
                block_tables_decode = attn_metadata.\
                        block_table_tensor[:num_decode_tokens]
                seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
                assert decode_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_decode.is_contiguous()
                assert seq_lens_decode.is_contiguous()

                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
                    out = FP4Tensor(data=output[:num_decode_tokens],
                                    scale=output_block_scale,
                                    scale_start_index=0,
                                    original_shape=decode_query.shape)
                else:
                    assert self.o_sf_scale is None
                    out = output[:num_decode_tokens]

                trtllm_batch_decode_with_kv_cache(
                    query=decode_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_decode,
                    seq_lens=seq_lens_decode,
                    max_seq_len=attn_metadata.max_seq_len,
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
                    window_left=self.window_left,
                    sinks=self.sinks,
                    o_sf_scale=self.o_sf_scale,
                    out=out,
                )
        return output_padded

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

bmm1_scale instance-attribute

bmm1_scale: Optional[float] = None

bmm2_scale instance-attribute

bmm2_scale: Optional[float] = None

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

o_sf_scale instance-attribute

o_sf_scale: Optional[float] = None

scale instance-attribute

scale = float(scale)

sinks instance-attribute

sinks: Optional[Tensor] = None

sliding_window instance-attribute

sliding_window = (-1, -1)

support_trtllm_attn instance-attribute

support_trtllm_attn = (
    supports_trtllm_attention()
    and num_heads % num_kv_heads == 0
)

window_left instance-attribute

window_left = (
    sliding_window[0] if sliding_window is not None else -1
)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = DECODER,
    kv_sharing_target_layer_name: Optional[int] = None,
    sinks: Optional[Tensor] = None,
) -> None
Source code in vllm/v1/attention/backends/flashinfer.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[int] = None,
    sinks: Optional[torch.Tensor] = None,
) -> None:
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    if sliding_window is None:
        self.sliding_window = (-1, -1)
    else:
        self.sliding_window = (sliding_window - 1, 0)
    self.window_left = (self.sliding_window[0]
                        if self.sliding_window is not None else -1)
    self.kv_cache_dtype = kv_cache_dtype
    self.logits_soft_cap = logits_soft_cap
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "FlashInferImpl")

    self.sinks: Optional[torch.Tensor] = None
    if sinks is not None:
        if sinks.shape[0] != num_heads:
            raise ValueError(
                "Sinks must have the same number of heads as the number of "
                f"heads in the layer. Expected {num_heads}, but got "
                f"{sinks.shape[0]}.")
        self.sinks = sinks

    self.support_trtllm_attn = (supports_trtllm_attention()
                                and num_heads % num_kv_heads == 0)
    self.bmm1_scale: Optional[float] = None
    self.bmm2_scale: Optional[float] = None
    self.o_sf_scale: Optional[float] = None

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlashInferMetadata,
    output: Optional[Tensor] = None,
    output_scale: Optional[Tensor] = None,
    output_block_scale: Optional[Tensor] = None,
) -> Tensor

Forward pass with FlashInfer.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
kv_cache Tensor

shape -

required
# NHD

[num_blocks, 2, block_size, num_kv_heads, head_size]

required
# HND

[num_blocks, 2, num_kv_heads, block_size, head_size]

required
attn_metadata FlashInferMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/v1/attention/backends/flashinfer.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: FlashInferMetadata,
    output: Optional[torch.Tensor] = None,
    output_scale: Optional[torch.Tensor] = None,
    output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with FlashInfer.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache: shape -
        # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
        # HND: [num_blocks, 2,  num_kv_heads, block_size, head_size]


        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    assert output is not None, "Output tensor must be provided."

    if attn_metadata is None:
        # Profiling run.
        return output

    if self.bmm1_scale is None:
        self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
                           self.scale)

    if self.bmm2_scale is None:
        self.bmm2_scale = layer._v_scale_float

    # The attn+quant fusion happens when output_scale is provided.
    if output_scale is None:
        assert attn_metadata.q_data_type != FP8_DTYPE, \
            "Query can only be FP8 if output fusion happened."
        assert output_block_scale is None, "output_block_scale "\
            "is not supported when fusion has not happened"
    else:
        assert attn_metadata.q_data_type == FP8_DTYPE, \
            "Query must be FP8 when attn+quant fusion happened."
        assert (attn_metadata.prefill_use_trtllm and
                attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"

        if output.dtype == FP8_DTYPE:
            assert output_block_scale is None, \
                "output_block_scale should not be provided for fp8 output"
        elif output.dtype == FP4_DTYPE:
            assert output_block_scale is not None, \
                "output_block_scale is required for nvfp4 output"
        else:
            raise ValueError(f"Unsupported output dtype: {output.dtype}")

        # TRTLLM attn kernel requires o scale to pass as a host scalar,
        # store the o scale as a host scalar in warmup run with cuda graph
        # not enabled
        if layer._o_scale_float is None:
            layer._o_scale_float = output_scale.cpu().item()
            if output.dtype == FP8_DTYPE:
                self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
            elif output.dtype == FP4_DTYPE:
                self.o_sf_scale = layer._o_scale_float

        # Insert FP8 quant for query
        num_tokens, num_heads, head_size = query.shape
        query, _ = ops.scaled_fp8_quant(
            query.reshape(
                (num_tokens, num_heads * head_size)).contiguous(),
            layer._q_scale)
        query = query.reshape((num_tokens, num_heads, head_size))

    # IMPORTANT!
    # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
    # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
    # in this method. For example, `view` and `slice` (or `[:n]`) operations
    # are surprisingly slow even in the case they do not invoke any GPU ops.
    # Minimize the PyTorch ops in this method as much as possible.
    # Whenever making a change in this method, please benchmark the
    # performance to make sure it does not introduce any overhead.

    num_actual_tokens = attn_metadata.num_actual_tokens

    if self.kv_sharing_target_layer_name is None:
        # Reshape the input keys and values and store them in the cache.
        # Skip this if sharing KV cache with an earlier attention layer.
        # NOTE(woosuk): Here, key and value are padded while slot_mapping is
        # not padded. However, we don't need to do key[:num_actual_tokens]
        # and value[:num_actual_tokens] because the reshape_and_cache_flash
        # op uses the slot_mapping's shape to determine the number of
        # actual tokens.
        torch.ops._C_cache_ops.reshape_and_cache_flash(
            key,
            value,
            kv_cache[:, 0],
            kv_cache[:, 1],
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

        # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
        # to process the cache when the kv_cache_dtype is fp8
        if self.kv_cache_dtype.startswith("fp8"):
            torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.kv_cache_dtype)
            kv_cache = kv_cache.view(torch_dtype)

    # Inputs and outputs may be padded for CUDA graphs
    query = query[:num_actual_tokens]
    output_padded = output
    output = output[:num_actual_tokens]

    if attn_metadata.use_cascade:
        # Cascade attention (rare case).
        assert attn_metadata.cascade_wrapper is not None
        output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
        return output

    num_decode_tokens = attn_metadata.num_decode_tokens
    num_prefill_tokens = attn_metadata.num_prefill_tokens

    stride_order = FlashInferBackend.get_kv_cache_stride_order()
    kv_cache_permute = kv_cache.permute(*stride_order)
    # Regular attention (common case).
    # Decodes are at the front and prefills are at the back,
    # according to reorder_batch()
    if num_prefill_tokens > 0:
        prefill_wrapper = attn_metadata.prefill_wrapper
        prefill_query = query[num_decode_tokens:]
        assert prefill_query.shape[0] == num_prefill_tokens
        assert prefill_wrapper is not None

        if not attn_metadata.prefill_use_trtllm:
            assert prefill_wrapper._causal
            assert prefill_wrapper._window_left == self.window_left
            assert prefill_wrapper._logits_soft_cap == (
                self.logits_soft_cap or 0.0)
            assert prefill_wrapper._sm_scale == self.scale
            prefill_wrapper.run(
                prefill_query,
                kv_cache_permute,
                k_scale=layer._k_scale_float,
                v_scale=layer._v_scale_float,
                out=output[num_decode_tokens:],
            )
        else:
            # prefill_query may be non-contiguous
            prefill_query = prefill_query.contiguous()
            workspace_buffer = prefill_wrapper._float_workspace_buffer
            block_tables_prefill = attn_metadata.block_table_tensor[
                num_decode_tokens:]
            seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:]

            # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
            assert get_kv_cache_layout() == "HND"
            assert prefill_query.is_contiguous()
            assert kv_cache_permute.is_contiguous()
            assert workspace_buffer.is_contiguous()
            assert block_tables_prefill.is_contiguous()
            assert seq_lens_prefill.is_contiguous()

            if output.dtype == FP4_DTYPE:
                assert self.o_sf_scale is not None
                out = FP4Tensor(data=output[num_decode_tokens:],
                                scale=output_block_scale,
                                scale_start_index=num_decode_tokens,
                                original_shape=prefill_query.shape)
            else:
                assert self.o_sf_scale is None
                out = output[num_decode_tokens:]

            trtllm_batch_context_with_kv_cache(
                query=prefill_query,
                kv_cache=kv_cache_permute,
                workspace_buffer=workspace_buffer,
                block_tables=block_tables_prefill,
                seq_lens=seq_lens_prefill,
                max_q_len=attn_metadata.max_q_len,
                max_kv_len=attn_metadata.max_seq_len,
                bmm1_scale=self.bmm1_scale,
                bmm2_scale=self.bmm2_scale,
                batch_size=attn_metadata.num_prefills,
                cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
                cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
                window_left=self.window_left,
                sinks=self.sinks,
                o_sf_scale=self.o_sf_scale,
                out=out,
            )

    if num_decode_tokens > 0:
        decode_wrapper = attn_metadata.decode_wrapper
        decode_query = query[:num_decode_tokens]
        assert decode_query.shape[0] == num_decode_tokens
        assert decode_wrapper is not None

        if not attn_metadata.decode_use_trtllm:
            assert decode_wrapper._window_left == self.window_left
            assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
                                                       or 0.0)
            assert decode_wrapper._sm_scale == self.scale
            decode_wrapper.run(
                decode_query,
                kv_cache_permute,
                k_scale=layer._k_scale_float,
                v_scale=layer._v_scale_float,
                out=output[:num_decode_tokens],
            )
        else:
            # decode_query may be non-contiguous
            decode_query = decode_query.contiguous()
            workspace_buffer = decode_wrapper._float_workspace_buffer
            block_tables_decode = attn_metadata.\
                    block_table_tensor[:num_decode_tokens]
            seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

            # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
            assert get_kv_cache_layout() == "HND"
            assert decode_query.is_contiguous()
            assert kv_cache_permute.is_contiguous()
            assert workspace_buffer.is_contiguous()
            assert block_tables_decode.is_contiguous()
            assert seq_lens_decode.is_contiguous()

            if output.dtype == FP4_DTYPE:
                assert self.o_sf_scale is not None
                out = FP4Tensor(data=output[:num_decode_tokens],
                                scale=output_block_scale,
                                scale_start_index=0,
                                original_shape=decode_query.shape)
            else:
                assert self.o_sf_scale is None
                out = output[:num_decode_tokens]

            trtllm_batch_decode_with_kv_cache(
                query=decode_query,
                kv_cache=kv_cache_permute,
                workspace_buffer=workspace_buffer,
                block_tables=block_tables_decode,
                seq_lens=seq_lens_decode,
                max_seq_len=attn_metadata.max_seq_len,
                bmm1_scale=self.bmm1_scale,
                bmm2_scale=self.bmm2_scale,
                window_left=self.window_left,
                sinks=self.sinks,
                o_sf_scale=self.o_sf_scale,
                out=out,
            )
    return output_padded

fused_output_quant_supported

fused_output_quant_supported(quant_key: QuantKey)
Source code in vllm/v1/attention/backends/flashinfer.py
def fused_output_quant_supported(self, quant_key: QuantKey):
    return (self.support_trtllm_attn
            and self.kv_cache_dtype.startswith("fp8")
            and quant_key in (kFp8StaticTensorSym, kNvfp4Quant))

FlashInferMetadata dataclass

Source code in vllm/v1/attention/backends/flashinfer.py
@dataclass
class FlashInferMetadata:

    num_actual_tokens: int  # Number of tokens excluding padding.

    # The data type of the query
    q_data_type: torch.dtype

    slot_mapping: torch.Tensor

    # For flashinfer trtllm batch decode
    max_q_len: int
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table_tensor: torch.Tensor
    prefill_use_trtllm: bool
    decode_use_trtllm: bool

    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

    # For cascade attention (CPU for planning).
    use_cascade: bool

    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
    cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None

    qo_indptr_gpu: Optional[torch.Tensor] = None
    paged_kv_indptr_gpu: Optional[torch.Tensor] = None

block_table_tensor instance-attribute

block_table_tensor: Tensor

cascade_wrapper class-attribute instance-attribute

cascade_wrapper: Optional[
    MultiLevelCascadeAttentionWrapper
] = None

decode_use_trtllm instance-attribute

decode_use_trtllm: bool

decode_wrapper class-attribute instance-attribute

decode_wrapper: Optional[
    BatchDecodeWithPagedKVCacheWrapper
] = None

max_q_len instance-attribute

max_q_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

paged_kv_indptr_gpu class-attribute instance-attribute

paged_kv_indptr_gpu: Optional[Tensor] = None

prefill_use_trtllm instance-attribute

prefill_use_trtllm: bool

prefill_wrapper class-attribute instance-attribute

prefill_wrapper: Optional[
    BatchPrefillWithPagedKVCacheWrapper
] = None

q_data_type instance-attribute

q_data_type: dtype

qo_indptr_gpu class-attribute instance-attribute

qo_indptr_gpu: Optional[Tensor] = None

seq_lens instance-attribute

seq_lens: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

use_cascade instance-attribute

use_cascade: bool

__init__

__init__(
    num_actual_tokens: int,
    q_data_type: dtype,
    slot_mapping: Tensor,
    max_q_len: int,
    max_seq_len: int,
    seq_lens: Tensor,
    block_table_tensor: Tensor,
    prefill_use_trtllm: bool,
    decode_use_trtllm: bool,
    num_decodes: int,
    num_decode_tokens: int,
    num_prefills: int,
    num_prefill_tokens: int,
    use_cascade: bool,
    prefill_wrapper: Optional[
        BatchPrefillWithPagedKVCacheWrapper
    ] = None,
    decode_wrapper: Optional[
        BatchDecodeWithPagedKVCacheWrapper
    ] = None,
    cascade_wrapper: Optional[
        MultiLevelCascadeAttentionWrapper
    ] = None,
    qo_indptr_gpu: Optional[Tensor] = None,
    paged_kv_indptr_gpu: Optional[Tensor] = None,
) -> None

FlashInferMetadataBuilder

Bases: AttentionMetadataBuilder[FlashInferMetadata]

Source code in vllm/v1/attention/backends/flashinfer.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
    cudagraph_support: ClassVar[AttentionCGSupport] = \
        AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

    reorder_batch_threshold: ClassVar[int] = 1

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        self.device = device
        self.vllm_config = vllm_config
        self.cache_config = vllm_config.cache_config
        self.model_config = vllm_config.model_config
        self.kv_cache_spec = kv_cache_spec
        self._workspace_buffer = None
        self._prefill_wrapper = None  # Wrapper for prefill/append
        self._decode_wrapper = None  # Wrapper for decode (general shape)

        self.compilation_config = vllm_config.compilation_config
        max_num_pages_per_req = cdiv(self.model_config.max_model_len,
                                     self.kv_cache_spec.block_size)
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        max_num_pages = max_num_reqs * max_num_pages_per_req
        self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
            decode_mode() == CUDAGraphMode.FULL
        if self.enable_cuda_graph:
            # For full cudagraph capture, one `decode_wrapper` for each batch
            # size is needed for FlashInfer.
            self._decode_wrappers_cudagraph: dict[
                int, BatchDecodeWithPagedKVCacheWrapper] = {}
            self._decode_cudagraph_max_bs = min(
                max_num_reqs, self.compilation_config.max_capture_size)

        self.num_qo_heads = self.model_config.get_num_attention_heads(
            self.vllm_config.parallel_config)
        self.num_kv_heads = self.kv_cache_spec.num_kv_heads
        self.head_dim = self.kv_cache_spec.head_size
        FlashInferBackend.validate_head_size(self.head_dim)
        self.page_size = self.kv_cache_spec.block_size

        self.enable_fusion = (
            self.compilation_config.pass_config.enable_attn_fusion)
        self.q_data_type = self.model_config.dtype
        self.cache_dtype = self.cache_config.cache_dtype
        if self.cache_dtype.startswith("fp8"):
            self.kv_cache_dtype = (
                FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    self.cache_dtype))
            # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
            if self.enable_fusion:
                self.q_data_type = self.kv_cache_dtype
        else:
            self.kv_cache_dtype = self.kv_cache_spec.dtype

        self._cascade_wrapper = None  # Wrapper for cascade attention

        # Global hyperparameters shared by all attention layers
        # TODO: discard this for trtllm-gen backend
        self.global_hyperparameters = infer_global_hyperparameters(
            get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))

        # Preparing persistent buffers (device-side)
        self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
                                           dtype=torch.int32,
                                           device=self.device)
        self.paged_kv_indices = torch.zeros(
            max_num_pages,  # max num pages possible
            dtype=torch.int32,
            device=self.device)
        self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
                                                  dtype=torch.int32,
                                                  device=self.device)
        # host-side buffer
        pin_memory = is_pin_memory_available()
        self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=pin_memory)
        self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
                                                dtype=torch.int32,
                                                device="cpu",
                                                pin_memory=pin_memory)
        self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs,
                                                      dtype=torch.int32,
                                                      device="cpu",
                                                      pin_memory=pin_memory)

        self.block_table_arange = torch.arange(max_num_pages_per_req,
                                               dtype=torch.int32,
                                               device=self.device)

    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
            self._workspace_buffer = torch.zeros(
                FLASHINFER_WORKSPACE_BUFFER_SIZE,
                dtype=torch.uint8,
                device=self.device)
        return self._workspace_buffer

    def _get_prefill_wrapper(self):
        if self._prefill_wrapper is None:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                self._get_workspace_buffer(), get_kv_cache_layout())
        return self._prefill_wrapper

    def _get_decode_wrapper(self,
                            batch_size: int,
                            use_cudagraph: bool = False):
        if use_cudagraph:
            decode_wrapper = self._decode_wrappers_cudagraph.get(
                batch_size, None)
        else:
            decode_wrapper = self._decode_wrapper

        if decode_wrapper is None:
            if use_cudagraph:
                paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
                paged_kv_indices = self.paged_kv_indices
                paged_kv_last_page_len = self.paged_kv_last_page_len[:
                                                                     batch_size]
            else:
                paged_kv_indptr = None
                paged_kv_indices = None
                paged_kv_last_page_len = None
            decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                self._get_workspace_buffer(),
                get_kv_cache_layout(),
                use_cuda_graph=use_cudagraph,
                paged_kv_indptr_buffer=paged_kv_indptr,
                paged_kv_indices_buffer=paged_kv_indices,
                paged_kv_last_page_len_buffer=paged_kv_last_page_len,
                # Tensor cores are enabled by default because the perf would be
                # atleast as good as cuda cores for all attention ops in latest
                # gpus.
                use_tensor_cores=True,
            )

            # save the decode wrapper
            if use_cudagraph:
                self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
            else:
                self._decode_wrapper = decode_wrapper

        return decode_wrapper

    def _get_cascade_wrapper(self):
        if self._cascade_wrapper is None:
            self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
                2, self._get_workspace_buffer(), get_kv_cache_layout())
        return self._cascade_wrapper

    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashInferMetadata:
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
            split_decodes_and_prefills(common_attn_metadata)

        page_size = self.page_size
        max_q_len = common_attn_metadata.max_query_len
        max_seq_len = common_attn_metadata.max_seq_len
        seq_lens = common_attn_metadata.seq_lens
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
        block_table_tensor = common_attn_metadata.block_table_tensor

        block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size

        use_cascade = common_prefix_len > 0
        if use_cascade:
            # Grab the blocks of the shared prefix from the first request.
            assert common_prefix_len % page_size == 0
            num_common_kv_blocks = common_prefix_len // page_size

            # Create CPU versions directly for cascade (no GPU versions needed)
            shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
                                                device='cpu')
            shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
                                                     dtype=torch.int32,
                                                     device='cpu')
            shared_kv_page_indices_cpu = block_table_tensor[
                0, :num_common_kv_blocks]
            shared_kv_last_page_len_cpu = torch.tensor([page_size],
                                                       dtype=torch.int32,
                                                       device='cpu')

            # Remove the blocks of the shared prefix from all requests.
            block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
            block_table_bounds_cpu -= num_common_kv_blocks
        else:
            shared_qo_indptr_cpu = None
            shared_kv_page_indptr_cpu = None
            shared_kv_page_indices_cpu = None
            shared_kv_last_page_len_cpu = None

        max_num_blocks = block_table_bounds_cpu.max().item()
        block_table_bounds = block_table_bounds_cpu.to(self.device,
                                                       non_blocking=True)
        mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
                < block_table_bounds.unsqueeze(1))
        # write self.paged_kv_indices inplace
        num_actual_pages = torch.sum(mask)
        paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
        torch.masked_select(block_table_tensor[:, :max_num_blocks],
                            mask,
                            out=paged_kv_indices)

        # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
        torch.cumsum(block_table_bounds_cpu,
                     dim=0,
                     dtype=torch.int32,
                     out=self.paged_kv_indptr_cpu[1:1 + num_reqs])

        paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
        # write self.paged_kv_last_page_len_cpu inplace
        torch.where(paged_kv_last_page_len_cpu == 0,
                    torch.tensor(page_size),
                    paged_kv_last_page_len_cpu,
                    out=self.paged_kv_last_page_len_cpu[:num_reqs])

        # Check if any layer uses sinks (requires TRTLLM attention)
        has_sinks = self.global_hyperparameters.has_sinks

        prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
                                                  self.num_kv_heads,
                                                  num_prefill_tokens,
                                                  max_seq_len,
                                                  self.cache_dtype,
                                                  self.q_data_type,
                                                  is_prefill=True,
                                                  has_sinks=has_sinks)
        decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
                                                 self.num_kv_heads,
                                                 num_decode_tokens,
                                                 max_seq_len,
                                                 self.cache_dtype,
                                                 self.q_data_type,
                                                 is_prefill=False,
                                                 has_sinks=has_sinks)

        attn_metadata = FlashInferMetadata(
            num_actual_tokens=num_actual_tokens,
            q_data_type=self.q_data_type,
            slot_mapping=common_attn_metadata.slot_mapping,
            max_q_len=max_q_len,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table_tensor=block_table_tensor,
            prefill_use_trtllm=prefill_use_trtllm,
            decode_use_trtllm=decode_use_trtllm,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            use_cascade=use_cascade,
        )

        qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
        paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs]
        paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]

        if attn_metadata.use_cascade:
            attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
            attn_metadata.cascade_wrapper.plan(
                [shared_qo_indptr_cpu, qo_indptr_cpu],
                [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
                [shared_kv_page_indices_cpu, paged_kv_indices],
                [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                causal=True,
                sm_scale=self.global_hyperparameters.sm_scale,
                window_left=self.global_hyperparameters.window_left,
                logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
                q_data_type=self.q_data_type,
                kv_data_type=self.kv_cache_dtype,
            )
        else:
            # Regular attention (common case).
            # Decodes are at the front and prefills are at the back,
            # according to reorder_batch()
            num_prefills = attn_metadata.num_prefills
            num_decodes = attn_metadata.num_decodes
            if num_prefills > 0:
                # Decodes are first so prefills start after the last decode
                prefill_start = num_decodes
                attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
                assert qo_indptr_cpu[prefill_start:].shape[
                    0] == num_prefills + 1
                assert paged_kv_indptr_cpu[prefill_start:].shape[
                    0] == num_prefills + 1
                assert paged_kv_last_page_len_cpu[prefill_start:].shape[
                    0] == num_prefills
                # Since prefill_wrapper.run() will be called with
                # query[num_decode_tokens:] we need to adjust the qo_indptr
                # to be relative to the start of the prefill queries.
                qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
                    prefill_start]
                paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
                if not attn_metadata.prefill_use_trtllm:
                    attn_metadata.prefill_wrapper.plan(
                        qo_indptr_cpu,
                        paged_kv_indptr_cpu,
                        paged_kv_indices,
                        paged_kv_last_page_len_cpu[prefill_start:],
                        self.num_qo_heads,
                        self.num_kv_heads,
                        self.head_dim,
                        self.page_size,
                        causal=True,
                        sm_scale=self.global_hyperparameters.sm_scale,
                        window_left=self.global_hyperparameters.window_left,
                        logits_soft_cap=self.global_hyperparameters.
                        logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
                    )
                else:
                    attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
                    attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
                        self.device)

            if num_decodes > 0:
                pure_decode = num_prefills == 0
                # possible required padding for cudagraph replay
                use_cudagraph = (self.enable_cuda_graph and pure_decode and
                                 num_decodes <= self._decode_cudagraph_max_bs)
                if use_cudagraph:
                    num_input_tokens = (
                        self.vllm_config.pad_for_cudagraph(num_decodes))
                    # Carefully fulfill the padding region with reasonable value
                    # on cpu.
                    # Make sure paged_kv_indptr_cpu is not decreasing
                    self.paged_kv_indptr_cpu[1 + num_decodes:1 +
                                             num_input_tokens].fill_(
                                                 paged_kv_indptr_cpu[-1])
                    # Fill the remaining paged_kv_last_page_len_cpu with 1.
                    # This is because flashinfer treats 0 as a full page
                    # instead of empty.
                    self.paged_kv_last_page_len_cpu[
                        num_decodes:num_input_tokens].fill_(1)

                else:
                    num_input_tokens = num_decodes

                attn_metadata.decode_wrapper = self._get_decode_wrapper(
                    num_input_tokens, use_cudagraph)
                if not attn_metadata.decode_use_trtllm:
                    # Use the persistent buffer with padding length,
                    # instead of the same address but chunked version
                    # in atten_metadata when using cudagraph.
                    fast_plan_decode(
                        attn_metadata.decode_wrapper,
                        self.paged_kv_indptr_cpu[:num_input_tokens + 1],
                        paged_kv_indices,
                        self.paged_kv_last_page_len_cpu[:num_input_tokens],
                        seq_lens_cpu[:num_input_tokens],
                        self.num_qo_heads,
                        self.num_kv_heads,
                        self.head_dim,
                        self.page_size,
                        # Disable flashinfer's pos encoding and use vllm's rope.
                        pos_encoding_mode="NONE",
                        sm_scale=self.global_hyperparameters.sm_scale,
                        window_left=self.global_hyperparameters.window_left,
                        logits_soft_cap=self.global_hyperparameters.
                        logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
                    )
        return attn_metadata

    def build_for_cudagraph_capture(
            self, common_attn_metadata: CommonAttentionMetadata):
        """
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with FlashInfer.
        """
        m = common_attn_metadata

        assert m.num_reqs == m.num_actual_tokens, \
            "FlashInfer only supports decode-only full CUDAGraph capture. " \
            "Make sure all cudagraph capture sizes <= max_num_seq."

        m.max_query_len = 1  # decode-only

        return self.build(0, m)

    def use_cascade_attention(self, *args, **kwargs) -> bool:
        if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
            # TODO: The cascade wrapper currently does not support setting
            # kv cache dtype to something different from query dtype.
            return False
        return use_cascade_attention(*args, **kwargs)

_cascade_wrapper instance-attribute

_cascade_wrapper = None

_decode_cudagraph_max_bs instance-attribute

_decode_cudagraph_max_bs = min(
    max_num_reqs, max_capture_size
)

_decode_wrapper instance-attribute

_decode_wrapper = None

_decode_wrappers_cudagraph instance-attribute

_decode_wrappers_cudagraph: dict[
    int, BatchDecodeWithPagedKVCacheWrapper
] = {}

_prefill_wrapper instance-attribute

_prefill_wrapper = None

_workspace_buffer instance-attribute

_workspace_buffer = None

block_table_arange instance-attribute

block_table_arange = arange(
    max_num_pages_per_req, dtype=int32, device=device
)

cache_config instance-attribute

cache_config = cache_config

cache_dtype instance-attribute

cache_dtype = cache_dtype

compilation_config instance-attribute

compilation_config = compilation_config

cudagraph_support class-attribute

device instance-attribute

device = device

enable_cuda_graph instance-attribute

enable_cuda_graph = decode_mode() == FULL

enable_fusion instance-attribute

enable_fusion = enable_attn_fusion

global_hyperparameters instance-attribute

global_hyperparameters = infer_global_hyperparameters(
    get_per_layer_parameters(
        vllm_config, layer_names, FlashInferImpl
    )
)

head_dim instance-attribute

head_dim = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = get_fp8_dtype_for_flashinfer(cache_dtype)

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

model_config instance-attribute

model_config = model_config

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_qo_heads instance-attribute

num_qo_heads = get_num_attention_heads(parallel_config)

page_size instance-attribute

page_size = block_size

paged_kv_indices instance-attribute

paged_kv_indices = zeros(
    max_num_pages, dtype=int32, device=device
)

paged_kv_indices_cpu instance-attribute

paged_kv_indices_cpu = zeros(
    max_num_pages,
    dtype=int32,
    device="cpu",
    pin_memory=pin_memory,
)

paged_kv_indptr instance-attribute

paged_kv_indptr = zeros(
    max_num_reqs + 1, dtype=int32, device=device
)

paged_kv_indptr_cpu instance-attribute

paged_kv_indptr_cpu = zeros(
    max_num_reqs + 1,
    dtype=int32,
    device="cpu",
    pin_memory=pin_memory,
)

paged_kv_last_page_len instance-attribute

paged_kv_last_page_len = zeros(
    max_num_reqs, dtype=int32, device=device
)

paged_kv_last_page_len_cpu instance-attribute

paged_kv_last_page_len_cpu = zeros(
    max_num_reqs,
    dtype=int32,
    device="cpu",
    pin_memory=pin_memory,
)

q_data_type instance-attribute

q_data_type = dtype

reorder_batch_threshold class-attribute

reorder_batch_threshold: int = 1

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/flashinfer.py
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):
    self.device = device
    self.vllm_config = vllm_config
    self.cache_config = vllm_config.cache_config
    self.model_config = vllm_config.model_config
    self.kv_cache_spec = kv_cache_spec
    self._workspace_buffer = None
    self._prefill_wrapper = None  # Wrapper for prefill/append
    self._decode_wrapper = None  # Wrapper for decode (general shape)

    self.compilation_config = vllm_config.compilation_config
    max_num_pages_per_req = cdiv(self.model_config.max_model_len,
                                 self.kv_cache_spec.block_size)
    max_num_reqs = vllm_config.scheduler_config.max_num_seqs
    max_num_pages = max_num_reqs * max_num_pages_per_req
    self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
        decode_mode() == CUDAGraphMode.FULL
    if self.enable_cuda_graph:
        # For full cudagraph capture, one `decode_wrapper` for each batch
        # size is needed for FlashInfer.
        self._decode_wrappers_cudagraph: dict[
            int, BatchDecodeWithPagedKVCacheWrapper] = {}
        self._decode_cudagraph_max_bs = min(
            max_num_reqs, self.compilation_config.max_capture_size)

    self.num_qo_heads = self.model_config.get_num_attention_heads(
        self.vllm_config.parallel_config)
    self.num_kv_heads = self.kv_cache_spec.num_kv_heads
    self.head_dim = self.kv_cache_spec.head_size
    FlashInferBackend.validate_head_size(self.head_dim)
    self.page_size = self.kv_cache_spec.block_size

    self.enable_fusion = (
        self.compilation_config.pass_config.enable_attn_fusion)
    self.q_data_type = self.model_config.dtype
    self.cache_dtype = self.cache_config.cache_dtype
    if self.cache_dtype.startswith("fp8"):
        self.kv_cache_dtype = (
            FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.cache_dtype))
        # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
        if self.enable_fusion:
            self.q_data_type = self.kv_cache_dtype
    else:
        self.kv_cache_dtype = self.kv_cache_spec.dtype

    self._cascade_wrapper = None  # Wrapper for cascade attention

    # Global hyperparameters shared by all attention layers
    # TODO: discard this for trtllm-gen backend
    self.global_hyperparameters = infer_global_hyperparameters(
        get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))

    # Preparing persistent buffers (device-side)
    self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
                                       dtype=torch.int32,
                                       device=self.device)
    self.paged_kv_indices = torch.zeros(
        max_num_pages,  # max num pages possible
        dtype=torch.int32,
        device=self.device)
    self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
                                              dtype=torch.int32,
                                              device=self.device)
    # host-side buffer
    pin_memory = is_pin_memory_available()
    self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
                                           dtype=torch.int32,
                                           device="cpu",
                                           pin_memory=pin_memory)
    self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=pin_memory)
    self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs,
                                                  dtype=torch.int32,
                                                  device="cpu",
                                                  pin_memory=pin_memory)

    self.block_table_arange = torch.arange(max_num_pages_per_req,
                                           dtype=torch.int32,
                                           device=self.device)

_get_cascade_wrapper

_get_cascade_wrapper()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_cascade_wrapper(self):
    if self._cascade_wrapper is None:
        self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
            2, self._get_workspace_buffer(), get_kv_cache_layout())
    return self._cascade_wrapper

_get_decode_wrapper

_get_decode_wrapper(
    batch_size: int, use_cudagraph: bool = False
)
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_decode_wrapper(self,
                        batch_size: int,
                        use_cudagraph: bool = False):
    if use_cudagraph:
        decode_wrapper = self._decode_wrappers_cudagraph.get(
            batch_size, None)
    else:
        decode_wrapper = self._decode_wrapper

    if decode_wrapper is None:
        if use_cudagraph:
            paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
            paged_kv_indices = self.paged_kv_indices
            paged_kv_last_page_len = self.paged_kv_last_page_len[:
                                                                 batch_size]
        else:
            paged_kv_indptr = None
            paged_kv_indices = None
            paged_kv_last_page_len = None
        decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
            self._get_workspace_buffer(),
            get_kv_cache_layout(),
            use_cuda_graph=use_cudagraph,
            paged_kv_indptr_buffer=paged_kv_indptr,
            paged_kv_indices_buffer=paged_kv_indices,
            paged_kv_last_page_len_buffer=paged_kv_last_page_len,
            # Tensor cores are enabled by default because the perf would be
            # atleast as good as cuda cores for all attention ops in latest
            # gpus.
            use_tensor_cores=True,
        )

        # save the decode wrapper
        if use_cudagraph:
            self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
        else:
            self._decode_wrapper = decode_wrapper

    return decode_wrapper

_get_prefill_wrapper

_get_prefill_wrapper()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_prefill_wrapper(self):
    if self._prefill_wrapper is None:
        self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
            self._get_workspace_buffer(), get_kv_cache_layout())
    return self._prefill_wrapper

_get_workspace_buffer

_get_workspace_buffer()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_workspace_buffer(self):
    if self._workspace_buffer is None:
        self._workspace_buffer = torch.zeros(
            FLASHINFER_WORKSPACE_BUFFER_SIZE,
            dtype=torch.uint8,
            device=self.device)
    return self._workspace_buffer

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> FlashInferMetadata
Source code in vllm/v1/attention/backends/flashinfer.py
def build(self,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata,
          fast_build: bool = False) -> FlashInferMetadata:
    num_reqs = common_attn_metadata.num_reqs
    num_actual_tokens = common_attn_metadata.num_actual_tokens
    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
        split_decodes_and_prefills(common_attn_metadata)

    page_size = self.page_size
    max_q_len = common_attn_metadata.max_query_len
    max_seq_len = common_attn_metadata.max_seq_len
    seq_lens = common_attn_metadata.seq_lens
    seq_lens_cpu = common_attn_metadata.seq_lens_cpu
    block_table_tensor = common_attn_metadata.block_table_tensor

    block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size

    use_cascade = common_prefix_len > 0
    if use_cascade:
        # Grab the blocks of the shared prefix from the first request.
        assert common_prefix_len % page_size == 0
        num_common_kv_blocks = common_prefix_len // page_size

        # Create CPU versions directly for cascade (no GPU versions needed)
        shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
                                            dtype=torch.int32,
                                            device='cpu')
        shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
                                                 dtype=torch.int32,
                                                 device='cpu')
        shared_kv_page_indices_cpu = block_table_tensor[
            0, :num_common_kv_blocks]
        shared_kv_last_page_len_cpu = torch.tensor([page_size],
                                                   dtype=torch.int32,
                                                   device='cpu')

        # Remove the blocks of the shared prefix from all requests.
        block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
        block_table_bounds_cpu -= num_common_kv_blocks
    else:
        shared_qo_indptr_cpu = None
        shared_kv_page_indptr_cpu = None
        shared_kv_page_indices_cpu = None
        shared_kv_last_page_len_cpu = None

    max_num_blocks = block_table_bounds_cpu.max().item()
    block_table_bounds = block_table_bounds_cpu.to(self.device,
                                                   non_blocking=True)
    mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
            < block_table_bounds.unsqueeze(1))
    # write self.paged_kv_indices inplace
    num_actual_pages = torch.sum(mask)
    paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
    torch.masked_select(block_table_tensor[:, :max_num_blocks],
                        mask,
                        out=paged_kv_indices)

    # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
    torch.cumsum(block_table_bounds_cpu,
                 dim=0,
                 dtype=torch.int32,
                 out=self.paged_kv_indptr_cpu[1:1 + num_reqs])

    paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
    # write self.paged_kv_last_page_len_cpu inplace
    torch.where(paged_kv_last_page_len_cpu == 0,
                torch.tensor(page_size),
                paged_kv_last_page_len_cpu,
                out=self.paged_kv_last_page_len_cpu[:num_reqs])

    # Check if any layer uses sinks (requires TRTLLM attention)
    has_sinks = self.global_hyperparameters.has_sinks

    prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
                                              self.num_kv_heads,
                                              num_prefill_tokens,
                                              max_seq_len,
                                              self.cache_dtype,
                                              self.q_data_type,
                                              is_prefill=True,
                                              has_sinks=has_sinks)
    decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
                                             self.num_kv_heads,
                                             num_decode_tokens,
                                             max_seq_len,
                                             self.cache_dtype,
                                             self.q_data_type,
                                             is_prefill=False,
                                             has_sinks=has_sinks)

    attn_metadata = FlashInferMetadata(
        num_actual_tokens=num_actual_tokens,
        q_data_type=self.q_data_type,
        slot_mapping=common_attn_metadata.slot_mapping,
        max_q_len=max_q_len,
        max_seq_len=max_seq_len,
        seq_lens=seq_lens,
        block_table_tensor=block_table_tensor,
        prefill_use_trtllm=prefill_use_trtllm,
        decode_use_trtllm=decode_use_trtllm,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        use_cascade=use_cascade,
    )

    qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
    paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs]
    paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]

    if attn_metadata.use_cascade:
        attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
        attn_metadata.cascade_wrapper.plan(
            [shared_qo_indptr_cpu, qo_indptr_cpu],
            [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
            [shared_kv_page_indices_cpu, paged_kv_indices],
            [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
            self.num_qo_heads,
            self.num_kv_heads,
            self.head_dim,
            self.page_size,
            causal=True,
            sm_scale=self.global_hyperparameters.sm_scale,
            window_left=self.global_hyperparameters.window_left,
            logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
            q_data_type=self.q_data_type,
            kv_data_type=self.kv_cache_dtype,
        )
    else:
        # Regular attention (common case).
        # Decodes are at the front and prefills are at the back,
        # according to reorder_batch()
        num_prefills = attn_metadata.num_prefills
        num_decodes = attn_metadata.num_decodes
        if num_prefills > 0:
            # Decodes are first so prefills start after the last decode
            prefill_start = num_decodes
            attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
            assert qo_indptr_cpu[prefill_start:].shape[
                0] == num_prefills + 1
            assert paged_kv_indptr_cpu[prefill_start:].shape[
                0] == num_prefills + 1
            assert paged_kv_last_page_len_cpu[prefill_start:].shape[
                0] == num_prefills
            # Since prefill_wrapper.run() will be called with
            # query[num_decode_tokens:] we need to adjust the qo_indptr
            # to be relative to the start of the prefill queries.
            qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
                prefill_start]
            paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
            if not attn_metadata.prefill_use_trtllm:
                attn_metadata.prefill_wrapper.plan(
                    qo_indptr_cpu,
                    paged_kv_indptr_cpu,
                    paged_kv_indices,
                    paged_kv_last_page_len_cpu[prefill_start:],
                    self.num_qo_heads,
                    self.num_kv_heads,
                    self.head_dim,
                    self.page_size,
                    causal=True,
                    sm_scale=self.global_hyperparameters.sm_scale,
                    window_left=self.global_hyperparameters.window_left,
                    logits_soft_cap=self.global_hyperparameters.
                    logits_soft_cap,
                    q_data_type=self.q_data_type,
                    kv_data_type=self.kv_cache_dtype,
                )
            else:
                attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device)
                attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
                    self.device)

        if num_decodes > 0:
            pure_decode = num_prefills == 0
            # possible required padding for cudagraph replay
            use_cudagraph = (self.enable_cuda_graph and pure_decode and
                             num_decodes <= self._decode_cudagraph_max_bs)
            if use_cudagraph:
                num_input_tokens = (
                    self.vllm_config.pad_for_cudagraph(num_decodes))
                # Carefully fulfill the padding region with reasonable value
                # on cpu.
                # Make sure paged_kv_indptr_cpu is not decreasing
                self.paged_kv_indptr_cpu[1 + num_decodes:1 +
                                         num_input_tokens].fill_(
                                             paged_kv_indptr_cpu[-1])
                # Fill the remaining paged_kv_last_page_len_cpu with 1.
                # This is because flashinfer treats 0 as a full page
                # instead of empty.
                self.paged_kv_last_page_len_cpu[
                    num_decodes:num_input_tokens].fill_(1)

            else:
                num_input_tokens = num_decodes

            attn_metadata.decode_wrapper = self._get_decode_wrapper(
                num_input_tokens, use_cudagraph)
            if not attn_metadata.decode_use_trtllm:
                # Use the persistent buffer with padding length,
                # instead of the same address but chunked version
                # in atten_metadata when using cudagraph.
                fast_plan_decode(
                    attn_metadata.decode_wrapper,
                    self.paged_kv_indptr_cpu[:num_input_tokens + 1],
                    paged_kv_indices,
                    self.paged_kv_last_page_len_cpu[:num_input_tokens],
                    seq_lens_cpu[:num_input_tokens],
                    self.num_qo_heads,
                    self.num_kv_heads,
                    self.head_dim,
                    self.page_size,
                    # Disable flashinfer's pos encoding and use vllm's rope.
                    pos_encoding_mode="NONE",
                    sm_scale=self.global_hyperparameters.sm_scale,
                    window_left=self.global_hyperparameters.window_left,
                    logits_soft_cap=self.global_hyperparameters.
                    logits_soft_cap,
                    q_data_type=self.q_data_type,
                    kv_data_type=self.kv_cache_dtype,
                )
    return attn_metadata

build_for_cudagraph_capture

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
)

This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with FlashInfer.

Source code in vllm/v1/attention/backends/flashinfer.py
def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata):
    """
    This method builds the metadata for full cudagraph capture.
    Currently, only decode is supported for full cudagraphs with FlashInfer.
    """
    m = common_attn_metadata

    assert m.num_reqs == m.num_actual_tokens, \
        "FlashInfer only supports decode-only full CUDAGraph capture. " \
        "Make sure all cudagraph capture sizes <= max_num_seq."

    m.max_query_len = 1  # decode-only

    return self.build(0, m)

use_cascade_attention

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/flashinfer.py
def use_cascade_attention(self, *args, **kwargs) -> bool:
    if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
        # TODO: The cascade wrapper currently does not support setting
        # kv cache dtype to something different from query dtype.
        return False
    return use_cascade_attention(*args, **kwargs)

fast_plan_decode

fast_plan_decode(
    self,
    indptr_cpu: Tensor,
    indices: Tensor,
    last_page_len_cpu: Tensor,
    seq_lens_cpu: Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
    logits_soft_cap: Optional[float] = None,
    q_data_type: Optional[Union[str, dtype]] = "float16",
    kv_data_type: Optional[Union[str, dtype]] = None,
    data_type: Optional[Union[str, dtype]] = None,
    sm_scale: Optional[float] = None,
    rope_scale: Optional[float] = None,
    rope_theta: Optional[float] = None,
    non_blocking: bool = True,
) -> None

A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for cudagraph capture/replay, while the no cudagraph version turns back to the original plan. using original plan after passing host-side buffers: - only host-to-device copy of indptr and last_page_len buffers Modifications for cudagraph: - only host-to-device copy of indptr and last_page_len buffers. - avoid device-to-device copy of indices buffer.

Part of the code get inspiration from the original plan from FlashInfer repo and the implementation of fast_decode_plan for FlashInfer in SGlang repo.

Source code in vllm/v1/attention/backends/flashinfer.py
def fast_plan_decode(
    self,  # decode wrapper
    indptr_cpu: torch.Tensor,
    indices: torch.Tensor,
    last_page_len_cpu: torch.Tensor,
    seq_lens_cpu: torch.Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
    logits_soft_cap: Optional[float] = None,
    q_data_type: Optional[Union[str, torch.dtype]] = "float16",
    kv_data_type: Optional[Union[str, torch.dtype]] = None,
    data_type: Optional[Union[str, torch.dtype]] = None,
    sm_scale: Optional[float] = None,
    rope_scale: Optional[float] = None,
    rope_theta: Optional[float] = None,
    non_blocking: bool = True,
) -> None:
    """
    A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
    cudagraph capture/replay, while the no cudagraph version turns back
    to the original plan.
    using original plan after passing host-side buffers:
    - only host-to-device copy of indptr and last_page_len buffers
    Modifications for cudagraph:
    - only host-to-device copy of indptr and last_page_len buffers.
    - avoid device-to-device copy of indices buffer.

    Part of the code get inspiration from the original plan from FlashInfer repo
    and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
    """
    # Warm up with the original plan if it is first call, and always run the
    # original plan if we run for dynamic shape. For fixed shape (cudagraph),
    # this warm up is to generate the _cached_module for the decode wrapper.
    if not self.is_cuda_graph_enabled or \
        getattr(self, "vllm_first_call", True):
        self.plan(
            indptr_cpu,
            indices,
            last_page_len_cpu,
            num_qo_heads,
            num_kv_heads,
            head_dim,
            page_size,
            pos_encoding_mode,
            window_left,
            logits_soft_cap,
            q_data_type,
            kv_data_type,
            data_type,
            sm_scale,
            rope_scale,
            rope_theta,
            non_blocking,
        )
        self.vllm_first_call = False
        return

    assert self.is_cuda_graph_enabled, "Should be cudagraph only here"

    batch_size = len(last_page_len_cpu)
    if logits_soft_cap is None:
        logits_soft_cap = 0.0

    # Handle data types consistently
    if data_type is not None:
        if q_data_type is None:
            q_data_type = data_type
        if kv_data_type is None:
            kv_data_type = data_type
    elif q_data_type is None:
        q_data_type = "float16"

    if kv_data_type is None:
        kv_data_type = q_data_type
    q_data_type = getattr(torch, q_data_type) if isinstance(
        q_data_type, str) else q_data_type
    kv_data_type = getattr(torch, kv_data_type) if isinstance(
        kv_data_type, str) else kv_data_type

    if batch_size != self._fixed_batch_size:
        raise ValueError(
            "The batch size should be fixed in cudagraph mode, the runtime "
            "batch size {} mismatches the batch size set during "
            "initialization {}".format(batch_size, self._fixed_batch_size))
    if len(indices) > len(self._paged_kv_indices_buf):
        raise ValueError(
            "The size of indices should be less than or equal to the "
            "allocated buffer")

    # host-to-device copy for the indptr buffer
    self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
    # host-to-device copy for the last_page_len buffer
    self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
                                           non_blocking=True)

    qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

    try:
        # Make sure we pass exactly 15 arguments for tensor core version
        self._plan_info = self._cached_module.plan(
            self._float_workspace_buffer,
            self._int_workspace_buffer,
            self._pin_memory_int_workspace_buffer,
            qo_indptr_host,
            indptr_cpu,
            seq_lens_cpu,
            batch_size,  # total_num_rows
            batch_size,
            num_qo_heads,
            num_kv_heads,
            page_size,
            self.is_cuda_graph_enabled,
            head_dim,
            head_dim,
            False,  # causal
        )
    except Exception as e:
        raise RuntimeError(f"Error in tensor core plan: {e}") from e

    self._pos_encoding_mode = pos_encoding_mode
    self._window_left = window_left
    self._logits_soft_cap = logits_soft_cap
    self._sm_scale = sm_scale
    self._rope_scale = rope_scale
    self._rope_theta = rope_theta