Skip to content

vllm.v1.attention.backends.tree_attn

Attention layer with TreeAttention.

logger module-attribute

logger = init_logger(__name__)

TreeAttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/tree_attn.py
class TreeAttentionBackend(AttentionBackend):

    accept_output_buffer: bool = True

    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [32, 64, 96, 128, 160, 192, 224, 256]

    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

    @staticmethod
    def get_name() -> str:
        return "TREE_ATTN_VLLM_V1"

    @staticmethod
    def get_impl_cls() -> type["TreeAttentionImpl"]:
        return TreeAttentionImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return TreeAttentionMetadata

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)

    @staticmethod
    def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
        return TreeAttentionMetadataBuilder

    @staticmethod
    def use_cascade_attention(*args, **kwargs) -> bool:
        return False

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[TreeAttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/tree_attn.py
@staticmethod
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
    return TreeAttentionMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[TreeAttentionImpl]
Source code in vllm/v1/attention/backends/tree_attn.py
@staticmethod
def get_impl_cls() -> type["TreeAttentionImpl"]:
    return TreeAttentionImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/tree_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")
    return (2, num_blocks, block_size, num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/tree_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return TreeAttentionMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/tree_attn.py
@staticmethod
def get_name() -> str:
    return "TREE_ATTN_VLLM_V1"

get_supported_dtypes classmethod

get_supported_dtypes() -> list[dtype]
Source code in vllm/v1/attention/backends/tree_attn.py
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
    return [torch.float16, torch.bfloat16]

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/tree_attn.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return [32, 64, 96, 128, 160, 192, 224, 256]

use_cascade_attention staticmethod

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/tree_attn.py
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
    return False

validate_head_size classmethod

validate_head_size(head_size: int) -> None
Source code in vllm/v1/attention/backends/tree_attn.py
@classmethod
def validate_head_size(cls, head_size: int) -> None:
    supported_head_sizes = cls.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        attn_type = cls.__name__.removesuffix("Backend")
        raise ValueError(
            f"Head size {head_size} is not supported by {attn_type}. "
            f"Supported head sizes are: {supported_head_sizes}. "
            "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
            "FlexAttention backend which supports all head sizes.")

TreeAttentionImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/tree_attn.py
class TreeAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
        attn_type: AttentionType = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[str] = None,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if logits_soft_cap is None:
            # Setting logits_soft_cap to 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)

        TreeAttentionBackend.validate_head_size(head_size)

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "TreeAttentionImpl.")

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: TreeAttentionMetadata,
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
        output_block_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with TreeAttention.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for TreeAttentionImpl")

        if attn_metadata is None:
            # Profiling run.
            return output

        # Cache the input KVs.
        key_cache, value_cache = kv_cache.unbind(0)
        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

        num_actual_tokens = attn_metadata.num_actual_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
                         key.shape[1])
        if prefill_meta := attn_metadata.prefill_metadata:
            unified_attention(
                q=query[num_decode_tokens:num_actual_tokens],
                k=key_cache,
                v=value_cache,
                out=output[num_decode_tokens:num_actual_tokens],
                cu_seqlens_q=prefill_meta.query_start_loc,
                max_seqlen_q=prefill_meta.max_query_len,
                seqused_k=prefill_meta.seq_lens,
                max_seqlen_k=prefill_meta.max_seq_len,
                softmax_scale=self.scale,
                causal=True,
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                block_table=prefill_meta.block_table,
                softcap=self.logits_soft_cap,
                q_descale=None,  # Not supported
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
            )

        if decode_meta := attn_metadata.decode_metadata:
            unified_attention(
                q=query[:num_decode_tokens],
                k=key_cache,
                v=value_cache,
                out=output[:num_decode_tokens],
                cu_seqlens_q=decode_meta.query_start_loc,
                max_seqlen_q=decode_meta.max_query_len,
                seqused_k=decode_meta.seq_lens,
                max_seqlen_k=decode_meta.max_seq_len,
                softmax_scale=self.scale,
                causal=True,
                alibi_slopes=self.alibi_slopes,
                qq_bias=decode_meta.tree_attn_bias,
                window_size=self.sliding_window,
                block_table=decode_meta.block_table,
                softcap=self.logits_soft_cap,
                q_descale=None,  # Not supported
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
            )
        return output

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = float(scale)

sliding_window instance-attribute

sliding_window = (-1, -1)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
) -> None
Source code in vllm/v1/attention/backends/tree_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: AttentionType = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads
    self.kv_cache_dtype = kv_cache_dtype
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    if logits_soft_cap is None:
        # Setting logits_soft_cap to 0 means no soft cap.
        logits_soft_cap = 0
    self.logits_soft_cap = logits_soft_cap
    if sliding_window is None:
        self.sliding_window = (-1, -1)
    else:
        self.sliding_window = (sliding_window - 1, 0)

    TreeAttentionBackend.validate_head_size(head_size)

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "TreeAttentionImpl.")

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: TreeAttentionMetadata,
    output: Optional[Tensor] = None,
    output_scale: Optional[Tensor] = None,
    output_block_scale: Optional[Tensor] = None,
) -> Tensor

Forward pass with TreeAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
attn_metadata TreeAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/v1/attention/backends/tree_attn.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: TreeAttentionMetadata,
    output: Optional[torch.Tensor] = None,
    output_scale: Optional[torch.Tensor] = None,
    output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with TreeAttention.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    assert output is not None, "Output tensor must be provided."

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported"
            " for TreeAttentionImpl")

    if attn_metadata is None:
        # Profiling run.
        return output

    # Cache the input KVs.
    key_cache, value_cache = kv_cache.unbind(0)
    if self.kv_sharing_target_layer_name is None:
        # Reshape the input keys and values and store them in the cache.
        # Skip this if sharing KV cache with an earlier attention layer.
        # NOTE(woosuk): Here, key and value are padded while slot_mapping is
        # not padded. However, we don't need to do key[:num_actual_tokens]
        # and value[:num_actual_tokens] because the reshape_and_cache_flash
        # op uses the slot_mapping's shape to determine the number of
        # actual tokens.
        ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

    num_actual_tokens = attn_metadata.num_actual_tokens
    num_decode_tokens = attn_metadata.num_decode_tokens
    descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
                     key.shape[1])
    if prefill_meta := attn_metadata.prefill_metadata:
        unified_attention(
            q=query[num_decode_tokens:num_actual_tokens],
            k=key_cache,
            v=value_cache,
            out=output[num_decode_tokens:num_actual_tokens],
            cu_seqlens_q=prefill_meta.query_start_loc,
            max_seqlen_q=prefill_meta.max_query_len,
            seqused_k=prefill_meta.seq_lens,
            max_seqlen_k=prefill_meta.max_seq_len,
            softmax_scale=self.scale,
            causal=True,
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            block_table=prefill_meta.block_table,
            softcap=self.logits_soft_cap,
            q_descale=None,  # Not supported
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
        )

    if decode_meta := attn_metadata.decode_metadata:
        unified_attention(
            q=query[:num_decode_tokens],
            k=key_cache,
            v=value_cache,
            out=output[:num_decode_tokens],
            cu_seqlens_q=decode_meta.query_start_loc,
            max_seqlen_q=decode_meta.max_query_len,
            seqused_k=decode_meta.seq_lens,
            max_seqlen_k=decode_meta.max_seq_len,
            softmax_scale=self.scale,
            causal=True,
            alibi_slopes=self.alibi_slopes,
            qq_bias=decode_meta.tree_attn_bias,
            window_size=self.sliding_window,
            block_table=decode_meta.block_table,
            softcap=self.logits_soft_cap,
            q_descale=None,  # Not supported
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
        )
    return output

TreeAttentionMetadata dataclass

Source code in vllm/v1/attention/backends/tree_attn.py
@dataclass
class TreeAttentionMetadata:
    num_actual_tokens: int  # Number of tokens excluding padding.
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table: torch.Tensor
    slot_mapping: torch.Tensor

    num_prefill_tokens: int = 0
    num_decode_tokens: int = 0
    num_prefills: int = 0
    num_decodes: int = 0

    tree_attn_bias: Optional[torch.Tensor] = None

    # Cached Prefill/decode metadata.
    _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
    _cached_decode_metadata: Optional["TreeAttentionMetadata"] = None

    @property
    def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            # Recover cached prefill-phase attention
            # metadata structure
            return self._cached_prefill_metadata

        q_start_loc = self.query_start_loc[self.num_decodes:]
        q_seqlens = torch.diff(q_start_loc)
        kv_seqlens = self.seq_lens[self.num_decodes:]
        # Construct & cache prefill-phase attention metadata structure
        self._cached_prefill_metadata = TreeAttentionMetadata(
            num_actual_tokens=self.num_prefill_tokens,
            max_query_len=int(q_seqlens.max().item()),
            query_start_loc=q_start_loc - q_start_loc[0],
            max_seq_len=int(kv_seqlens.max().item()),
            seq_lens=kv_seqlens,
            block_table=self.block_table[self.num_decodes:],
            slot_mapping=self.slot_mapping[self.num_decode_tokens:],
        )
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            # Recover cached decode-phase attention
            # metadata structure
            return self._cached_decode_metadata

        q_start_loc = self.query_start_loc[:self.num_decodes + 1]
        q_seqlens = torch.diff(q_start_loc)
        kv_seqlens = self.seq_lens[:self.num_decodes]
        # Construct & cache decode-phase attention metadata structure
        self._cached_decode_metadata = TreeAttentionMetadata(
            num_actual_tokens=self.num_decode_tokens,
            max_query_len=int(q_seqlens.max().item()),
            query_start_loc=q_start_loc,
            max_seq_len=int(kv_seqlens.max().item()),
            seq_lens=kv_seqlens,
            block_table=self.block_table[:self.num_decodes],
            slot_mapping=self.slot_mapping[:self.num_decode_tokens],
            tree_attn_bias=self.tree_attn_bias,
        )
        return self._cached_decode_metadata

_cached_decode_metadata class-attribute instance-attribute

_cached_decode_metadata: Optional[TreeAttentionMetadata] = (
    None
)

_cached_prefill_metadata class-attribute instance-attribute

_cached_prefill_metadata: Optional[
    TreeAttentionMetadata
] = None

block_table instance-attribute

block_table: Tensor

decode_metadata property

decode_metadata: Optional[TreeAttentionMetadata]

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_decode_tokens class-attribute instance-attribute

num_decode_tokens: int = 0

num_decodes class-attribute instance-attribute

num_decodes: int = 0

num_prefill_tokens class-attribute instance-attribute

num_prefill_tokens: int = 0

num_prefills class-attribute instance-attribute

num_prefills: int = 0

prefill_metadata property

prefill_metadata: Optional[TreeAttentionMetadata]

query_start_loc instance-attribute

query_start_loc: Tensor

seq_lens instance-attribute

seq_lens: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

tree_attn_bias class-attribute instance-attribute

tree_attn_bias: Optional[Tensor] = None

__init__

__init__(
    num_actual_tokens: int,
    max_query_len: int,
    query_start_loc: Tensor,
    max_seq_len: int,
    seq_lens: Tensor,
    block_table: Tensor,
    slot_mapping: Tensor,
    num_prefill_tokens: int = 0,
    num_decode_tokens: int = 0,
    num_prefills: int = 0,
    num_decodes: int = 0,
    tree_attn_bias: Optional[Tensor] = None,
    _cached_prefill_metadata: Optional[
        TreeAttentionMetadata
    ] = None,
    _cached_decode_metadata: Optional[
        TreeAttentionMetadata
    ] = None,
) -> None

TreeAttentionMetadataBuilder

Bases: AttentionMetadataBuilder[TreeAttentionMetadata]

Source code in vllm/v1/attention/backends/tree_attn.py
class TreeAttentionMetadataBuilder(
        AttentionMetadataBuilder[TreeAttentionMetadata]):

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.kv_cache_spec = kv_cache_spec
        self.block_size = kv_cache_spec.block_size

        spec_config = vllm_config.speculative_config
        spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
        tree_choices: list[tuple[int,
                                 ...]] = (ast.literal_eval(spec_token_tree)
                                          if spec_token_tree is not None else
                                          [(0, )])
        # Construct the tree attention bias.
        depth_counts = _get_depth_counts(tree_choices)
        self.tree_attn_bias = _prepare_tree_attn_bias(
            tree_choices,
            depth_counts,
            dtype=torch.float32,
            device=device,
        )

    def reorder_batch(self, input_batch: "InputBatch",
                      scheduler_output: "SchedulerOutput") -> bool:
        return reorder_batch_to_split_decodes_and_prefills(
            input_batch,
            scheduler_output,
            decode_threshold=self.tree_attn_bias.shape[0])

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> TreeAttentionMetadata:
        decode_threshold = self.tree_attn_bias.shape[0]
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(common_attn_metadata,
                                       decode_threshold=decode_threshold))

        num_actual_tokens = common_attn_metadata.num_actual_tokens
        q_start_loc = common_attn_metadata.query_start_loc
        max_query_len = common_attn_metadata.max_query_len
        kv_seqlens = common_attn_metadata.seq_lens
        max_seq_len = common_attn_metadata.max_seq_len
        block_table = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping

        return TreeAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_decodes=num_decodes,
            max_query_len=max_query_len,
            query_start_loc=q_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=kv_seqlens,
            block_table=block_table,
            slot_mapping=slot_mapping,
            tree_attn_bias=self.tree_attn_bias,
        )

    def build_for_drafting(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        draft_index: int,
    ) -> TreeAttentionMetadata:
        # Cache the original tree attention bias.
        orig_tree_attn_bias = self.tree_attn_bias

        if draft_index == 0:
            # Use prefill for drafting at the root level.
            self.tree_attn_bias = torch.empty(0)
        else:
            # Slice the tree attention bias for drafting. Exclude
            # the root level.
            start, end = 1, 1 + common_attn_metadata.max_query_len
            self.tree_attn_bias = self.tree_attn_bias[start:end,
                                                      start:end].contiguous()

        # Build attention bias.
        attn_metadata = self.build(0, common_attn_metadata, fast_build=True)

        # Reset the tree attention bias to the original value.
        self.tree_attn_bias = orig_tree_attn_bias
        return attn_metadata

block_size instance-attribute

block_size = block_size

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

tree_attn_bias instance-attribute

tree_attn_bias = _prepare_tree_attn_bias(
    tree_choices, depth_counts, dtype=float32, device=device
)

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/tree_attn.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    self.kv_cache_spec = kv_cache_spec
    self.block_size = kv_cache_spec.block_size

    spec_config = vllm_config.speculative_config
    spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
    tree_choices: list[tuple[int,
                             ...]] = (ast.literal_eval(spec_token_tree)
                                      if spec_token_tree is not None else
                                      [(0, )])
    # Construct the tree attention bias.
    depth_counts = _get_depth_counts(tree_choices)
    self.tree_attn_bias = _prepare_tree_attn_bias(
        tree_choices,
        depth_counts,
        dtype=torch.float32,
        device=device,
    )

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> TreeAttentionMetadata
Source code in vllm/v1/attention/backends/tree_attn.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> TreeAttentionMetadata:
    decode_threshold = self.tree_attn_bias.shape[0]
    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
        split_decodes_and_prefills(common_attn_metadata,
                                   decode_threshold=decode_threshold))

    num_actual_tokens = common_attn_metadata.num_actual_tokens
    q_start_loc = common_attn_metadata.query_start_loc
    max_query_len = common_attn_metadata.max_query_len
    kv_seqlens = common_attn_metadata.seq_lens
    max_seq_len = common_attn_metadata.max_seq_len
    block_table = common_attn_metadata.block_table_tensor
    slot_mapping = common_attn_metadata.slot_mapping

    return TreeAttentionMetadata(
        num_actual_tokens=num_actual_tokens,
        num_prefill_tokens=num_prefill_tokens,
        num_decode_tokens=num_decode_tokens,
        num_prefills=num_prefills,
        num_decodes=num_decodes,
        max_query_len=max_query_len,
        query_start_loc=q_start_loc,
        max_seq_len=max_seq_len,
        seq_lens=kv_seqlens,
        block_table=block_table,
        slot_mapping=slot_mapping,
        tree_attn_bias=self.tree_attn_bias,
    )

build_for_drafting

build_for_drafting(
    common_attn_metadata: CommonAttentionMetadata,
    draft_index: int,
) -> TreeAttentionMetadata
Source code in vllm/v1/attention/backends/tree_attn.py
def build_for_drafting(
    self,
    common_attn_metadata: CommonAttentionMetadata,
    draft_index: int,
) -> TreeAttentionMetadata:
    # Cache the original tree attention bias.
    orig_tree_attn_bias = self.tree_attn_bias

    if draft_index == 0:
        # Use prefill for drafting at the root level.
        self.tree_attn_bias = torch.empty(0)
    else:
        # Slice the tree attention bias for drafting. Exclude
        # the root level.
        start, end = 1, 1 + common_attn_metadata.max_query_len
        self.tree_attn_bias = self.tree_attn_bias[start:end,
                                                  start:end].contiguous()

    # Build attention bias.
    attn_metadata = self.build(0, common_attn_metadata, fast_build=True)

    # Reset the tree attention bias to the original value.
    self.tree_attn_bias = orig_tree_attn_bias
    return attn_metadata

reorder_batch

reorder_batch(
    input_batch: InputBatch,
    scheduler_output: SchedulerOutput,
) -> bool
Source code in vllm/v1/attention/backends/tree_attn.py
def reorder_batch(self, input_batch: "InputBatch",
                  scheduler_output: "SchedulerOutput") -> bool:
    return reorder_batch_to_split_decodes_and_prefills(
        input_batch,
        scheduler_output,
        decode_threshold=self.tree_attn_bias.shape[0])

_get_depth_counts

_get_depth_counts(
    sorted_tree_choices: list[tuple[int, ...]],
) -> list[int]
Source code in vllm/v1/attention/backends/tree_attn.py
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
    # Count the number of choices at each depth of the tree.
    depth_counts = []
    prev_depth = 0
    for path in sorted_tree_choices:
        depth = len(path)
        if depth != prev_depth:
            depth_counts.append(0)
        depth_counts[depth - 1] += 1
        prev_depth = depth
    return depth_counts

_prepare_tree_attn_bias

_prepare_tree_attn_bias(
    sorted_tree_choices: list[tuple[int, ...]],
    depth_counts: list[int],
    dtype: Optional[dtype],
    device: Optional[device],
) -> Tensor
Source code in vllm/v1/attention/backends/tree_attn.py
def _prepare_tree_attn_bias(
    sorted_tree_choices: list[tuple[int, ...]],
    depth_counts: list[int],
    dtype: Optional[torch.dtype],
    device: Optional[torch.device],
) -> torch.Tensor:
    # +1 comes from the additional root node.
    tree_len = len(sorted_tree_choices) + 1
    tree_attn_mask = torch.full((tree_len, tree_len),
                                -torch.inf,
                                device=device,
                                dtype=dtype)

    # Set diagonal to all zeros. Each token should
    # attend to itself.
    mask_val = 0
    for i in range(tree_len):
        tree_attn_mask[i, i] = mask_val

    # Set root to all zeros. All tokens attend to it.
    tree_attn_mask[:, 0] = mask_val

    # Set all ancestors to zeros.
    start = 0
    for i in range(len(depth_counts)):
        for j in range(depth_counts[i]):
            cur_tree_choice = sorted_tree_choices[start + j]
            # Retrieve ancestor position.
            if len(cur_tree_choice) == 1:
                continue
            ancestor_idx = []
            for c in range(len(cur_tree_choice) - 1):
                ancestor_idx.append(
                    sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
            tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
        start += depth_counts[i]
    return tree_attn_mask