Skip to content

vllm.v1.cudagraph_dispatcher

logger module-attribute

logger = init_logger(__name__)

CudagraphDispatcher

Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs.

The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one for FULL cudagraph runtime mode. The keys are initialized depending on attention support and what cudagraph mode is set in CompilationConfig. The keys stored in dispatcher are the only source of truth for valid cudagraphs that can be dispatched at runtime.

At runtime, the dispatch method generates the runtime cudagraph mode (FULL, PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) based on the input key. After dispatching (commuicate via forward context), the cudagraph wrappers will trust the dispatch key to do either capturing or replaying (if mode matched), or pass through to the underlying runnable without cudagraph (if mode no match or mode is NONE).

Source code in vllm/v1/cudagraph_dispatcher.py
class CudagraphDispatcher:
    """
    Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs.

    The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
    for FULL cudagraph runtime mode. The keys are initialized depending on 
    attention support and what cudagraph mode is set in CompilationConfig. The 
    keys stored in dispatcher are the only source of truth for valid
    cudagraphs that can be dispatched at runtime.

    At runtime, the dispatch method generates the runtime cudagraph mode (FULL, 
    PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
    based on the input key. After dispatching (commuicate via forward context), 
    the cudagraph wrappers will trust the dispatch key to do either capturing
    or replaying (if mode matched), or pass through to the underlying runnable 
    without cudagraph (if mode no match or mode is NONE).
    """

    def __init__(self, vllm_config: VllmConfig):
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.cudagraph_mode = self.compilation_config.cudagraph_mode

        # Dict to store valid cudagraph dispatching keys.
        self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
            CUDAGraphMode.PIECEWISE: set(),
            CUDAGraphMode.FULL: set(),
        }

        assert not self.cudagraph_mode.requires_piecewise_compilation() or \
            (self.compilation_config.level == CompilationLevel.PIECEWISE and
             self.compilation_config.splitting_ops_contain_attention()), \
            "Compilation level should be CompilationLevel.PIECEWISE when "\
            "cudagraph_mode piecewise cudagraphs is used, "\
            f"cudagraph_mode={self.cudagraph_mode}, "\
            f"compilation_level={self.compilation_config.level}, "\
            f"splitting_ops={self.compilation_config.splitting_ops}"

        self.keys_initialized = False

    def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
                          batch_descriptor: BatchDescriptor):
        assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
            f"Invalid cudagraph runtime mode: {runtime_mode}"
        self.cudagraph_keys[runtime_mode].add(batch_descriptor)

    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
                                  uniform_decode_query_len: int):
        # This should be called only after attention backend is initialized.

        # Note: we create all valid keys possible for cudagraph but do not
        # guarantee all keys would be used. For example, we create keys for
        # piecewise cudagraphs when it is piecewise compilation, which is always
        # valid, but for attention backend support unified routine, we may not
        # trigger capturing/replaying the piecewise cudagraphs depending on
        # CompilationConfig.cudagraph_mode. In addition, if we allow lazy
        # capturing in future PR, some keys may never be triggered.
        if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
            for bs in self.compilation_config.cudagraph_capture_sizes:
                self.add_cudagraph_key(
                    cudagraph_mode.mixed_mode(),
                    BatchDescriptor(num_tokens=bs, uniform_decode=False))

        # if decode cudagraph mode is FULL, and we don't already have mixed
        # mode full cudagraphs then add them here.
        if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \
            and cudagraph_mode.separate_routine():
            max_num_tokens = uniform_decode_query_len * \
                self.vllm_config.scheduler_config.max_num_seqs
            cudagraph_capture_sizes_for_decode = [
                x for x in self.compilation_config.cudagraph_capture_sizes
                if x <= max_num_tokens and x >= uniform_decode_query_len
            ]
            for bs in cudagraph_capture_sizes_for_decode:
                self.add_cudagraph_key(
                    CUDAGraphMode.FULL,
                    BatchDescriptor(num_tokens=bs, uniform_decode=True))
        self.keys_initialized = True

    def dispatch(
        self, batch_descriptor: BatchDescriptor
    ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
        """
        Given a batch descriptor, dispatch to a cudagraph mode.
        A new batch descriptor is returned as we might dispatch a uniform batch 
        to a graph that supports a more general batch (uniform to non-uniform).
        """
        # if not initialized, just skip dispatching.
        if not self.keys_initialized:
            logger.warning_once("cudagraph dispatching keys are not "
                                "initialized. No cudagraph will be used.")
            return CUDAGraphMode.NONE, None

        # check if key exists for full cudagraph
        if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
            return CUDAGraphMode.FULL, batch_descriptor

        # otherwise, check if non-uniform key exists
        non_uniform_key = batch_descriptor.non_uniform
        if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
            return CUDAGraphMode.FULL, non_uniform_key

        # also check if non-uniform key exists for more "general"
        # piecewise cudagraph
        if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
            return CUDAGraphMode.PIECEWISE, non_uniform_key

        # finally, just return no cudagraphs
        return CUDAGraphMode.NONE, None

compilation_config instance-attribute

compilation_config = compilation_config

cudagraph_keys instance-attribute

cudagraph_keys: dict[
    CUDAGraphMode, set[BatchDescriptor]
] = {PIECEWISE: set(), FULL: set()}

cudagraph_mode instance-attribute

cudagraph_mode = cudagraph_mode

keys_initialized instance-attribute

keys_initialized = False

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(vllm_config: VllmConfig)
Source code in vllm/v1/cudagraph_dispatcher.py
def __init__(self, vllm_config: VllmConfig):
    self.vllm_config = vllm_config
    self.compilation_config = vllm_config.compilation_config
    self.cudagraph_mode = self.compilation_config.cudagraph_mode

    # Dict to store valid cudagraph dispatching keys.
    self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
        CUDAGraphMode.PIECEWISE: set(),
        CUDAGraphMode.FULL: set(),
    }

    assert not self.cudagraph_mode.requires_piecewise_compilation() or \
        (self.compilation_config.level == CompilationLevel.PIECEWISE and
         self.compilation_config.splitting_ops_contain_attention()), \
        "Compilation level should be CompilationLevel.PIECEWISE when "\
        "cudagraph_mode piecewise cudagraphs is used, "\
        f"cudagraph_mode={self.cudagraph_mode}, "\
        f"compilation_level={self.compilation_config.level}, "\
        f"splitting_ops={self.compilation_config.splitting_ops}"

    self.keys_initialized = False

add_cudagraph_key

add_cudagraph_key(
    runtime_mode: CUDAGraphMode,
    batch_descriptor: BatchDescriptor,
)
Source code in vllm/v1/cudagraph_dispatcher.py
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
                      batch_descriptor: BatchDescriptor):
    assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
        f"Invalid cudagraph runtime mode: {runtime_mode}"
    self.cudagraph_keys[runtime_mode].add(batch_descriptor)

dispatch

dispatch(
    batch_descriptor: BatchDescriptor,
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]

Given a batch descriptor, dispatch to a cudagraph mode. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform).

Source code in vllm/v1/cudagraph_dispatcher.py
def dispatch(
    self, batch_descriptor: BatchDescriptor
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
    """
    Given a batch descriptor, dispatch to a cudagraph mode.
    A new batch descriptor is returned as we might dispatch a uniform batch 
    to a graph that supports a more general batch (uniform to non-uniform).
    """
    # if not initialized, just skip dispatching.
    if not self.keys_initialized:
        logger.warning_once("cudagraph dispatching keys are not "
                            "initialized. No cudagraph will be used.")
        return CUDAGraphMode.NONE, None

    # check if key exists for full cudagraph
    if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
        return CUDAGraphMode.FULL, batch_descriptor

    # otherwise, check if non-uniform key exists
    non_uniform_key = batch_descriptor.non_uniform
    if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
        return CUDAGraphMode.FULL, non_uniform_key

    # also check if non-uniform key exists for more "general"
    # piecewise cudagraph
    if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
        return CUDAGraphMode.PIECEWISE, non_uniform_key

    # finally, just return no cudagraphs
    return CUDAGraphMode.NONE, None

initialize_cudagraph_keys

initialize_cudagraph_keys(
    cudagraph_mode: CUDAGraphMode,
    uniform_decode_query_len: int,
)
Source code in vllm/v1/cudagraph_dispatcher.py
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
                              uniform_decode_query_len: int):
    # This should be called only after attention backend is initialized.

    # Note: we create all valid keys possible for cudagraph but do not
    # guarantee all keys would be used. For example, we create keys for
    # piecewise cudagraphs when it is piecewise compilation, which is always
    # valid, but for attention backend support unified routine, we may not
    # trigger capturing/replaying the piecewise cudagraphs depending on
    # CompilationConfig.cudagraph_mode. In addition, if we allow lazy
    # capturing in future PR, some keys may never be triggered.
    if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
        for bs in self.compilation_config.cudagraph_capture_sizes:
            self.add_cudagraph_key(
                cudagraph_mode.mixed_mode(),
                BatchDescriptor(num_tokens=bs, uniform_decode=False))

    # if decode cudagraph mode is FULL, and we don't already have mixed
    # mode full cudagraphs then add them here.
    if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \
        and cudagraph_mode.separate_routine():
        max_num_tokens = uniform_decode_query_len * \
            self.vllm_config.scheduler_config.max_num_seqs
        cudagraph_capture_sizes_for_decode = [
            x for x in self.compilation_config.cudagraph_capture_sizes
            if x <= max_num_tokens and x >= uniform_decode_query_len
        ]
        for bs in cudagraph_capture_sizes_for_decode:
            self.add_cudagraph_key(
                CUDAGraphMode.FULL,
                BatchDescriptor(num_tokens=bs, uniform_decode=True))
    self.keys_initialized = True