Skip to content

vllm.compilation.cuda_piecewise_backend

logger module-attribute

logger = init_logger(__name__)

ConcreteSizeEntry dataclass

Source code in vllm/compilation/cuda_piecewise_backend.py
@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    compiled: bool = False
    runnable: Callable = None  # type: ignore

compiled class-attribute instance-attribute

compiled: bool = False

runnable class-attribute instance-attribute

runnable: Callable = None

runtime_shape instance-attribute

runtime_shape: int

__init__

__init__(
    runtime_shape: int,
    compiled: bool = False,
    runnable: Callable = None,
) -> None

PiecewiseBackend

Source code in vllm/compilation/cuda_piecewise_backend.py
class PiecewiseBackend:

    def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
                 piecewise_compile_index: int, total_piecewise_compiles: int,
                 sym_shape_indices: list[int],
                 compiled_graph_for_general_shape: Callable,
                 vllm_backend: VllmBackend):
        """
        The backend for piecewise compilation.
        It mainly handles the compilation of static shapes and 
        dispatching based on runtime shape.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
        """
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = (
            piecewise_compile_index == total_piecewise_compiles - 1)

        self.is_full_graph = total_piecewise_compiles == 1

        self.compile_sizes: set[int] = set(
            self.compilation_config.compile_sizes)

        self.first_run_finished = False

        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa

        self.sym_shape_indices = sym_shape_indices

        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

        # the entries for different shapes that we need to compile
        self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}

        # to_be_compiled_sizes tracks the remaining sizes to compile,
        # and updates during the compilation process, so we need to copy it
        self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()

        # We only keep compilation management inside this class directly.
        for shape in self.compile_sizes:
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                runnable=self.compiled_graph_for_general_shape,
            )

    def check_for_ending_compilation(self):
        if self.is_last_graph and not self.to_be_compiled_sizes:
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
            self.vllm_backend.compiler_manager.save_to_file()
            end_monitoring_torch_compile(self.vllm_config)

    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
            self.check_for_ending_compilation()
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]

        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]

        if not entry.compiled:
            entry.compiled = True
            self.to_be_compiled_sizes.remove(runtime_shape)
            # args are real arguments
            entry.runnable = self.vllm_backend.compiler_manager.compile(
                self.graph,
                args,
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
                runtime_shape=runtime_shape)

            # finished compilations for all required shapes
            if self.is_last_graph and not self.to_be_compiled_sizes:
                self.check_for_ending_compilation()

        return entry.runnable(*args)

compilation_config instance-attribute

compilation_config = compilation_config

compile_sizes instance-attribute

compile_sizes: set[int] = set(compile_sizes)

compiled_graph_for_general_shape instance-attribute

compiled_graph_for_general_shape = (
    compiled_graph_for_general_shape
)

concrete_size_entries instance-attribute

concrete_size_entries: dict[int, ConcreteSizeEntry] = {}

first_run_finished instance-attribute

first_run_finished = False

graph instance-attribute

graph = graph

is_debugging_mode instance-attribute

is_debugging_mode = VLLM_LOGGING_LEVEL == 'DEBUG'

is_first_graph instance-attribute

is_first_graph = piecewise_compile_index == 0

is_full_graph instance-attribute

is_full_graph = total_piecewise_compiles == 1

is_last_graph instance-attribute

is_last_graph = (
    piecewise_compile_index == total_piecewise_compiles - 1
)

piecewise_compile_index instance-attribute

piecewise_compile_index = piecewise_compile_index

sym_shape_indices instance-attribute

sym_shape_indices = sym_shape_indices

to_be_compiled_sizes instance-attribute

to_be_compiled_sizes: set[int] = copy()

total_piecewise_compiles instance-attribute

total_piecewise_compiles = total_piecewise_compiles

vllm_backend instance-attribute

vllm_backend = vllm_backend

vllm_config instance-attribute

vllm_config = vllm_config

__call__

__call__(*args) -> Any
Source code in vllm/compilation/cuda_piecewise_backend.py
def __call__(self, *args) -> Any:
    if not self.first_run_finished:
        self.first_run_finished = True
        self.check_for_ending_compilation()
        return self.compiled_graph_for_general_shape(*args)

    runtime_shape = args[self.sym_shape_indices[0]]

    if runtime_shape not in self.concrete_size_entries:
        # we don't need to do anything for this shape
        return self.compiled_graph_for_general_shape(*args)

    entry = self.concrete_size_entries[runtime_shape]

    if not entry.compiled:
        entry.compiled = True
        self.to_be_compiled_sizes.remove(runtime_shape)
        # args are real arguments
        entry.runnable = self.vllm_backend.compiler_manager.compile(
            self.graph,
            args,
            self.compilation_config.inductor_compile_config,
            self.compilation_config,
            graph_index=self.piecewise_compile_index,
            num_graphs=self.total_piecewise_compiles,
            runtime_shape=runtime_shape)

        # finished compilations for all required shapes
        if self.is_last_graph and not self.to_be_compiled_sizes:
            self.check_for_ending_compilation()

    return entry.runnable(*args)

__init__

__init__(
    graph: GraphModule,
    vllm_config: VllmConfig,
    piecewise_compile_index: int,
    total_piecewise_compiles: int,
    sym_shape_indices: list[int],
    compiled_graph_for_general_shape: Callable,
    vllm_backend: VllmBackend,
)

The backend for piecewise compilation. It mainly handles the compilation of static shapes and dispatching based on runtime shape.

We will compile self.graph once for the general shape, and then compile for different shapes specified in compilation_config.compile_sizes.

Source code in vllm/compilation/cuda_piecewise_backend.py
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
             piecewise_compile_index: int, total_piecewise_compiles: int,
             sym_shape_indices: list[int],
             compiled_graph_for_general_shape: Callable,
             vllm_backend: VllmBackend):
    """
    The backend for piecewise compilation.
    It mainly handles the compilation of static shapes and 
    dispatching based on runtime shape.

    We will compile `self.graph` once for the general shape,
    and then compile for different shapes specified in
    `compilation_config.compile_sizes`.
    """
    self.graph = graph
    self.vllm_config = vllm_config
    self.compilation_config = vllm_config.compilation_config
    self.piecewise_compile_index = piecewise_compile_index
    self.total_piecewise_compiles = total_piecewise_compiles
    self.vllm_backend = vllm_backend

    self.is_first_graph = piecewise_compile_index == 0
    self.is_last_graph = (
        piecewise_compile_index == total_piecewise_compiles - 1)

    self.is_full_graph = total_piecewise_compiles == 1

    self.compile_sizes: set[int] = set(
        self.compilation_config.compile_sizes)

    self.first_run_finished = False

    self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa

    self.sym_shape_indices = sym_shape_indices

    self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

    # the entries for different shapes that we need to compile
    self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}

    # to_be_compiled_sizes tracks the remaining sizes to compile,
    # and updates during the compilation process, so we need to copy it
    self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()

    # We only keep compilation management inside this class directly.
    for shape in self.compile_sizes:
        self.concrete_size_entries[shape] = ConcreteSizeEntry(
            runtime_shape=shape,
            runnable=self.compiled_graph_for_general_shape,
        )

check_for_ending_compilation

check_for_ending_compilation()
Source code in vllm/compilation/cuda_piecewise_backend.py
def check_for_ending_compilation(self):
    if self.is_last_graph and not self.to_be_compiled_sizes:
        # no specific sizes to compile
        # save the hash of the inductor graph for the next run
        self.vllm_backend.compiler_manager.save_to_file()
        end_monitoring_torch_compile(self.vllm_config)