Skip to content

vllm.forward_context

_forward_context module-attribute

_forward_context: Optional[ForwardContext] = None

batchsize_forward_time module-attribute

batchsize_forward_time: defaultdict = defaultdict(list)

batchsize_logging_interval module-attribute

batchsize_logging_interval: float = (
    VLLM_LOG_BATCHSIZE_INTERVAL
)

forward_start_time module-attribute

forward_start_time: float = 0

last_logging_time module-attribute

last_logging_time: float = 0

logger module-attribute

logger = init_logger(__name__)

track_batchsize module-attribute

track_batchsize: bool = VLLM_LOG_BATCHSIZE_INTERVAL >= 0

BatchDescriptor

Bases: NamedTuple

Batch descriptor for cudagraph dispatching. We should keep the num of items as minimal as possible to properly and uniquely describe the padded batch for cudagraph.

Source code in vllm/forward_context.py
class BatchDescriptor(NamedTuple):
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """
    num_tokens: int
    uniform_decode: bool = False
    """
    False can also be used for an uniform decode batch to dispatch to the 
    cudagraph supporting non-uniform batches.
    """

    @property
    def non_uniform(self) -> "BatchDescriptor":
        """
        Return a non-uniform version of current batch descriptor.
        """
        return BatchDescriptor(self.num_tokens, uniform_decode=False)

non_uniform property

non_uniform: BatchDescriptor

Return a non-uniform version of current batch descriptor.

num_tokens instance-attribute

num_tokens: int

uniform_decode class-attribute instance-attribute

uniform_decode: bool = False

False can also be used for an uniform decode batch to dispatch to the cudagraph supporting non-uniform batches.

DPMetadata dataclass

Source code in vllm/forward_context.py
@dataclass
class DPMetadata:
    max_tokens_across_dp_cpu: torch.Tensor
    cu_tokens_across_dp_cpu: torch.Tensor
    local_sizes: Optional[list[int]] = None

    @staticmethod
    def num_tokens_across_dp(num_tokens: int, dp_size: int,
                             dp_rank: int) -> torch.Tensor:
        """
        Gather the num_tokens across all DP ranks and return results in a
        CPU tensor of size dp_size.
        """
        num_tokens_across_dp = [0] * dp_size
        num_tokens_across_dp[dp_rank] = num_tokens
        num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                         device="cpu",
                                         dtype=torch.int32)
        from vllm.distributed.parallel_state import get_dp_group
        dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
        return num_tokens_tensor

    @staticmethod
    def make(
            parallel_config: ParallelConfig,
            attn_metadata: Any,
            num_tokens: int,
            num_tokens_across_dp: Optional[torch.Tensor] = None
    ) -> "DPMetadata":

        assert parallel_config.data_parallel_size > 1
        dp_size = parallel_config.data_parallel_size
        dp_rank = parallel_config.data_parallel_rank
        if attn_metadata is not None and hasattr(attn_metadata,
                                                 "num_prefill_tokens"):
            # for v0 attention backends
            batchsize = attn_metadata.num_prefill_tokens + \
                attn_metadata.num_decode_tokens
        else:
            # for v1 attention backends or no attn_metadata
            batchsize = num_tokens

        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
        assert (num_tokens_across_dp is None
                or num_tokens_across_dp[dp_rank] == batchsize)
        if num_tokens_across_dp is None:
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
                batchsize, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)

    @contextmanager
    def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
        """
        Context manager to compute and temporarily set the per-rank local token
        sizes for a specific chunk during chunked forward execution.

        This is necessary to ensure each DP (data parallel) rank processes its
        designated portion of tokens in lockstep with others, even when the
        token counts are uneven or some ranks have completed their input early.

        For chunked execution, we break up the total tokens on each rank into
        multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
        of tokens to process in that chunk on each rank.

        It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
        number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
        to determine the chunk-wise split.

        `self.local_sizes` is only valid inside the context.

        Args:
            max_chunk_size_per_rank: The max number of tokens each rank is 
                                     allowed to process in this chunk.
            chunk_idx: The index of the chunk to compute sizes for.
        """
        cu_sizes = self.cu_tokens_across_dp_cpu
        num_tokens_across_dp_cpu = [
            (cu_sizes[i] -
             cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
            for i in range(len(cu_sizes))
        ]
        self.local_sizes = _compute_chunked_local_num_tokens(
            num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
        return self.local_sizes

cu_tokens_across_dp_cpu instance-attribute

cu_tokens_across_dp_cpu: Tensor

local_sizes class-attribute instance-attribute

local_sizes: Optional[list[int]] = None

max_tokens_across_dp_cpu instance-attribute

max_tokens_across_dp_cpu: Tensor

__init__

__init__(
    max_tokens_across_dp_cpu: Tensor,
    cu_tokens_across_dp_cpu: Tensor,
    local_sizes: Optional[list[int]] = None,
) -> None

chunked_sizes

chunked_sizes(max_chunk_size_per_rank: int, chunk_idx: int)

Context manager to compute and temporarily set the per-rank local token sizes for a specific chunk during chunked forward execution.

This is necessary to ensure each DP (data parallel) rank processes its designated portion of tokens in lockstep with others, even when the token counts are uneven or some ranks have completed their input early.

For chunked execution, we break up the total tokens on each rank into multiple chunks (of at most max_chunk_size_per_rank), and for a given chunk_idx, this context manager sets self.local_sizes to the number of tokens to process in that chunk on each rank.

It uses cumulative sizes (cu_tokens_across_dp_cpu) to derive the number of tokens per rank, and calls _compute_chunked_local_num_tokens to determine the chunk-wise split.

self.local_sizes is only valid inside the context.

Parameters:

Name Type Description Default
max_chunk_size_per_rank int

The max number of tokens each rank is allowed to process in this chunk.

required
chunk_idx int

The index of the chunk to compute sizes for.

required
Source code in vllm/forward_context.py
@contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
    """
    Context manager to compute and temporarily set the per-rank local token
    sizes for a specific chunk during chunked forward execution.

    This is necessary to ensure each DP (data parallel) rank processes its
    designated portion of tokens in lockstep with others, even when the
    token counts are uneven or some ranks have completed their input early.

    For chunked execution, we break up the total tokens on each rank into
    multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
    `chunk_idx`, this context manager sets `self.local_sizes` to the number
    of tokens to process in that chunk on each rank.

    It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
    number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
    to determine the chunk-wise split.

    `self.local_sizes` is only valid inside the context.

    Args:
        max_chunk_size_per_rank: The max number of tokens each rank is 
                                 allowed to process in this chunk.
        chunk_idx: The index of the chunk to compute sizes for.
    """
    cu_sizes = self.cu_tokens_across_dp_cpu
    num_tokens_across_dp_cpu = [
        (cu_sizes[i] -
         cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
        for i in range(len(cu_sizes))
    ]
    self.local_sizes = _compute_chunked_local_num_tokens(
        num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
    try:
        yield self.local_sizes
    finally:
        self.local_sizes = None

get_chunk_sizes_across_dp_rank

get_chunk_sizes_across_dp_rank() -> Optional[list[int]]
Source code in vllm/forward_context.py
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
    return self.local_sizes

make staticmethod

make(
    parallel_config: ParallelConfig,
    attn_metadata: Any,
    num_tokens: int,
    num_tokens_across_dp: Optional[Tensor] = None,
) -> DPMetadata
Source code in vllm/forward_context.py
@staticmethod
def make(
        parallel_config: ParallelConfig,
        attn_metadata: Any,
        num_tokens: int,
        num_tokens_across_dp: Optional[torch.Tensor] = None
) -> "DPMetadata":

    assert parallel_config.data_parallel_size > 1
    dp_size = parallel_config.data_parallel_size
    dp_rank = parallel_config.data_parallel_rank
    if attn_metadata is not None and hasattr(attn_metadata,
                                             "num_prefill_tokens"):
        # for v0 attention backends
        batchsize = attn_metadata.num_prefill_tokens + \
            attn_metadata.num_decode_tokens
    else:
        # for v1 attention backends or no attn_metadata
        batchsize = num_tokens

    # If num_tokens_across_dp is None, it will be computed by all_reduce
    # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
    assert (num_tokens_across_dp is None
            or num_tokens_across_dp[dp_rank] == batchsize)
    if num_tokens_across_dp is None:
        num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
            batchsize, dp_size, dp_rank)
    max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
    cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
    return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)

num_tokens_across_dp staticmethod

num_tokens_across_dp(
    num_tokens: int, dp_size: int, dp_rank: int
) -> Tensor

Gather the num_tokens across all DP ranks and return results in a CPU tensor of size dp_size.

Source code in vllm/forward_context.py
@staticmethod
def num_tokens_across_dp(num_tokens: int, dp_size: int,
                         dp_rank: int) -> torch.Tensor:
    """
    Gather the num_tokens across all DP ranks and return results in a
    CPU tensor of size dp_size.
    """
    num_tokens_across_dp = [0] * dp_size
    num_tokens_across_dp[dp_rank] = num_tokens
    num_tokens_tensor = torch.tensor(num_tokens_across_dp,
                                     device="cpu",
                                     dtype=torch.int32)
    from vllm.distributed.parallel_state import get_dp_group
    dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
    return num_tokens_tensor

ForwardContext dataclass

Source code in vllm/forward_context.py
@dataclass
class ForwardContext:
    # copy from vllm_config.compilation_config.static_forward_context
    no_compile_layers: dict[str, Any]
    """
    Type AttentionMetadata for v0, 
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each 
    attention layer to its attention metadata
    set dynamically for each forward pass
    """
    attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
    # TODO: remove after making all virtual_engines share the same kv cache
    virtual_engine: int  # set dynamically for each forward pass
    # set dynamically for each forward pass
    dp_metadata: Optional[DPMetadata] = None
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
    batch_descriptor: Optional[BatchDescriptor] = None

    def __post_init__(self):
        assert self.cudagraph_runtime_mode in [
            CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"

attn_metadata instance-attribute

batch_descriptor class-attribute instance-attribute

batch_descriptor: Optional[BatchDescriptor] = None

cudagraph_runtime_mode class-attribute instance-attribute

cudagraph_runtime_mode: CUDAGraphMode = NONE

dp_metadata class-attribute instance-attribute

dp_metadata: Optional[DPMetadata] = None

no_compile_layers instance-attribute

no_compile_layers: dict[str, Any]

Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata set dynamically for each forward pass

virtual_engine instance-attribute

virtual_engine: int

__init__

__init__(
    no_compile_layers: dict[str, Any],
    attn_metadata: Union[
        AttentionMetadata, dict[str, AttentionMetadata]
    ],
    virtual_engine: int,
    dp_metadata: Optional[DPMetadata] = None,
    cudagraph_runtime_mode: CUDAGraphMode = NONE,
    batch_descriptor: Optional[BatchDescriptor] = None,
) -> None

__post_init__

__post_init__()
Source code in vllm/forward_context.py
def __post_init__(self):
    assert self.cudagraph_runtime_mode in [
        CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
        f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"

_compute_chunked_local_num_tokens

_compute_chunked_local_num_tokens(
    num_tokens_across_dp_cpu: list[int],
    max_num_tokens: int,
    chunk_idx: int,
) -> list[int]
Source code in vllm/forward_context.py
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
                                      max_num_tokens: int,
                                      chunk_idx: int) -> list[int]:
    dp_size = len(num_tokens_across_dp_cpu)

    local_size = [-1] * dp_size
    for i in range(dp_size):
        dp_tokens = num_tokens_across_dp_cpu[i]
        local_size[i] = min(max_num_tokens,
                            dp_tokens - (max_num_tokens * chunk_idx))
        if local_size[i] <= 0:
            local_size[i] = 1  # ensure lockstep even if done
    return local_size

get_forward_context

get_forward_context() -> ForwardContext

Get the current forward context.

Source code in vllm/forward_context.py
def get_forward_context() -> ForwardContext:
    """Get the current forward context."""
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context.")
    return _forward_context

set_forward_context

set_forward_context(
    attn_metadata: Any,
    vllm_config: VllmConfig,
    virtual_engine: int = 0,
    num_tokens: Optional[int] = None,
    num_tokens_across_dp: Optional[Tensor] = None,
    cudagraph_runtime_mode: CUDAGraphMode = NONE,
    batch_descriptor: Optional[BatchDescriptor] = None,
)

A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass.

Source code in vllm/forward_context.py
@contextmanager
def set_forward_context(
        attn_metadata: Any,
        vllm_config: VllmConfig,
        virtual_engine: int = 0,
        num_tokens: Optional[int] = None,
        num_tokens_across_dp: Optional[torch.Tensor] = None,
        cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
        batch_descriptor: Optional[BatchDescriptor] = None):
    """A context manager that stores the current forward context,
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
    global forward_start_time
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()
    dp_metadata: Optional[DPMetadata] = None
    if vllm_config.parallel_config.data_parallel_size > 1 and (
            attn_metadata is not None or num_tokens is not None):
        dp_metadata = DPMetadata.make(vllm_config.parallel_config,
                                      attn_metadata, num_tokens or 0,
                                      num_tokens_across_dp)

    global _forward_context
    prev_context = _forward_context
    _forward_context = ForwardContext(
        no_compile_layers=vllm_config.compilation_config.
        static_forward_context,
        virtual_engine=virtual_engine,
        attn_metadata=attn_metadata,
        dp_metadata=dp_metadata,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
    )

    try:
        yield
    finally:
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
            if hasattr(attn_metadata, "num_prefill_tokens"):
                # for v0 attention backends
                batchsize = attn_metadata.num_prefill_tokens + \
                    attn_metadata.num_decode_tokens
            else:
                # for v1 attention backends
                batchsize = num_tokens
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            from vllm.platforms import current_platform
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append(
                (now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(("Batchsize forward time stats "
                                 "(batchsize, count, median_time(ms)): %s"),
                                forward_stats)

        _forward_context = prev_context