class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_name: str = "cuda"
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
dist_backend: str = "nccl"
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
@property
def supported_dtypes(self) -> list[torch.dtype]:
if self.has_device_capability(80):
# Ampere and Hopper or later NVIDIA GPUs.
return [torch.bfloat16, torch.float16, torch.float32]
elif (not self.has_device_capability(80)
) and self.has_device_capability(60):
# Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
return [torch.float16, torch.float32]
# Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
# though vLLM doesn't support these GPUs.
return [torch.float32]
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.cuda.set_device(device)
# With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed
_ = torch.zeros(1, device=device)
@classmethod
def get_device_capability(cls,
device_id: int = 0
) -> Optional[DeviceCapability]:
raise NotImplementedError
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
if enforce_eager and not envs.VLLM_USE_V1:
logger.warning(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
return False
return True
@classmethod
def is_fully_connected(cls, device_ids: list[int]) -> bool:
raise NotImplementedError
@classmethod
def log_warnings(cls):
pass
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto":
if vllm_config.speculative_config:
if not envs.VLLM_USE_V1:
raise NotImplementedError(
"Speculative decoding is not supported on vLLM V0.")
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
if model_config is not None and model_config.use_mla:
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla = False
use_cutlass_mla = False
if envs.VLLM_ATTENTION_BACKEND is None:
# Default case
if cls.is_device_capability(100):
# Blackwell => Force CutlassMLA.
use_cutlass_mla = True
# TODO: This does not work, because the
# global_force_attn_backend_context_manager is not set.
# See vllm/attention/selector.py:_cached_get_attn_backend
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
else:
# Not Blackwell
use_flashmla = True
else:
# Forced case
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
use_cutlass_mla = (
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
from vllm.attention.ops.flashmla import is_flashmla_supported
if use_flashmla and is_flashmla_supported()[0] \
and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")
if use_cutlass_mla and cache_config.block_size != 128:
cache_config.block_size = 128
logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA backend.")
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
logger.info(
"Data Parallel: disabling cudagraphs since DP "
"with DeepEP high-throughput kernels are not CUDA Graph "
"compatible. The DeepEP low-latency kernels are CUDA Graph "
"compatible. Set the all_to_all backend to deepep_low_latency "
"to use those kernels instead.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if model_config is not None:
model_config.enforce_eager = True
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
@classmethod
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
if cls.has_device_capability(80) and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
return _Backend.FLASH_ATTN
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision "
"module, so we use xformers backend instead. You can "
"run `pip install flash-attn` to use flash-attention "
"backend.")
# Fallback for Volta/Turing GPUs or FA not supported
return _Backend.XFORMERS
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.CUTLASS_MLA or (
cls.is_device_capability(100) and selected_backend is None
and block_size == 128):
if use_v1:
logger.info_once("Using Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"cutlass_mla.CutlassMLABackend")
else:
logger.warning(
"Cutlass MLA backend is only supported on V1 engine")
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
else:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
if not is_flashmla_supported()[0]:
logger.warning(
"FlashMLA backend is not supported due to %s",
is_flashmla_supported()[1])
elif block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
if use_v1:
logger.info_once(
"Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
logger.info("Using FlashMLA backend.")
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
if cls.has_device_capability(100):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
set_kv_cache_layout("HND")
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend on V1 engine.")
return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS_VLLM_V1:
logger.info_once("Using XFormers backend on V1 engine.")
return XFORMERS_V1
from vllm.attention.selector import is_attn_backend_supported
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
if is_default_backend_supported := is_attn_backend_supported(
FLASHINFER_V1, head_size, dtype):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
set_kv_cache_layout("HND")
return FLASHINFER_V1
if not is_default_backend_supported.can_import:
logger.warning_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if has_sink and not cls.is_device_capability(90):
logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1
if is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False):
logger.info_once("Using Flash Attention backend on "
"V1 engine.")
return FLASH_ATTN_V1
# FlexAttention is the default for older GPUs
else:
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
assert not is_default_backend_supported
use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype
logger.info_once(
"Using FlexAttention backend for %s on V1 engine.",
", ".join(f"{k}={v}"
for k, v in use_flex_attention_reason.items()),
)
return FLEX_ATTENTION_V1
# Backends for V0 engine
if selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
logger.info("Using DifferentialFlashAttention backend.")
return ("vllm.attention.backends.differential_flash_attn."
"DifferentialFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
target_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
target_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
target_backend = _Backend.XFORMERS
# FlashAttn is valid for the model, checking if the package is
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention backend for FP8 KV cache.")
target_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
target_backend = _Backend.XFORMERS
if target_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
@classmethod
def supports_fp8(cls) -> bool:
return cls.has_device_capability(89)
@classmethod
def supports_v1(cls, model_config: "ModelConfig") -> bool:
return True
@classmethod
def use_custom_allreduce(cls) -> bool:
return True
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
model_config: "ModelConfig") -> bool:
fp8_attention = kv_cache_dtype.startswith("fp8")
attention_backend = envs.VLLM_ATTENTION_BACKEND
supported = False
if model_config is not None and model_config.use_mla:
# Default to CutlassMLA for blackwell,
# FlashMLA otherwise
if attention_backend is None:
if cls.is_device_capability(100):
attention_backend = "CUTLASS_MLA"
else:
attention_backend = "FLASHMLA"
# Only FlashMLA supports fp8
if attention_backend == "FLASHMLA":
supported = True
else:
supported = (not fp8_attention)
else:
# Default to FlashAttention
if attention_backend is None:
attention_backend = "FLASH_ATTN_VLLM_V1"
# All Blackwell backends support fp8
if cls.is_device_capability(100):
supported = True
elif attention_backend == "FLASH_ATTN_VLLM_V1":
if fp8_attention:
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
else:
supported = True
return supported
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs "
"with compute capability of at least 8.0. "
f"Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")