Skip to content

vllm.v1.kv_cache_interface

logger module-attribute

logger = init_logger(__name__)

AttentionSpec dataclass

Bases: KVCacheSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class AttentionSpec(KVCacheSpec):
    num_kv_heads: int
    head_size: int
    dtype: torch.dtype
    use_mla: bool

    @property
    def page_size_bytes(self) -> int:
        # For MLA we only store a single latent vector
        coef = 1 if self.use_mla else 2
        return coef * self.block_size * self.num_kv_heads * self.head_size \
                * get_dtype_size(self.dtype)

dtype instance-attribute

dtype: dtype

head_size instance-attribute

head_size: int

num_kv_heads instance-attribute

num_kv_heads: int

page_size_bytes property

page_size_bytes: int

use_mla instance-attribute

use_mla: bool

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    use_mla: bool,
) -> None

ChunkedLocalAttentionSpec dataclass

Bases: AttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class ChunkedLocalAttentionSpec(AttentionSpec):
    attention_chunk_size: int

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for at most
        # `self.attention_chunk_size` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
                         max_model_len)

        return cdiv(num_tokens, self.block_size) * self.page_size_bytes

attention_chunk_size instance-attribute

attention_chunk_size: int

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    use_mla: bool,
    attention_chunk_size: int,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    max_model_len = vllm_config.model_config.max_model_len
    max_num_batched_tokens = (
        vllm_config.scheduler_config.max_num_batched_tokens)

    # During chunked prefill, we allocate KV cache for at most
    # `self.attention_chunk_size` computed tokens plus the newly scheduled
    # tokens. And we won't allocate KV cache for more than `max_model_len`
    # tokens.
    num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
                     max_model_len)

    return cdiv(num_tokens, self.block_size) * self.page_size_bytes

EncoderOnlyAttentionSpec dataclass

Bases: AttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # Encoder-only layers do not need KV cache
        return 0

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    use_mla: bool,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    # Encoder-only layers do not need KV cache
    return 0

FullAttentionSpec dataclass

Bases: AttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class FullAttentionSpec(AttentionSpec):
    sliding_window: Optional[int] = None
    attention_chunk_size: Optional[int] = None
    """
    When hybrid allocator is disabled and the model contains both full 
    attention layers and sliding window attention layers, sliding 
    window attention are regarded as full attention in KV cache manager 
    (blocks are allocated for all tokens), while computed as sliding window 
    attention in model runner.
    In this case, we use FullAttentionSpec and record the sliding window size.
    Default to None for not using sliding window attention.
    """

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        return cdiv(max_model_len, self.block_size) * self.page_size_bytes

    @classmethod
    def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
        if len(window_sizes) == 0:
            return None
        elif len(window_sizes) == 1:
            return window_sizes.pop()
        else:
            raise ValueError(
                "All attention layers in the same KV cache group must have the "
                "same window size.")

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of FullAttentionSpec objects into a single 
        FullAttentionSpec object.
        """
        assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
            "All attention layers in the same KV cache group must be "
            "FullAttentionSpec.")

        sliding_window = set(spec.sliding_window for spec in specs
                             if spec.sliding_window is not None)
        attention_chunk_size = set(spec.attention_chunk_size for spec in specs
                                   if spec.attention_chunk_size is not None)
        merged_spec = cls(
            block_size=specs[0].block_size,
            num_kv_heads=specs[0].num_kv_heads,
            head_size=specs[0].head_size,
            dtype=specs[0].dtype,
            use_mla=specs[0].use_mla,
            sliding_window=cls.merge_window_sizes(sliding_window),
            attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
        )
        for spec in specs:
            for f in fields(AttentionSpec):
                assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                    "All attention layers in the same KV cache group must have "
                    "the same attention spec.")
        assert (
            (merged_spec.sliding_window is not None) +
            (merged_spec.attention_chunk_size is not None) <= 1
        ), ("Model with both sliding window layers and chunked local attention "
            "layers is not supported.")
        return merged_spec

attention_chunk_size class-attribute instance-attribute

attention_chunk_size: Optional[int] = None

When hybrid allocator is disabled and the model contains both full attention layers and sliding window attention layers, sliding window attention are regarded as full attention in KV cache manager (blocks are allocated for all tokens), while computed as sliding window attention in model runner. In this case, we use FullAttentionSpec and record the sliding window size. Default to None for not using sliding window attention.

sliding_window class-attribute instance-attribute

sliding_window: Optional[int] = None

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    use_mla: bool,
    sliding_window: Optional[int] = None,
    attention_chunk_size: Optional[int] = None,
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    max_model_len = vllm_config.model_config.max_model_len
    return cdiv(max_model_len, self.block_size) * self.page_size_bytes

merge classmethod

merge(specs: list[Self]) -> Self

Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object.

Source code in vllm/v1/kv_cache_interface.py
@classmethod
def merge(cls, specs: list[Self]) -> Self:
    """
    Merge a list of FullAttentionSpec objects into a single 
    FullAttentionSpec object.
    """
    assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
        "All attention layers in the same KV cache group must be "
        "FullAttentionSpec.")

    sliding_window = set(spec.sliding_window for spec in specs
                         if spec.sliding_window is not None)
    attention_chunk_size = set(spec.attention_chunk_size for spec in specs
                               if spec.attention_chunk_size is not None)
    merged_spec = cls(
        block_size=specs[0].block_size,
        num_kv_heads=specs[0].num_kv_heads,
        head_size=specs[0].head_size,
        dtype=specs[0].dtype,
        use_mla=specs[0].use_mla,
        sliding_window=cls.merge_window_sizes(sliding_window),
        attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
    )
    for spec in specs:
        for f in fields(AttentionSpec):
            assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
                "All attention layers in the same KV cache group must have "
                "the same attention spec.")
    assert (
        (merged_spec.sliding_window is not None) +
        (merged_spec.attention_chunk_size is not None) <= 1
    ), ("Model with both sliding window layers and chunked local attention "
        "layers is not supported.")
    return merged_spec

merge_window_sizes classmethod

merge_window_sizes(window_sizes: set[int]) -> Optional[int]
Source code in vllm/v1/kv_cache_interface.py
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
    if len(window_sizes) == 0:
        return None
    elif len(window_sizes) == 1:
        return window_sizes.pop()
    else:
        raise ValueError(
            "All attention layers in the same KV cache group must have the "
            "same window size.")

KVCacheConfig dataclass

The KV cache configuration of a model.

Source code in vllm/v1/kv_cache_interface.py
@dataclass
class KVCacheConfig:
    """
    The KV cache configuration of a model.
    """
    """The number of KV cache blocks"""
    num_blocks: int
    """How should model runner initialize the KV cache tensors for each layer"""
    kv_cache_tensors: list[KVCacheTensor]
    """
    The kv cache groups of the model.
    For models with only one type of attention, there is only one group that
    contains all layers.
    For models with multiple types of attention, there will be multiple groups,
    see `_get_kv_cache_config_uniform_page_size` for more details.
    """
    kv_cache_groups: list[KVCacheGroupSpec]

kv_cache_groups instance-attribute

kv_cache_groups: list[KVCacheGroupSpec]

kv_cache_tensors instance-attribute

kv_cache_tensors: list[KVCacheTensor]

The kv cache groups of the model. For models with only one type of attention, there is only one group that contains all layers. For models with multiple types of attention, there will be multiple groups, see _get_kv_cache_config_uniform_page_size for more details.

num_blocks instance-attribute

num_blocks: int

How should model runner initialize the KV cache tensors for each layer

__init__

__init__(
    num_blocks: int,
    kv_cache_tensors: list[KVCacheTensor],
    kv_cache_groups: list[KVCacheGroupSpec],
) -> None

KVCacheGroupSpec dataclass

Represents a group of model layers that share the same KV cache block table. These layers are regarded as one layer in the KV cache manager.

Source code in vllm/v1/kv_cache_interface.py
@dataclass
class KVCacheGroupSpec:
    """
    Represents a group of model layers that share the same KV cache block table.
    These layers are regarded as one layer in the KV cache manager.
    """
    # The names of model layers in this group
    layer_names: list[str]
    # The KV cache spec of this manager layer
    kv_cache_spec: KVCacheSpec

kv_cache_spec instance-attribute

kv_cache_spec: KVCacheSpec

layer_names instance-attribute

layer_names: list[str]

__init__

__init__(
    layer_names: list[str], kv_cache_spec: KVCacheSpec
) -> None

KVCacheSpec dataclass

A base class for specifying the KV cache format of one layer.

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class KVCacheSpec:
    """
    A base class for specifying the KV cache format of one layer.
    """

    # number of tokens in a block
    block_size: int

    @property
    def page_size_bytes(self) -> int:
        """
        The size of a page with `block_size` tokens in bytes.

        Returns:
            The page size
        """
        raise NotImplementedError

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        """
        The maximum possible memory usage of this KV cache in bytes.

        Returns:
            The KV cache size in bytes
        """
        raise NotImplementedError

    @classmethod
    def merge(cls, specs: list[Self]) -> Self:
        """
        Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
        """
        assert all(spec == specs[0] for spec in specs[1:]), (
            "All layers in the same KV cache group must be the same.")
        return copy.deepcopy(specs[0])

block_size instance-attribute

block_size: int

page_size_bytes property

page_size_bytes: int

The size of a page with block_size tokens in bytes.

Returns:

Type Description
int

The page size

__init__

__init__(block_size: int) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int

The maximum possible memory usage of this KV cache in bytes.

Returns:

Type Description
int

The KV cache size in bytes

Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    """
    The maximum possible memory usage of this KV cache in bytes.

    Returns:
        The KV cache size in bytes
    """
    raise NotImplementedError

merge classmethod

merge(specs: list[Self]) -> Self

Merge a list of KVCacheSpec objects into a single KVCacheSpec object.

Source code in vllm/v1/kv_cache_interface.py
@classmethod
def merge(cls, specs: list[Self]) -> Self:
    """
    Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
    """
    assert all(spec == specs[0] for spec in specs[1:]), (
        "All layers in the same KV cache group must be the same.")
    return copy.deepcopy(specs[0])

KVCacheTensor dataclass

A class for specifying how the workers should initialize the KV cache.

Source code in vllm/v1/kv_cache_interface.py
@dataclass
class KVCacheTensor:
    """
    A class for specifying how the workers should initialize the KV cache.
    """
    size: int  # size of the KV cache tensor in bytes
    shared_by: list[str]  # layer names that share the same KV cache tensor

shared_by instance-attribute

shared_by: list[str]

size instance-attribute

size: int

__init__

__init__(size: int, shared_by: list[str]) -> None

MambaSpec dataclass

Bases: KVCacheSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
    dtypes: tuple[torch.dtype]
    page_size_padded: Optional[int] = None
    mamba_type: str = "mamba2"

    @property
    def page_size_bytes(self) -> int:
        page_size = sum(
            prod(shape) * get_dtype_size(dtype)
            for (shape, dtype) in zip(self.shapes, self.dtypes))
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # We allocate 1 block for each request now, so max_memory_usage_bytes is
        # the same as page_size_bytes.
        # Need to update this when supporting prefix caching.
        return self.page_size_bytes

dtypes instance-attribute

dtypes: tuple[dtype]

mamba_type class-attribute instance-attribute

mamba_type: str = 'mamba2'

page_size_bytes property

page_size_bytes: int

page_size_padded class-attribute instance-attribute

page_size_padded: Optional[int] = None

shapes instance-attribute

shapes: tuple[tuple[int, ...], ...]

__init__

__init__(
    block_size: int,
    shapes: tuple[tuple[int, ...], ...],
    dtypes: tuple[dtype],
    page_size_padded: Optional[int] = None,
    mamba_type: str = "mamba2",
) -> None

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    # We allocate 1 block for each request now, so max_memory_usage_bytes is
    # the same as page_size_bytes.
    # Need to update this when supporting prefix caching.
    return self.page_size_bytes

SlidingWindowSpec dataclass

Bases: AttentionSpec

Source code in vllm/v1/kv_cache_interface.py
@dataclass(frozen=True)
class SlidingWindowSpec(AttentionSpec):
    sliding_window: int

    def __post_init__(self):
        assert not self.use_mla, "MLA is not supported for sliding window"

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        max_model_len = vllm_config.model_config.max_model_len
        max_num_batched_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens)

        # During chunked prefill, we allocate KV cache for the last
        # `self.sliding_window-1` computed tokens plus the newly scheduled
        # tokens. And we won't allocate KV cache for more than `max_model_len`
        # tokens.
        num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
                         max_model_len)

        # +1 here because the sliding window may not start from the beginning
        # of the block. For example, if the block size is 4 and num_token
        # is 4, we need two blocks [XXCD] [EF] to store the sliding
        # window [CDEF] of 6 tokens.
        return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes

sliding_window instance-attribute

sliding_window: int

__init__

__init__(
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: dtype,
    use_mla: bool,
    sliding_window: int,
) -> None

__post_init__

__post_init__()
Source code in vllm/v1/kv_cache_interface.py
def __post_init__(self):
    assert not self.use_mla, "MLA is not supported for sliding window"

max_memory_usage_bytes

max_memory_usage_bytes(vllm_config: VllmConfig) -> int
Source code in vllm/v1/kv_cache_interface.py
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    max_model_len = vllm_config.model_config.max_model_len
    max_num_batched_tokens = (
        vllm_config.scheduler_config.max_num_batched_tokens)

    # During chunked prefill, we allocate KV cache for the last
    # `self.sliding_window-1` computed tokens plus the newly scheduled
    # tokens. And we won't allocate KV cache for more than `max_model_len`
    # tokens.
    num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
                     max_model_len)

    # +1 here because the sliding window may not start from the beginning
    # of the block. For example, if the block size is 4 and num_token
    # is 4, we need two blocks [XXCD] [EF] to store the sliding
    # window [CDEF] of 6 tokens.
    return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes