Skip to content

vllm.utils.deep_gemm

Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import only these wrappers.

DEFAULT_BLOCK_SIZE module-attribute

DEFAULT_BLOCK_SIZE = [128, 128]

__all__ module-attribute

__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "per_block_cast_to_fp8",
    "is_blackwell_deep_gemm_e8m0_used",
    "is_deep_gemm_supported",
    "should_use_deepgemm_for_fp8_linear",
]

_fp8_gemm_nt_impl module-attribute

_fp8_gemm_nt_impl: Callable[..., Any] | None = None

_grouped_impl module-attribute

_grouped_impl: Callable[..., Any] | None = None

_grouped_masked_impl module-attribute

_grouped_masked_impl: Callable[..., Any] | None = None

_align

_align(x: int, y: int) -> int
Source code in vllm/utils/deep_gemm.py
def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y

_ceil_to_ue8m0

_ceil_to_ue8m0(x: Tensor)
Source code in vllm/utils/deep_gemm.py
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))

_lazy_init

_lazy_init() -> None

Import deep_gemm and resolve symbols on first use.

Source code in vllm/utils/deep_gemm.py
def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl

    # fast path
    if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
            or _grouped_masked_impl is not None):
        return

    if not has_deep_gemm():
        return

    # Set up deep_gemm cache path
    DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
            envs.VLLM_CACHE_ROOT, "deep_gemm")

    _dg = importlib.import_module("deep_gemm")

    _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
                                        "gemm_fp8_fp8_bf16_nt")
    _grouped_impl = _resolve_symbol(
        _dg, "m_grouped_fp8_gemm_nt_contiguous",
        "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
    _grouped_masked_impl = _resolve_symbol(
        _dg, "fp8_m_grouped_gemm_nt_masked",
        "m_grouped_gemm_fp8_fp8_bf16_nt_masked")

_missing

_missing(*_: Any, **__: Any) -> NoReturn

Placeholder for unavailable DeepGEMM backend.

Source code in vllm/utils/deep_gemm.py
def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available. Please install the `deep_gemm` "
        "package to enable FP8 kernels.")

_resolve_symbol

_resolve_symbol(
    module, new: str, old: str
) -> Callable[..., Any] | None

Return the new symbol if it exists, otherwise the old one.

Source code in vllm/utils/deep_gemm.py
def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
    """Return the *new* symbol if it exists, otherwise the *old* one."""
    if hasattr(module, new):
        return getattr(module, new)
    if hasattr(module, old):
        # TODO(wentao): deprecate old symbol in the future.
        logger.warning_once(
            "Found legacy DeepGEMM symbol `%s`. Please upgrade the `deep_gemm` "
            "package so that `%s` is available. Support for the legacy symbol "
            "will be removed in a future vLLM release.",
            old,
            new,
        )
        return getattr(module, old)
    return None

calc_diff

calc_diff(x: Tensor, y: Tensor)

Return a global difference metric for unit tests.

DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing torch.testing.assert_close to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report 1 - sim. Once kernel accuracy improves this helper can be removed.

Source code in vllm/utils/deep_gemm.py
def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim

fp8_gemm_nt

fp8_gemm_nt(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def fp8_gemm_nt(*args, **kwargs):
    _lazy_init()
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    return _fp8_gemm_nt_impl(
        *args,
        disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
        **kwargs)

fp8_m_grouped_gemm_nt_masked

fp8_m_grouped_gemm_nt_masked(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
    _lazy_init()
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_masked_impl(
        *args,
        disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
        **kwargs)

is_blackwell_deep_gemm_e8m0_used cached

is_blackwell_deep_gemm_e8m0_used() -> bool

Return True if vLLM is configured to use DeepGEMM " "E8M0 scale on a Blackwell-class GPU.

Source code in vllm/utils/deep_gemm.py
@functools.cache
def is_blackwell_deep_gemm_e8m0_used() -> bool:
    """Return ``True`` if vLLM is configured to use DeepGEMM "
    "E8M0 scale on a Blackwell-class GPU.
    """
    if not is_deep_gemm_supported():
        logger.debug_once(
            "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
        return False

    if not envs.VLLM_USE_DEEP_GEMM_E8M0:
        logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.")
        return False

    _lazy_init()

    if _fp8_gemm_nt_impl is None:
        logger.debug_once(
            "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
        return False

    enabled = (current_platform.is_cuda()
               and current_platform.has_device_capability(100))
    if enabled:
        logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
    else:
        logger.debug_once(
            "DeepGEMM E8M0 disabled: not running on Blackwell GPU.")
    return enabled

is_deep_gemm_supported cached

is_deep_gemm_supported() -> bool

Return True if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported.

Source code in vllm/utils/deep_gemm.py
@functools.cache
def is_deep_gemm_supported() -> bool:
    """Return ``True`` if DeepGEMM is supported on the current platform.
    Currently, only Hopper and Blackwell GPUs are supported.
    """
    is_supported_arch = current_platform.is_cuda() and (
        current_platform.is_device_capability(90)
        or current_platform.is_device_capability(100))
    return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch

m_grouped_fp8_gemm_nt_contiguous

m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
    _lazy_init()
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_impl(
        *args,
        disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
        **kwargs)

per_block_cast_to_fp8

per_block_cast_to_fp8(
    x: Tensor,
    block_size: list[int] = DEFAULT_BLOCK_SIZE,
    use_ue8m0: bool = False,
) -> tuple[Tensor, Tensor]
Source code in vllm/utils/deep_gemm.py
def per_block_cast_to_fp8(
        x: torch.Tensor,
        block_size: list[int] = DEFAULT_BLOCK_SIZE,
        use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
    x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
                           dtype=x.dtype,
                           device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    sf = x_amax / 448.0
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
    x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
        x_view.size(0), x_view.size(2))

should_use_deepgemm_for_fp8_linear

should_use_deepgemm_for_fp8_linear(
    output_dtype: dtype, weight: Tensor
)
Source code in vllm/utils/deep_gemm.py
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
                                       weight: torch.Tensor):
    return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
            and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)