Skip to content

vllm.v1.attention.backends.utils

KV_SHARING_FAST_PREFILL_METADATA_FIELDS module-attribute

KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
    ("logits_indices_padded", Optional[Tensor], None),
    ("num_logits_indices", int, 0),
]

M module-attribute

M = TypeVar('M')

_KV_CACHE_LAYOUT_OVERRIDE module-attribute

_KV_CACHE_LAYOUT_OVERRIDE = None

logger module-attribute

logger = init_logger(__name__)

AttentionCGSupport

Bases: Enum

Constants for the cudagraph support of the attention backend Here we do not consider the cascade attention, as currently it is never cudagraph supported.

Source code in vllm/v1/attention/backends/utils.py
class AttentionCGSupport(enum.Enum):
    """ Constants for the cudagraph support of the attention backend
    Here we do not consider the cascade attention, as currently
    it is never cudagraph supported."""

    ALWAYS = 3
    """Cudagraph always supported; supports mixed-prefill-decode"""
    UNIFORM_BATCH = 2
    """Cudagraph supported for batches the only contain query lengths that are
    the same, this can be used for spec-decode 
        i.e. "decodes" are 1 + num_speculative_tokens"""
    UNIFORM_SINGLE_TOKEN_DECODE = 1
    """Cudagraph supported for batches the only contain query_len==1 decodes"""
    NEVER = 0
    """NO cudagraph support"""

ALWAYS class-attribute instance-attribute

ALWAYS = 3

Cudagraph always supported; supports mixed-prefill-decode

NEVER class-attribute instance-attribute

NEVER = 0

NO cudagraph support

UNIFORM_BATCH class-attribute instance-attribute

UNIFORM_BATCH = 2

Cudagraph supported for batches the only contain query lengths that are the same, this can be used for spec-decode i.e. "decodes" are 1 + num_speculative_tokens

UNIFORM_SINGLE_TOKEN_DECODE class-attribute instance-attribute

UNIFORM_SINGLE_TOKEN_DECODE = 1

Cudagraph supported for batches the only contain query_len==1 decodes

AttentionMetadataBuilder

Bases: ABC, Generic[M]

Source code in vllm/v1/attention/backends/utils.py
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
    # Does this backend/builder support CUDA Graphs for attention (default: no).
    cudagraph_support: ClassVar[AttentionCGSupport] = \
        AttentionCGSupport.NEVER
    # Does this backend/builder reorder the batch?
    # If not, set this to None. Otherwise set it to the query
    # length that will be pulled into the front of the batch.
    reorder_batch_threshold: ClassVar[Optional[int]] = None

    @abstractmethod
    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
        self.kv_cache_spec = kv_cache_spec

    @abstractmethod
    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> M:
        """
        Central method that builds attention metadata.
        Some builders (MLA) require reorder_batch to be called prior to build.

        Args:
            common_prefix_len: The length of the common prefix of the batch.
            common_attn_metadata: The common attention metadata.
            fast_build: The meta-data will prioritize speed of building over
                then speed at execution. Can be used for spec-decode where the
                result of a build call may only be used for few layers/iters.
        """
        raise NotImplementedError

    def build_for_cudagraph_capture(
            self, common_attn_metadata: CommonAttentionMetadata) -> M:
        """
        Build attention metadata for CUDA graph capture. Uses build by default.
        Subclasses that override this method should call self.build or
        super().build_for_cudagraph_capture.
        """
        return self.build(common_prefix_len=0,
                          common_attn_metadata=common_attn_metadata)

    def build_for_drafting(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        draft_index: int,
    ) -> M:
        """
        Build attention metadata for draft model. Uses build by default.

        Args:
            common_attn_metadata: The common attention metadata.
            draft_index: The index of the current draft operation.
                When speculating a chain of tokens, this index refers to the
                draft attempt for the i-th token.
                For tree-based attention, this index instead refers to the
                draft attempt for the i-th level in the tree of tokens.
        """
        return self.build(common_prefix_len=0,
                          common_attn_metadata=common_attn_metadata,
                          fast_build=True)

    def use_cascade_attention(
        self,
        common_prefix_len: int,
        query_lens: np.ndarray,
        num_query_heads: int,
        num_kv_heads: int,
        use_alibi: bool,
        use_sliding_window: bool,
        use_local_attention: bool,
        num_sms: int,
    ) -> bool:
        return False

cudagraph_support class-attribute

cudagraph_support: AttentionCGSupport = NEVER

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

reorder_batch_threshold class-attribute

reorder_batch_threshold: Optional[int] = None

__init__ abstractmethod

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/utils.py
@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
             vllm_config: VllmConfig, device: torch.device):
    self.kv_cache_spec = kv_cache_spec

build abstractmethod

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> M

Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build.

Parameters:

Name Type Description Default
common_prefix_len int

The length of the common prefix of the batch.

required
common_attn_metadata CommonAttentionMetadata

The common attention metadata.

required
fast_build bool

The meta-data will prioritize speed of building over then speed at execution. Can be used for spec-decode where the result of a build call may only be used for few layers/iters.

False
Source code in vllm/v1/attention/backends/utils.py
@abstractmethod
def build(self,
          common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata,
          fast_build: bool = False) -> M:
    """
    Central method that builds attention metadata.
    Some builders (MLA) require reorder_batch to be called prior to build.

    Args:
        common_prefix_len: The length of the common prefix of the batch.
        common_attn_metadata: The common attention metadata.
        fast_build: The meta-data will prioritize speed of building over
            then speed at execution. Can be used for spec-decode where the
            result of a build call may only be used for few layers/iters.
    """
    raise NotImplementedError

build_for_cudagraph_capture

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
) -> M

Build attention metadata for CUDA graph capture. Uses build by default. Subclasses that override this method should call self.build or super().build_for_cudagraph_capture.

Source code in vllm/v1/attention/backends/utils.py
def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata) -> M:
    """
    Build attention metadata for CUDA graph capture. Uses build by default.
    Subclasses that override this method should call self.build or
    super().build_for_cudagraph_capture.
    """
    return self.build(common_prefix_len=0,
                      common_attn_metadata=common_attn_metadata)

build_for_drafting

build_for_drafting(
    common_attn_metadata: CommonAttentionMetadata,
    draft_index: int,
) -> M

Build attention metadata for draft model. Uses build by default.

Parameters:

Name Type Description Default
common_attn_metadata CommonAttentionMetadata

The common attention metadata.

required
draft_index int

The index of the current draft operation. When speculating a chain of tokens, this index refers to the draft attempt for the i-th token. For tree-based attention, this index instead refers to the draft attempt for the i-th level in the tree of tokens.

required
Source code in vllm/v1/attention/backends/utils.py
def build_for_drafting(
    self,
    common_attn_metadata: CommonAttentionMetadata,
    draft_index: int,
) -> M:
    """
    Build attention metadata for draft model. Uses build by default.

    Args:
        common_attn_metadata: The common attention metadata.
        draft_index: The index of the current draft operation.
            When speculating a chain of tokens, this index refers to the
            draft attempt for the i-th token.
            For tree-based attention, this index instead refers to the
            draft attempt for the i-th level in the tree of tokens.
    """
    return self.build(common_prefix_len=0,
                      common_attn_metadata=common_attn_metadata,
                      fast_build=True)

use_cascade_attention

use_cascade_attention(
    common_prefix_len: int,
    query_lens: ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
    use_local_attention: bool,
    num_sms: int,
) -> bool
Source code in vllm/v1/attention/backends/utils.py
def use_cascade_attention(
    self,
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
    use_local_attention: bool,
    num_sms: int,
) -> bool:
    return False

CommonAttentionMetadata dataclass

Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata.

For many of the tensors we keep both GPU and CPU versions.

Source code in vllm/v1/attention/backends/utils.py
@dataclass
class CommonAttentionMetadata:
    """
    Per-batch attention metadata, shared across layers and backends.
    AttentionMetadataBuilder instances use it to construct per-layer metadata.

    For many of the tensors we keep both GPU and CPU versions.
    """

    query_start_loc: torch.Tensor
    query_start_loc_cpu: torch.Tensor
    """(batch_size + 1,), the start location of each request in query Tensor"""

    seq_lens: torch.Tensor
    seq_lens_cpu: torch.Tensor
    """(batch_size,), the length of each request including both computed tokens
    and newly scheduled tokens"""

    num_computed_tokens_cpu: torch.Tensor
    """(batch_size,), the number of computed tokens for each request"""

    num_reqs: int
    """Number of requests"""
    num_actual_tokens: int
    """Total number of tokens in batch"""
    max_query_len: int
    """Longest query in batch"""
    max_seq_len: int
    """Longest context length in batch"""

    block_table_tensor: torch.Tensor
    slot_mapping: torch.Tensor

    causal: bool = True

block_table_tensor instance-attribute

block_table_tensor: Tensor

causal class-attribute instance-attribute

causal: bool = True

max_query_len instance-attribute

max_query_len: int

Longest query in batch

max_seq_len instance-attribute

max_seq_len: int

Longest context length in batch

num_actual_tokens instance-attribute

num_actual_tokens: int

Total number of tokens in batch

num_computed_tokens_cpu instance-attribute

num_computed_tokens_cpu: Tensor

(batch_size,), the number of computed tokens for each request

num_reqs instance-attribute

num_reqs: int

Number of requests

query_start_loc instance-attribute

query_start_loc: Tensor

query_start_loc_cpu instance-attribute

query_start_loc_cpu: Tensor

(batch_size + 1,), the start location of each request in query Tensor

seq_lens instance-attribute

seq_lens: Tensor

seq_lens_cpu instance-attribute

seq_lens_cpu: Tensor

(batch_size,), the length of each request including both computed tokens and newly scheduled tokens

slot_mapping instance-attribute

slot_mapping: Tensor

__init__

__init__(
    query_start_loc: Tensor,
    query_start_loc_cpu: Tensor,
    seq_lens: Tensor,
    seq_lens_cpu: Tensor,
    num_computed_tokens_cpu: Tensor,
    num_reqs: int,
    num_actual_tokens: int,
    max_query_len: int,
    max_seq_len: int,
    block_table_tensor: Tensor,
    slot_mapping: Tensor,
    causal: bool = True,
) -> None

PerLayerParameters dataclass

Currently, FlashInfer backend only support models in which all layers share the same values for the following hyperparameters. Should not be used for trtllm-gen backend since it supports different values for the following hyperparameters.

Source code in vllm/v1/attention/backends/utils.py
@dataclass
class PerLayerParameters:
    """
    Currently, FlashInfer backend only support models in which all layers share
    the same values for the following hyperparameters. Should not be used for
    trtllm-gen backend since it supports different values for the following
    hyperparameters.
    """

    window_left: int
    logits_soft_cap: Optional[float]
    sm_scale: float
    has_sinks: bool = False

has_sinks class-attribute instance-attribute

has_sinks: bool = False

logits_soft_cap instance-attribute

logits_soft_cap: Optional[float]

sm_scale instance-attribute

sm_scale: float

window_left instance-attribute

window_left: int

__init__

__init__(
    window_left: int,
    logits_soft_cap: Optional[float],
    sm_scale: float,
    has_sinks: bool = False,
) -> None

UbatchSlice dataclass

Source code in vllm/v1/attention/backends/utils.py
@dataclass
class UbatchSlice:
    request_slice: slice
    token_slice: slice

request_slice instance-attribute

request_slice: slice

token_slice instance-attribute

token_slice: slice

__init__

__init__(request_slice: slice, token_slice: slice) -> None

_make_metadata_with_slice

_make_metadata_with_slice(
    ubatch_slice: UbatchSlice,
    attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata

This function creates a new CommonAttentionMetadata that corresponds to the requests included in ubatch_slice

Source code in vllm/v1/attention/backends/utils.py
def _make_metadata_with_slice(
        ubatch_slice: UbatchSlice,
        attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
    """
    This function creates a new CommonAttentionMetadata that corresponds to 
    the requests included in ubatch_slice
    """

    request_slice = ubatch_slice.request_slice
    token_slice = ubatch_slice.token_slice

    query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
                                             request_slice)
    assert len(query_start_loc) >= 2, (
        f"query_start_loc must have at least 2 elements, "
        f"got {len(query_start_loc)}")
    query_start_loc_cpu = slice_query_start_locs(
        attn_metadata.query_start_loc_cpu, request_slice)

    seq_lens = attn_metadata.seq_lens[request_slice]
    seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
    max_seq_len = int(seq_lens_cpu.max())
    num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
        request_slice]

    num_requests = request_slice.stop - request_slice.start
    num_actual_tokens = token_slice.stop - token_slice.start
    max_query_len = int(
        torch.max(torch.abs(query_start_loc_cpu[1:] -
                            query_start_loc_cpu[:-1])).item())

    block_table_tensor = attn_metadata.block_table_tensor[request_slice]
    slot_mapping = attn_metadata.slot_mapping[token_slice]

    return CommonAttentionMetadata(
        query_start_loc=query_start_loc,
        query_start_loc_cpu=query_start_loc_cpu,
        seq_lens=seq_lens,
        seq_lens_cpu=seq_lens_cpu,
        num_computed_tokens_cpu=num_computed_tokens_cpu,
        num_reqs=num_requests,
        num_actual_tokens=num_actual_tokens,
        max_query_len=max_query_len,
        max_seq_len=max_seq_len,
        block_table_tensor=block_table_tensor,
        slot_mapping=slot_mapping,
    )

get_kv_cache_layout cached

get_kv_cache_layout()
Source code in vllm/v1/attention/backends/utils.py
@functools.lru_cache
def get_kv_cache_layout():
    # Format specified by the code.
    global _KV_CACHE_LAYOUT_OVERRIDE

    if _KV_CACHE_LAYOUT_OVERRIDE is not None:
        cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
        logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \
                         "Setting KV cache layout to %s.", cache_layout)
        return cache_layout

    # Format specified by the user.
    cache_layout = envs.VLLM_KV_CACHE_LAYOUT
    # When neither the user nor the override specified a layout, get default
    if cache_layout is None:
        cache_layout = get_kv_connector_cache_layout()
    else:
        logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
        "detected. Setting KV cache layout to %s.", cache_layout)
    return cache_layout

get_per_layer_parameters

get_per_layer_parameters(
    vllm_config: VllmConfig,
    layer_names: list[str],
    cls_: type[AttentionImpl],
) -> dict[str, PerLayerParameters]

Scan layers in layer_names and determine some hyperparameters to use during plan.

Source code in vllm/v1/attention/backends/utils.py
def get_per_layer_parameters(
        vllm_config: VllmConfig, layer_names: list[str],
        cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
    """
    Scan layers in `layer_names` and determine some hyperparameters
    to use during `plan`.
    """

    layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
    per_layer_params: dict[str, PerLayerParameters] = {}

    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, cls_)

        # Infer hyperparameters from the attention layer
        window_size = getattr(impl, "sliding_window", None)
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = getattr(impl, "logits_soft_cap", None)
        sm_scale = impl.scale
        has_sinks = getattr(impl, "sinks", None) is not None

        per_layer_params[key] = PerLayerParameters(window_left,
                                                   logits_soft_cap, sm_scale,
                                                   has_sinks)

    return per_layer_params

infer_global_hyperparameters

infer_global_hyperparameters(
    per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters

Currently, FlashInfer backend other than trtllm-gen only support models in which all layers share the same values for the following hyperparameters: - window_left - logits_soft_cap - sm_scale

So this function asserts that all layers share the same values for these hyperparameters and returns the global values.

Source code in vllm/v1/attention/backends/utils.py
def infer_global_hyperparameters(
        per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
    """
    Currently, FlashInfer backend other than trtllm-gen 
    only support models in which all layers share
    the same values for the following hyperparameters:
    - `window_left`
    - `logits_soft_cap`
    - `sm_scale`

    So this function asserts that all layers share the same values for these
    hyperparameters and returns the global values.
    """

    assert len(per_layer_params) > 0, "No attention layers found in the model."

    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]

    # trtllm attention doesn't need global hyper params so disable the check
    if not envs.VLLM_USE_TRTLLM_ATTENTION:
        for params in param_sets:
            if params.window_left != global_params.window_left:
                raise ValueError(
                    "Window left is not the same for all layers. " \
                    "One potential fix is to set disable_sliding_window=True")
            assert params == global_params, (
                "FlashInfer backend currently only supports models in which all"
                "layers share the same values "
                "for the following hyperparameters:"
                "`window_left`, `logits_soft_cap`, `sm_scale`.")

    return global_params

make_kv_sharing_fast_prefill_attention_metadata

make_kv_sharing_fast_prefill_attention_metadata(
    metadata_cls: Any,
) -> Any

Return a new subclass of metadata_cls for fast prefill

Source code in vllm/v1/attention/backends/utils.py
def make_kv_sharing_fast_prefill_attention_metadata(
    metadata_cls: Any, ) -> Any:
    """
    Return a new subclass of `metadata_cls` for fast prefill
    """
    return subclass_attention_metadata(
        name_prefix="KVSharingFastPrefill",
        metadata_cls=metadata_cls,
        fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
    )

make_local_attention_virtual_batches

make_local_attention_virtual_batches(
    attn_chunk_size: int,
    common_attn_metadata: CommonAttentionMetadata,
    block_size: int = 0,
) -> CommonAttentionMetadata
Source code in vllm/v1/attention/backends/utils.py
def make_local_attention_virtual_batches(
    attn_chunk_size: int,
    common_attn_metadata: CommonAttentionMetadata,
    block_size: int = 0,
) -> CommonAttentionMetadata:
    query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
    seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
    block_table = common_attn_metadata.block_table_tensor
    device = common_attn_metadata.query_start_loc.device

    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
    actual_batch_size = seq_lens_np.shape[0]

    # Handle if we are starting in the middle of a local attention block,
    #  we assume q_seqlens > 0 (for all elements), for each batch idx we compute
    #  the number of tokens that are not in the first local attention block and
    #  then we can simply use a cdiv for the rest.
    # For example if we have:
    #   attn_chunk_size = 4
    #   q_seqlens = [4, 10, 5]
    #   k_seqlens = [6, 17, 9]
    # Then we would get:
    #   new_tokens_in_first_block = [2, 1, 4]
    #   local_blocks = [2, 4, 2]
    q_tokens_in_first_block = np.minimum(
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
        q_seqlens).astype(np.int32)
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
                            attn_chunk_size)

    # Once we know the number of local blocks we can compute the request spans
    #  for each batch idx, we can figure out the number of "virtual" requests we
    #  have to make,
    # For the above example we would get:
    #   seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
    #
    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
    #   (TODO: max a utility to share this code with _prepare_inputs)
    # arange step 1. [2, 4, 2] -> [2, 6, 8]
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
    # Then we can compute the seqlens_q_local, handling the fact that the
    #  first and last blocks could be partial
    seqlens_q_local = \
        np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
    # set the first block since this may be a partial block
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    # set the remaining blocks
    seqlens_q_local[arange > 0] = np.minimum(
        seqlens_q_local - attn_chunk_size * (arange - 1),
        attn_chunk_size)[arange > 0]

    # convert from q_seqlens to cu_seqlens_q
    cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
    np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
    cu_seqlens_q_local[0] = 0

    # compute the seqlens_k_local,
    #  basically a full local attention block for all but the last block in each
    #  batch
    # For our example this will be:
    #   seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
    seqlens_k_local = np.full(cu_num_blocks[-1],
                              attn_chunk_size,
                              dtype=np.int32)
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
    num_computed_tokens_local = seqlens_k_local - seqlens_q_local

    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
        (rarange * attn_chunk_size + \
            np.repeat(tokens_in_last_block, local_blocks))
    # For the example the local attention blocks start at:
    #                           _b0_  _____b1_____  _b2_
    #   k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
    block_starts = k_seqstarts_absolute // block_size
    assert attn_chunk_size % block_size == 0, \
        f"attn_chunk_size {attn_chunk_size} is not " \
        f"divisible by block_size {block_size}"
    pages_per_local_batch = attn_chunk_size // block_size

    # Create a block_table for the local attention blocks
    # For out example if we have a block-table like (assuming block_size=2):
    #   block_table = [
    #     [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],  < batch 0
    #     [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  < batch 1
    #     [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],  < batch 2
    #   ]
    # Then for the local batches we would want a block-table like
    #   block_table_local = [
    #     [  0,  1 ], < local-batch 0, (batch 0, starting from k[0])
    #     [  2,  3 ], < local-batch 1, (batch 0, starting from k[4])
    #     [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
    #     [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
    #     [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
    #     [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
    #     [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
    #     [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
    #   ]
    block_indices = (block_starts[:, None] +
                     np.arange(pages_per_local_batch, dtype=np.int32))
    block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] -
                                                   1)
    batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
                              local_blocks * pages_per_local_batch)
    block_table_local = block_table[batch_indices, block_indices]\
        .view(virtual_batches, -1)

    query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
    seq_lens_cpu = torch.from_numpy(seqlens_k_local)
    max_seq_len = int(seq_lens_cpu.max())

    return CommonAttentionMetadata(
        query_start_loc_cpu=query_start_loc_cpu,
        query_start_loc=query_start_loc_cpu.to(device=device,
                                               non_blocking=True),
        seq_lens_cpu=seq_lens_cpu,
        seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
        num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
        num_reqs=len(seq_lens_cpu),
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        max_query_len=seqlens_q_local.max(),
        max_seq_len=max_seq_len,
        block_table_tensor=block_table_local,
        slot_mapping=common_attn_metadata.slot_mapping,
        causal=True,
    )

reorder_batch_to_split_decodes_and_prefills

reorder_batch_to_split_decodes_and_prefills(
    input_batch: InputBatch,
    scheduler_output: SchedulerOutput,
    decode_threshold: int = 1,
) -> bool

Reorders the batch to split into prefill and decode requests; places all requests with <= decode_threshold tokens at the front of the batch.

Returns:

Type Description
bool

True if the batch was modified, False otherwise.

Source code in vllm/v1/attention/backends/utils.py
def reorder_batch_to_split_decodes_and_prefills(
    input_batch: "InputBatch",
    scheduler_output: "SchedulerOutput",
    decode_threshold: int = 1,
) -> bool:
    """
    Reorders the batch to split into prefill and decode requests; places all
    requests with <= decode_threshold tokens at the front of the batch.

    Returns:
        True if the batch was modified, False otherwise.
    """
    # We now want to reorder the batch so that the "decode" requests are at
    # the front and the "prefill" requests are at the back using the least
    # amount of swaps possible. (NOTE for now we loosely use "decode" to mean
    # requests where attention is likely memory-bound and "prefill" to mean
    # requests where attention is likely compute-bound, TODO(lucas): figure out
    # a better naming here)
    decodes = []
    prefills = []
    num_decode_tokens = 0
    num_prefill_tokens = 0

    for i, req_id in enumerate(input_batch.req_ids):
        num_tokens = scheduler_output.num_scheduled_tokens[req_id]
        # for now treat 1 scheduled token as "decode" even if its not,
        # we should update this to something like < 8 in the future but
        # currently the TritonMLA._forward_decode only supports
        # num_tokens = 1
        if num_tokens <= decode_threshold:
            decodes.append(i)
            num_decode_tokens += num_tokens
        else:
            prefills.append(i)
            num_prefill_tokens += num_tokens

    # We hope that this is fairly minimal since decodes
    # should be around for a number of iterations so hopefully they are
    # relatively stationary (and new request are generally appended to the
    # persistent batch so already should be at the back)
    # To achieve this we loop over the decodes in descending order and
    # the prefills in ascending order. We swap decodes from the  "back"
    # i.e. past where the last decode should be in the reodorered with
    # prefills from the front of the batch.
    # `decodes` and `prefills` are already in ascending order just based on
    # the above loop
    num_decodes = len(decodes)
    num_prefills = len(prefills)
    modified_batch = False

    for i in range(1, min(num_decodes, num_prefills) + 1):
        # If the decode is at the "back" of the batch, i, we can swap it
        # with the prefill closest to the front of the batch
        decode_idx = decodes[num_decodes - i]
        if decode_idx < num_decodes:
            break

        input_batch.swap_states(prefills[i - 1], decode_idx)
        modified_batch = True

    return modified_batch

set_kv_cache_layout

set_kv_cache_layout(cache_layout: str)
Source code in vllm/v1/attention/backends/utils.py
def set_kv_cache_layout(cache_layout: str):
    global _KV_CACHE_LAYOUT_OVERRIDE
    _KV_CACHE_LAYOUT_OVERRIDE = cache_layout

slice_query_start_locs

slice_query_start_locs(
    query_start_loc: Tensor, request_slice: slice
) -> Tensor

Creates a new query_start_loc that corresponds to the requests in request_slice.

Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility.

Source code in vllm/v1/attention/backends/utils.py
def slice_query_start_locs(
    query_start_loc: torch.Tensor,
    request_slice: slice,
) -> torch.Tensor:
    """
    Creates a new query_start_loc that corresponds to the requests in 
    request_slice.

    Note: This function creates a new tensor to hold the new query_start_locs.
    This will break cudagraph compatibility.
    """
    return query_start_loc[request_slice.start: request_slice.stop + 1] -\
        query_start_loc[request_slice.start]

split_attn_metadata

split_attn_metadata(
    ubatch_slices: list[UbatchSlice],
    common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]

Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UbatchSlice in ubatch_slices.

Note: This function does not modify common_attn_metadata

Source code in vllm/v1/attention/backends/utils.py
def split_attn_metadata(
    ubatch_slices: list[UbatchSlice],
    common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
    """
    Creates a new CommonAttentionMetadata instance that corresponds to the 
    requests for each UbatchSlice in ubatch_slices.

    Note: This function does not modify common_attn_metadata
    """
    results = []
    for ubatch_slice in ubatch_slices:
        results.append(
            _make_metadata_with_slice(ubatch_slice, common_attn_metadata))
    return results

split_decodes_and_prefills

split_decodes_and_prefills(
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
) -> tuple[int, int, int, int]

Assuming a reordered batch, finds the boundary between prefill and decode requests.

Parameters:

Name Type Description Default
common_attn_metadata CommonAttentionMetadata

CommonAttentionMetadata object containing the batch metadata.

required
decode_threshold int

The maximum query length to be considered a decode.

1

Returns:

Name Type Description
num_decodes int

The number of decode requests.

num_prefills int

The number of prefill requests.

num_decode_tokens int

The number of tokens in the decode requests.

num_prefill_tokens int

The number of tokens in the prefill requests.

Source code in vllm/v1/attention/backends/utils.py
def split_decodes_and_prefills(
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
) -> tuple[int, int, int, int]:
    """
    Assuming a reordered batch, finds the boundary between prefill and decode
    requests.

    Args:
        common_attn_metadata: CommonAttentionMetadata object containing the
            batch metadata.
        decode_threshold: The maximum query length to be considered a decode.

    Returns:
        num_decodes: The number of decode requests.
        num_prefills: The number of prefill requests.
        num_decode_tokens: The number of tokens in the decode requests.
        num_prefill_tokens: The number of tokens in the prefill requests.
    """
    max_query_len = common_attn_metadata.max_query_len
    num_reqs = common_attn_metadata.num_reqs
    num_tokens = common_attn_metadata.num_actual_tokens
    query_start_loc = common_attn_metadata.query_start_loc_cpu

    if max_query_len <= decode_threshold:
        return num_reqs, 0, num_tokens, 0

    query_lens = query_start_loc[1:] - query_start_loc[:-1]
    is_prefill = query_lens > decode_threshold
    if not torch.any(is_prefill):
        return num_reqs, 0, num_tokens, 0

    first_prefill = is_prefill.int().argmax(dim=-1).item()
    assert torch.all(query_lens[first_prefill:] > decode_threshold)
    assert torch.all(query_lens[:first_prefill] <= decode_threshold)
    num_decodes = first_prefill
    num_prefills = num_reqs - num_decodes
    num_decode_tokens = query_start_loc[first_prefill].item()
    num_prefill_tokens = num_tokens - num_decode_tokens
    return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)

subclass_attention_backend

subclass_attention_backend(
    name_prefix: str,
    attention_backend_cls: type[AttentionBackend],
    builder_cls: type[AttentionMetadataBuilder[M]],
) -> type[AttentionBackend]

Return a new subclass where get_builder_cls returns builder_cls.

Source code in vllm/v1/attention/backends/utils.py
def subclass_attention_backend(
        name_prefix: str, attention_backend_cls: type[AttentionBackend],
        builder_cls: type[AttentionMetadataBuilder[M]]
) -> type[AttentionBackend]:
    """
    Return a new subclass where `get_builder_cls` returns `builder_cls`.
    """
    name: str = name_prefix + attention_backend_cls.__name__  # type: ignore

    return type(name, (attention_backend_cls, ),
                {"get_builder_cls": lambda: builder_cls})

subclass_attention_metadata

subclass_attention_metadata(
    name_prefix: str,
    metadata_cls: Any,
    fields: list[tuple[str, Any, Any]],
) -> Any

Return a new subclass of metadata_cls with additional fields

Source code in vllm/v1/attention/backends/utils.py
def subclass_attention_metadata(
    name_prefix: str,
    metadata_cls: Any,
    fields: list[tuple[str, Any, Any]],
) -> Any:
    """
    Return a new subclass of `metadata_cls` with additional fields
    """
    name: str = name_prefix + metadata_cls.__name__  # type: ignore
    Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
    return Wrapped