Skip to content

vllm.model_executor.warmup.deep_gemm_warmup

Warmup deep_gemm kernels. DeepGEMM JIT's the kernels. The warmup aims to JIT all the kernels that would be used during model execution beforehand.

FP8_GEMM_NT_WARMUP_CACHE module-attribute

FP8_GEMM_NT_WARMUP_CACHE: set[Size] = set()

GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE module-attribute

GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[Size] = (
    set()
)

_deepgemm_fp8_gemm_nt_warmup

_deepgemm_fp8_gemm_nt_warmup(
    w: Tensor, ws: Tensor, max_tokens: int
)
Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
                                 max_tokens: int):
    if w.size() in FP8_GEMM_NT_WARMUP_CACHE:
        return

    n, k = w.size()
    block_m = deep_gemm_block_shape()[0]

    device = w.device
    a1q = torch.empty((max_tokens, k),
                      device=device,
                      dtype=torch.float8_e4m3fn)
    a1q_scales = torch.empty((max_tokens, k // block_m),
                             device=device,
                             dtype=torch.float32)
    out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)

    pbar = tqdm(total=max_tokens,
                desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
    num_tokens = max_tokens
    while num_tokens > 0:
        fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws),
                    out[:num_tokens])
        pbar.update(1)
        num_tokens -= 1

    FP8_GEMM_NT_WARMUP_CACHE.add(w.size())

_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup

_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
    w1: Tensor,
    w2: Tensor,
    w1_scale: Tensor,
    w2_scale: Tensor,
    num_topk: int,
)
Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
                                                    w2: torch.Tensor,
                                                    w1_scale: torch.Tensor,
                                                    w2_scale: torch.Tensor,
                                                    num_topk: int):
    if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
            and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE):
        return

    assert w1.size(0) == w2.size(0), (
        "w1 and w2 must have the same number of experts")

    block_m = deep_gemm_block_shape()[0]
    num_experts = w1.size(0)
    device = w1.device

    # This is the maximum GroupedGemm M size that we expect to run
    # the grouped_gemm with.
    MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE,
                              num_topk,
                              num_experts,
                              block_m,
                              expert_tokens_meta=None)
    # Distribute expert-ids evenly.
    MAX_BLOCKS = MAX_M // block_m
    expert_ids_block = torch.randint(low=0,
                                     high=num_experts,
                                     size=(MAX_BLOCKS, ),
                                     device=device,
                                     dtype=torch.int32)
    expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)

    def _warmup(w: torch.Tensor, w_scale: torch.Tensor):

        _, n, k = w.size()
        a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn)
        a1q_scales = torch.empty((MAX_M, k // block_m),
                                 device=device,
                                 dtype=torch.float32)
        out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)

        pbar = tqdm(
            total=MAX_BLOCKS,
            desc=
            f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})"
        )
        num_tokens = MAX_M
        while num_tokens > 0:
            m_grouped_fp8_gemm_nt_contiguous(
                (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale),
                out[:num_tokens], expert_ids[:num_tokens])
            pbar.update(1)
            num_tokens = num_tokens - block_m

    for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
        if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
            _warmup(w, ws)
            GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE.add(w.size())

_extract_data_from_fused_moe_module

_extract_data_from_fused_moe_module(
    m: Module,
) -> tuple[Tensor, Tensor, Tensor, Tensor, int]

Extract weights, weight scales and num_topk from FusedMoE module.

Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def _extract_data_from_fused_moe_module(
    m: torch.nn.Module
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
    """
    Extract weights, weight scales and num_topk from FusedMoE module.
    """
    assert isinstance(m, FusedMoE)
    w13 = m.w13_weight
    w13_s = m.w13_weight_scale_inv
    w2 = m.w2_weight
    w2_s = m.w2_weight_scale_inv
    num_topk = m.top_k

    assert isinstance(w13, torch.Tensor)
    assert isinstance(w13_s, torch.Tensor)
    assert isinstance(w2, torch.Tensor)
    assert isinstance(w2_s, torch.Tensor)
    return w13, w13_s, w2, w2_s, num_topk

_extract_data_from_linear_base_module

_extract_data_from_linear_base_module(
    m: Module,
) -> tuple[Tensor, Tensor, list[int]]

Extract weights, weight scales and quantization block sizes from the given LinearBase module.

Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def _extract_data_from_linear_base_module(
        m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
    """
    Extract weights, weight scales and quantization block sizes from the given
    LinearBase module.
    """
    assert isinstance(m, LinearBase)
    assert isinstance(m.quant_method, Fp8LinearMethod)
    assert m.quant_method.block_quant
    assert m.quant_method.quant_config is not None

    w = m.weight
    ws = m.weight_scale_inv
    quant_block_size = m.quant_method.quant_config.weight_block_size

    assert isinstance(w, torch.Tensor)
    assert isinstance(ws, torch.Tensor)
    assert quant_block_size is not None
    return (w, ws, quant_block_size)

_fp8_linear_may_use_deep_gemm

_fp8_linear_may_use_deep_gemm(module: Module) -> bool

Return True if the input module/layer could be processed with DeepGEMM.

Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
    """
    Return True if the input module/layer could be processed with DeepGEMM.
    """
    block_size = deep_gemm_block_shape()[0]
    if not (isinstance(module, LinearBase)
            and isinstance(module.quant_method, Fp8LinearMethod)
            and module.quant_method.block_quant):
        return False

    w, _, block_sizes = _extract_data_from_linear_base_module(module)
    return (block_sizes == deep_gemm_block_shape() and w.ndim == 2
            and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0)

_fused_moe_grouped_gemm_may_use_deep_gemm

_fused_moe_grouped_gemm_may_use_deep_gemm(
    module: Module,
) -> bool
Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
    if not (isinstance(module, FusedMoE)
            and module.moe_config.quant_dtype == torch.float8_e4m3fn
            and module.moe_config.block_shape == deep_gemm_block_shape()):
        return False

    if not isinstance(module.quant_method.fused_experts,
                      FusedMoEModularKernel):
        # fused_experts could invoke deep_gemm_moe_fp8
        return True

    mk: FusedMoEModularKernel = module.quant_method.fused_experts
    # Further check if the ModularKernel implementation uses the DeepGemmExperts
    return isinstance(mk.fused_experts,
                      (DeepGemmExperts, TritonOrDeepGemmExperts))

deep_gemm_warmup

deep_gemm_warmup(model: Module, max_tokens: int)
Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
    deepgemm_fp8_gemm_nt_warmup(model, max_tokens)
    deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model)

deepgemm_fp8_gemm_nt_warmup

deepgemm_fp8_gemm_nt_warmup(model: Module, max_tokens: int)
Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
    dg_modules = [
        m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)
    ]

    for dgm in dg_modules:
        w, ws, _ = _extract_data_from_linear_base_module(dgm)
        _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens)

deepgemm_grouped_fp8_gemm_nt_contiguous_warmup

deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
    model: Module,
)
Source code in vllm/model_executor/warmup/deep_gemm_warmup.py
def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
    dg_modules = [
        m for m in model.modules()
        if _fused_moe_grouped_gemm_may_use_deep_gemm(m)
    ]

    for dgm in dg_modules:
        w13, w13_scale, w2, w2_scale, num_topk = (
            _extract_data_from_fused_moe_module(dgm))
        _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
            w13, w2, w13_scale, w2_scale, num_topk)