Skip to content

vllm.model_executor.layers.quantization.fp8

ACTIVATION_SCHEMES module-attribute

ACTIVATION_SCHEMES = ['static', 'dynamic']

logger module-attribute

logger = init_logger(__name__)

Fp8Config

Bases: QuantizationConfig

Config class for FP8.

Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8Config(QuantizationConfig):
    """Config class for FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
        activation_scheme: str = "dynamic",
        ignored_layers: Optional[list[str]] = None,
        weight_block_size: Optional[list[int]] = None,
    ) -> None:
        super().__init__()

        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

        if activation_scheme not in ACTIVATION_SCHEMES:
            raise ValueError(
                f"Unsupported activation scheme {activation_scheme}")
        self.activation_scheme = activation_scheme
        self.ignored_layers = ignored_layers or []
        if weight_block_size is not None:
            if not is_checkpoint_fp8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports fp8-serialized "
                    "checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must have 2 "
                    f"dimensions, but got {len(weight_block_size)} dimensions")
            if activation_scheme != "dynamic":
                raise ValueError("The block-wise quantization only supports "
                                 "dynamic activation scheme for now, but got "
                                 f"{activation_scheme} activation scheme.")
        self.weight_block_size = weight_block_size

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
        if self.ignored_layers is not None:
            self.ignored_layers = hf_to_vllm_mapper.apply_list(
                self.ignored_layers)

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
        quant_method = cls.get_from_keys(config, ["quant_method"])
        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
        activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
        weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                                 None)
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(config,
                                                  ["modules_to_not_convert"],
                                                  None)
        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
                   activation_scheme=activation_scheme,
                   ignored_layers=ignored_layers,
                   weight_block_size=weight_block_size)

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignored_layers,
                                fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            return Fp8LinearMethod(self)
        elif isinstance(layer, FusedMoE):
            return Fp8MoEMethod(self, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

    def get_cache_scale(self, name: str) -> Optional[str]:
        """
        Check whether the param name matches the format for k/v cache scales
        in compressed-tensors. If this is the case, return its equivalent
        param name expected by vLLM

        :param name: param name
        :return: matching param name for KV cache scale in vLLM
        """
        if name.endswith(".output_scale") and ".k_proj" in name:
            return name.replace(".k_proj.output_scale", ".attn.k_scale")
        if name.endswith(".output_scale") and ".v_proj" in name:
            return name.replace(".v_proj.output_scale", ".attn.v_scale")
        if name.endswith(".output_scale") and ".q_proj" in name:
            return name.replace(".q_proj.output_scale", ".attn.q_scale")
        if name.endswith("self_attn.prob_output_scale"):
            return name.replace(".prob_output_scale", ".attn.prob_scale")
        # If no matches, return None
        return None

activation_scheme instance-attribute

activation_scheme = activation_scheme

ignored_layers instance-attribute

ignored_layers = ignored_layers or []

is_checkpoint_fp8_serialized instance-attribute

is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

weight_block_size instance-attribute

weight_block_size = weight_block_size

__init__

__init__(
    is_checkpoint_fp8_serialized: bool = False,
    activation_scheme: str = "dynamic",
    ignored_layers: Optional[list[str]] = None,
    weight_block_size: Optional[list[int]] = None,
) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(
    self,
    is_checkpoint_fp8_serialized: bool = False,
    activation_scheme: str = "dynamic",
    ignored_layers: Optional[list[str]] = None,
    weight_block_size: Optional[list[int]] = None,
) -> None:
    super().__init__()

    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

    if activation_scheme not in ACTIVATION_SCHEMES:
        raise ValueError(
            f"Unsupported activation scheme {activation_scheme}")
    self.activation_scheme = activation_scheme
    self.ignored_layers = ignored_layers or []
    if weight_block_size is not None:
        if not is_checkpoint_fp8_serialized:
            raise ValueError(
                "The block-wise quantization only supports fp8-serialized "
                "checkpoint for now.")
        if len(weight_block_size) != 2:
            raise ValueError(
                "The quantization block size of weight must have 2 "
                f"dimensions, but got {len(weight_block_size)} dimensions")
        if activation_scheme != "dynamic":
            raise ValueError("The block-wise quantization only supports "
                             "dynamic activation scheme for now, but got "
                             f"{activation_scheme} activation scheme.")
    self.weight_block_size = weight_block_size

apply_vllm_mapper

apply_vllm_mapper(hf_to_vllm_mapper: WeightsMapper)
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
    if self.ignored_layers is not None:
        self.ignored_layers = hf_to_vllm_mapper.apply_list(
            self.ignored_layers)

from_config classmethod

from_config(config: dict[str, Any]) -> Fp8Config
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
    quant_method = cls.get_from_keys(config, ["quant_method"])
    is_checkpoint_fp8_serialized = ("fp8" in quant_method)
    activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
    ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
    weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
                                             None)
    if not ignored_layers:
        ignored_layers = cls.get_from_keys_or(config,
                                              ["modules_to_not_convert"],
                                              None)
    return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
               activation_scheme=activation_scheme,
               ignored_layers=ignored_layers,
               weight_block_size=weight_block_size)

get_cache_scale

get_cache_scale(name: str) -> Optional[str]

Check whether the param name matches the format for k/v cache scales in compressed-tensors. If this is the case, return its equivalent param name expected by vLLM

:param name: param name :return: matching param name for KV cache scale in vLLM

Source code in vllm/model_executor/layers/quantization/fp8.py
def get_cache_scale(self, name: str) -> Optional[str]:
    """
    Check whether the param name matches the format for k/v cache scales
    in compressed-tensors. If this is the case, return its equivalent
    param name expected by vLLM

    :param name: param name
    :return: matching param name for KV cache scale in vLLM
    """
    if name.endswith(".output_scale") and ".k_proj" in name:
        return name.replace(".k_proj.output_scale", ".attn.k_scale")
    if name.endswith(".output_scale") and ".v_proj" in name:
        return name.replace(".v_proj.output_scale", ".attn.v_scale")
    if name.endswith(".output_scale") and ".q_proj" in name:
        return name.replace(".q_proj.output_scale", ".attn.q_scale")
    if name.endswith("self_attn.prob_output_scale"):
        return name.replace(".prob_output_scale", ".attn.prob_scale")
    # If no matches, return None
    return None

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return []

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "fp8"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/fp8.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import

    if isinstance(layer, LinearBase):
        if is_layer_skipped(prefix=prefix,
                            ignored_layers=self.ignored_layers,
                            fused_mapping=self.packed_modules_mapping):
            return UnquantizedLinearMethod()
        return Fp8LinearMethod(self)
    elif isinstance(layer, FusedMoE):
        return Fp8MoEMethod(self, layer)
    elif isinstance(layer, Attention):
        return Fp8KVCacheMethod(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/fp8.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

Fp8KVCacheMethod

Bases: BaseKVCacheMethod

Supports loading kv-cache scaling factors from FP8 checkpoints.

Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Fp8Config):
        super().__init__(quant_config)

__init__

__init__(quant_config: Fp8Config)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config):
    super().__init__(quant_config)

Fp8LinearMethod

Bases: LinearMethodBase

Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale.

Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded.

Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

Parameters:

Name Type Description Default
quant_config Fp8Config

The quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp8.py
class Fp8LinearMethod(LinearMethodBase):
    """Linear method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn data type due to the limitation of
       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config):
        self.quant_config = quant_config
        self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
        self.out_dtype = torch.get_default_dtype()

        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

        # AITER is only supported on ROCm and only for FP8_FNUZ
        # and at the moment are MI300 series
        self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                           and envs.VLLM_ROCM_USE_AITER
                                           and envs.VLLM_ROCM_USE_AITER_LINEAR
                                           and current_platform.is_fp8_fnuz())

        self.block_quant = self.quant_config.weight_block_size is not None
        self.act_q_static = self.quant_config.activation_scheme == "static"
        # Use per-token quantization for better perf if dynamic and cutlass
        if not self.act_q_static and cutlass_fp8_supported():
            self.act_q_group_shape = GroupShape.PER_TOKEN
        else:
            self.act_q_group_shape = GroupShape.PER_TENSOR

        self.fp8_linear = Fp8LinearOp(
            act_quant_static=self.act_q_static,
            act_quant_group_shape=self.act_q_group_shape)

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        maybe_create_device_identity()

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        if self.block_quant:
            tp_size = get_tensor_model_parallel_world_size()
            assert self.quant_config.weight_block_size is not None
            layer.weight_block_size = self.quant_config.weight_block_size
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # Required by row parallel
            if (tp_size > 1
                    and input_size // input_size_per_partition == tp_size
                    and input_size_per_partition % block_k != 0):
                raise ValueError(
                    f"Weight input_size_per_partition = "
                    f"{input_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")
            # Required by column parallel or enabling merged weights
            is_tp_split = (tp_size > 1 and
                           output_size // output_size_per_partition == tp_size)
            is_merged_gemm = len(output_partition_sizes) > 1
            if is_tp_split or is_merged_gemm:
                sizes_to_check = output_partition_sizes
                if not is_tp_split and is_merged_gemm:
                    # In case of merged matrices, we allow the last
                    # matrix to not be a multiple of block size
                    sizes_to_check = output_partition_sizes[:-1]
                for output_partition_size in sizes_to_check:
                    if output_partition_size % block_n != 0:
                        raise ValueError(
                            f"Weight output_partition_size = "
                            f"{output_partition_size} is not divisible by "
                            f"weight quantization block_n = {block_n}.")

        # WEIGHT
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)

        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        # If checkpoint is serialized fp8, load them.
        # Otherwise, wait until process_weights_after_loading.
        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            if not self.block_quant:
                scale = PerTensorScaleParameter(
                    data=torch.empty(len(output_partition_sizes),
                                     dtype=torch.float32),
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
                layer.register_parameter("weight_scale", scale)
            else:
                assert self.quant_config.activation_scheme == "dynamic"
                scale = BlockQuantScaleParameter(
                    data=torch.empty(
                        (output_size_per_partition + block_n - 1) // block_n,
                        (input_size_per_partition + block_k - 1) // block_k,
                        dtype=torch.float32,
                    ),
                    input_dim=1,
                    output_dim=0,
                    weight_loader=weight_loader,
                )
                scale[:] = torch.finfo(torch.float32).min
                set_weight_attrs(scale, {"scale_type": "weight_scale"})
                # The weight_scale_inv name is intentional for deepseekv3
                layer.register_parameter("weight_scale_inv", scale)

            # INPUT ACTIVATION SCALE
            if self.quant_config.activation_scheme == "static":
                scale = PerTensorScaleParameter(data=torch.empty(
                    len(output_partition_sizes), dtype=torch.float32),
                                                weight_loader=weight_loader)

                scale[:] = torch.finfo(torch.float32).min
                set_weight_attrs(scale, {"scale_type": "input_scale"})
                layer.register_parameter("input_scale", scale)
            else:
                layer.register_parameter("input_scale", None)

    def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
        # Pad the weight tensor. This is an optimization on ROCm platform, which
        # can benefit from tensors located far enough from one another in memory
        if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):
            num_pad = 256 // weight.element_size()
            weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
            torch.cuda.empty_cache()
        return weight

    def process_weights_after_loading(self, layer: Module) -> None:
        size_k_first = True
        # TODO(rob): refactor block quant into separate class.
        if self.block_quant:
            assert self.quant_config.activation_scheme == "dynamic"
            size_k_first = False
            if current_platform.is_fp8_fnuz():
                weight, weight_scale_inv, _ = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=layer.weight,
                        weight_scale=layer.weight_scale_inv)
            else:
                weight = layer.weight.data
                weight_scale_inv = layer.weight_scale_inv.data

            weight = self._maybe_pad_weight(weight)

            # Torch.compile cannot use Parameter subclasses.
            layer.weight = Parameter(weight, requires_grad=False)
            layer.weight_scale_inv = Parameter(weight_scale_inv,
                                               requires_grad=False)

        # If checkpoint not serialized fp8, quantize the weights.
        elif not self.quant_config.is_checkpoint_fp8_serialized:
            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                         scale=None)

            # Update the layer with the new values.
            layer.weight = Parameter(qweight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            # layer.input_scale is None indicates dynamic quant and scale is
            # computed from input.
            layer.input_scale = None

        # If checkpoint is fp8, handle that there are N scales for N
        # shards in a fused module
        else:
            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                    requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                       requires_grad=False)

            weight = layer.weight
            weight_scale = layer.weight_scale

            # If using w8a8, torch._scaled_mm needs per tensor, so
            # requantize the logical shards as a single weight.
            if not self.use_marlin:
                # Dequant -> Quant with max scale so we can run per tensor.
                if current_platform.is_fp8_fnuz():
                    weight, weight_scale, input_scale = \
                        normalize_e4m3fn_to_e4m3fnuz(
                            weight=weight,
                            weight_scale=weight_scale,
                            input_scale=layer.input_scale)
                    if input_scale is not None:
                        layer.input_scale = Parameter(input_scale,
                                                      requires_grad=False)

                weight_scale, weight = requantize_with_max_scale(
                    weight=weight,
                    weight_scale=weight_scale,
                    logical_widths=layer.logical_widths,
                )

            weight = self._maybe_pad_weight(weight)
            # Update layer with new values.
            layer.weight = Parameter(weight.t(), requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            if self.quant_config.activation_scheme == "static":
                layer.input_scale = Parameter(layer.input_scale.max(),
                                              requires_grad=False)

        if self.use_marlin:
            prepare_fp8_layer_for_marlin(layer, size_k_first)
            # Activations not quantized for marlin.
            del layer.input_scale

        # On B200, if E8M0 for DeepGemm is used, we need to
        # requantize the weight and input to the specific scale
        # at the same time.
        if is_blackwell_deep_gemm_e8m0_used():
            assert layer.weight_block_size is not None
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.weight.data,
                layer.weight_scale_inv.data if hasattr(
                    layer, "weight_scale_inv") else layer.weight_scale.data,
                block_sz,
            )

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:

        if self.use_marlin:
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                bias=bias)

        if self.block_quant:
            assert self.quant_config.weight_block_size is not None

            return torch.ops.vllm.apply_w8a8_block_fp8_linear(
                input=x,
                weight=layer.weight,
                block_size=self.quant_config.weight_block_size,
                weight_scale=layer.weight_scale_inv,
                input_scale=layer.input_scale,
                bias=bias,
                cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
                use_aiter_and_is_supported=self.use_aiter_and_is_supported,
            )

        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
                                     out_dtype=self.out_dtype,
                                     input_scale=layer.input_scale,
                                     bias=bias)

act_q_group_shape instance-attribute

act_q_group_shape = PER_TOKEN

act_q_static instance-attribute

act_q_static = activation_scheme == 'static'

block_quant instance-attribute

block_quant = weight_block_size is not None

cutlass_block_fp8_supported instance-attribute

cutlass_block_fp8_supported = cutlass_block_fp8_supported()

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(
    act_quant_static=act_q_static,
    act_quant_group_shape=act_q_group_shape,
)

out_dtype instance-attribute

out_dtype = get_default_dtype()

quant_config instance-attribute

quant_config = quant_config

use_aiter_and_is_supported instance-attribute

use_aiter_and_is_supported = (
    is_rocm()
    and VLLM_ROCM_USE_AITER
    and VLLM_ROCM_USE_AITER_LINEAR
    and is_fp8_fnuz()
)

use_marlin instance-attribute

use_marlin = (
    not has_device_capability(89)
    or VLLM_TEST_FORCE_FP8_MARLIN
)

__init__

__init__(quant_config: Fp8Config)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config):
    self.quant_config = quant_config
    self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
    self.out_dtype = torch.get_default_dtype()

    # For GPUs that lack FP8 hardware support, we can leverage the Marlin
    # kernel for fast weight-only FP8 quantization
    self.use_marlin = (not current_platform.has_device_capability(89)
                       or envs.VLLM_TEST_FORCE_FP8_MARLIN)
    # Disable marlin for rocm
    if current_platform.is_rocm():
        self.use_marlin = False

    # AITER is only supported on ROCm and only for FP8_FNUZ
    # and at the moment are MI300 series
    self.use_aiter_and_is_supported = (current_platform.is_rocm()
                                       and envs.VLLM_ROCM_USE_AITER
                                       and envs.VLLM_ROCM_USE_AITER_LINEAR
                                       and current_platform.is_fp8_fnuz())

    self.block_quant = self.quant_config.weight_block_size is not None
    self.act_q_static = self.quant_config.activation_scheme == "static"
    # Use per-token quantization for better perf if dynamic and cutlass
    if not self.act_q_static and cutlass_fp8_supported():
        self.act_q_group_shape = GroupShape.PER_TOKEN
    else:
        self.act_q_group_shape = GroupShape.PER_TENSOR

    self.fp8_linear = Fp8LinearOp(
        act_quant_static=self.act_q_static,
        act_quant_group_shape=self.act_q_group_shape)

_maybe_pad_weight

_maybe_pad_weight(weight: Tensor) -> Tensor
Source code in vllm/model_executor/layers/quantization/fp8.py
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
    # Pad the weight tensor. This is an optimization on ROCm platform, which
    # can benefit from tensors located far enough from one another in memory
    if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
            and weight.stride(-1) == 1
            and (weight.stride(-2) * weight.element_size()) % 512 == 0):
        num_pad = 256 // weight.element_size()
        weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
        torch.cuda.empty_cache()
    return weight

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply(self,
          layer: torch.nn.Module,
          x: torch.Tensor,
          bias: Optional[torch.Tensor] = None) -> torch.Tensor:

    if self.use_marlin:
        return apply_fp8_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias)

    if self.block_quant:
        assert self.quant_config.weight_block_size is not None

        return torch.ops.vllm.apply_w8a8_block_fp8_linear(
            input=x,
            weight=layer.weight,
            block_size=self.quant_config.weight_block_size,
            weight_scale=layer.weight_scale_inv,
            input_scale=layer.input_scale,
            bias=bias,
            cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
            use_aiter_and_is_supported=self.use_aiter_and_is_supported,
        )

    return self.fp8_linear.apply(input=x,
                                 weight=layer.weight,
                                 weight_scale=layer.weight_scale,
                                 out_dtype=self.out_dtype,
                                 input_scale=layer.input_scale,
                                 bias=bias)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    maybe_create_device_identity()

    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    if self.block_quant:
        tp_size = get_tensor_model_parallel_world_size()
        assert self.quant_config.weight_block_size is not None
        layer.weight_block_size = self.quant_config.weight_block_size
        block_n, block_k = (
            self.quant_config.weight_block_size[0],
            self.quant_config.weight_block_size[1],
        )
        # Required by row parallel
        if (tp_size > 1
                and input_size // input_size_per_partition == tp_size
                and input_size_per_partition % block_k != 0):
            raise ValueError(
                f"Weight input_size_per_partition = "
                f"{input_size_per_partition} is not divisible by "
                f"weight quantization block_k = {block_k}.")
        # Required by column parallel or enabling merged weights
        is_tp_split = (tp_size > 1 and
                       output_size // output_size_per_partition == tp_size)
        is_merged_gemm = len(output_partition_sizes) > 1
        if is_tp_split or is_merged_gemm:
            sizes_to_check = output_partition_sizes
            if not is_tp_split and is_merged_gemm:
                # In case of merged matrices, we allow the last
                # matrix to not be a multiple of block size
                sizes_to_check = output_partition_sizes[:-1]
            for output_partition_size in sizes_to_check:
                if output_partition_size % block_n != 0:
                    raise ValueError(
                        f"Weight output_partition_size = "
                        f"{output_partition_size} is not divisible by "
                        f"weight quantization block_n = {block_n}.")

    # WEIGHT
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_fp8_serialized else
                    params_dtype)

    weight = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition,
        dtype=weight_dtype),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    # If checkpoint is serialized fp8, load them.
    # Otherwise, wait until process_weights_after_loading.
    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALE
        if not self.block_quant:
            scale = PerTensorScaleParameter(
                data=torch.empty(len(output_partition_sizes),
                                 dtype=torch.float32),
                weight_loader=weight_loader,
            )
            scale[:] = torch.finfo(torch.float32).min
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            layer.register_parameter("weight_scale", scale)
        else:
            assert self.quant_config.activation_scheme == "dynamic"
            scale = BlockQuantScaleParameter(
                data=torch.empty(
                    (output_size_per_partition + block_n - 1) // block_n,
                    (input_size_per_partition + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
            scale[:] = torch.finfo(torch.float32).min
            set_weight_attrs(scale, {"scale_type": "weight_scale"})
            # The weight_scale_inv name is intentional for deepseekv3
            layer.register_parameter("weight_scale_inv", scale)

        # INPUT ACTIVATION SCALE
        if self.quant_config.activation_scheme == "static":
            scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                            weight_loader=weight_loader)

            scale[:] = torch.finfo(torch.float32).min
            set_weight_attrs(scale, {"scale_type": "input_scale"})
            layer.register_parameter("input_scale", scale)
        else:
            layer.register_parameter("input_scale", None)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    size_k_first = True
    # TODO(rob): refactor block quant into separate class.
    if self.block_quant:
        assert self.quant_config.activation_scheme == "dynamic"
        size_k_first = False
        if current_platform.is_fp8_fnuz():
            weight, weight_scale_inv, _ = \
                normalize_e4m3fn_to_e4m3fnuz(
                    weight=layer.weight,
                    weight_scale=layer.weight_scale_inv)
        else:
            weight = layer.weight.data
            weight_scale_inv = layer.weight_scale_inv.data

        weight = self._maybe_pad_weight(weight)

        # Torch.compile cannot use Parameter subclasses.
        layer.weight = Parameter(weight, requires_grad=False)
        layer.weight_scale_inv = Parameter(weight_scale_inv,
                                           requires_grad=False)

    # If checkpoint not serialized fp8, quantize the weights.
    elif not self.quant_config.is_checkpoint_fp8_serialized:
        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
                                                     scale=None)

        # Update the layer with the new values.
        layer.weight = Parameter(qweight.t(), requires_grad=False)
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)
        # layer.input_scale is None indicates dynamic quant and scale is
        # computed from input.
        layer.input_scale = None

    # If checkpoint is fp8, handle that there are N scales for N
    # shards in a fused module
    else:
        layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
                                                requires_grad=False)
        if self.quant_config.activation_scheme == "static":
            layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
                                                   requires_grad=False)

        weight = layer.weight
        weight_scale = layer.weight_scale

        # If using w8a8, torch._scaled_mm needs per tensor, so
        # requantize the logical shards as a single weight.
        if not self.use_marlin:
            # Dequant -> Quant with max scale so we can run per tensor.
            if current_platform.is_fp8_fnuz():
                weight, weight_scale, input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        weight=weight,
                        weight_scale=weight_scale,
                        input_scale=layer.input_scale)
                if input_scale is not None:
                    layer.input_scale = Parameter(input_scale,
                                                  requires_grad=False)

            weight_scale, weight = requantize_with_max_scale(
                weight=weight,
                weight_scale=weight_scale,
                logical_widths=layer.logical_widths,
            )

        weight = self._maybe_pad_weight(weight)
        # Update layer with new values.
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)
        if self.quant_config.activation_scheme == "static":
            layer.input_scale = Parameter(layer.input_scale.max(),
                                          requires_grad=False)

    if self.use_marlin:
        prepare_fp8_layer_for_marlin(layer, size_k_first)
        # Activations not quantized for marlin.
        del layer.input_scale

    # On B200, if E8M0 for DeepGemm is used, we need to
    # requantize the weight and input to the specific scale
    # at the same time.
    if is_blackwell_deep_gemm_e8m0_used():
        assert layer.weight_block_size is not None
        block_sz = tuple(layer.weight_block_size)
        requant_weight_ue8m0_inplace(
            layer.weight.data,
            layer.weight_scale_inv.data if hasattr(
                layer, "weight_scale_inv") else layer.weight_scale.data,
            block_sz,
        )

Fp8MoEMethod

Bases: FusedMoEMethodBase

MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale.

Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded.

Parameters:

Name Type Description Default
quant_config Fp8Config

The quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp8.py
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
class Fp8MoEMethod(FusedMoEMethodBase):
    """MoE method for FP8.
    Supports loading FP8 checkpoints with static weight scale and
    dynamic/static activation scale.

    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
    activation scaling. The weight scaling factor will be initialized after
    the model weights are loaded.

    Args:
        quant_config: The quantization config.
    """

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        super().__init__(layer.moe_config)
        self.layer = layer
        self.quant_config = quant_config
        self.block_quant = self.quant_config.weight_block_size is not None

        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
        self.fused_experts: Optional[
            mk.FusedMoEModularKernel] = None  # type: ignore
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )
        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = (not current_platform.has_device_capability(89)
                           or envs.VLLM_TEST_FORCE_FP8_MARLIN)
        # Disable marlin for rocm
        if current_platform.is_rocm():
            self.use_marlin = False

        # Check for DeepGemm support.
        self.allow_deep_gemm = False
        if envs.VLLM_USE_DEEP_GEMM:
            if not has_deep_gemm():
                logger.warning_once("Failed to import DeepGemm kernels.")
            elif not self.block_quant:
                logger.warning_once("Model is not block quantized. Not using "
                                    "DeepGemm kernels")
            elif (is_deep_gemm_supported()):
                logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
                self.allow_deep_gemm = True
            else:
                logger.warning_once(
                    "DeepGemm not supported on the current platform.")

        # Check for CutlassBlockScaledGroupedGemm support.
        self.allow_cutlass_block_scaled_grouped_gemm = False
        if not self.block_quant:
            logger.debug_once("Model is not block quantized. Not using "
                              "CutlassBlockScaledGroupedGemm kernels")
        elif (current_platform.is_cuda()
              and current_platform.is_device_capability(100)):
            logger.info_once(
                "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
            )
            self.allow_cutlass_block_scaled_grouped_gemm = True
        else:
            logger.warning_once(
                "CutlassBlockScaledGroupedGemm not supported on the current "
                "platform.")

    def maybe_make_prepare_finalize(
        self,
        moe: FusedMoEConfig,
    ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
            return super().maybe_make_prepare_finalize(moe)

        prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
            moe,
            layer=self.layer,
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize

    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                       intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        layer.intermediate_size_per_partition = intermediate_size_per_partition
        layer.hidden_size = hidden_size
        layer.num_experts = num_experts
        layer.orig_dtype = params_dtype
        layer.weight_block_size = None

        if self.quant_config.is_checkpoint_fp8_serialized:
            params_dtype = torch.float8_e4m3fn
        if self.block_quant:
            assert self.quant_config.weight_block_size is not None
            layer.weight_block_size = self.quant_config.weight_block_size
            tp_size = get_tensor_model_parallel_world_size()
            block_n, block_k = (
                self.quant_config.weight_block_size[0],
                self.quant_config.weight_block_size[1],
            )
            # NOTE: To ensure proper alignment of the block-wise quantization
            # scales, the output_size of the weights for both the gate and up
            # layers must be divisible by block_n.
            # Required by column parallel or enabling merged weights
            if intermediate_size_per_partition % block_n != 0:
                raise ValueError(
                    f"The output_size of gate's and up's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_n = {block_n}.")
            if (tp_size > 1
                    and intermediate_size_per_partition % block_k != 0):
                # Required by row parallel
                raise ValueError(
                    f"The input_size of down's weight = "
                    f"{intermediate_size_per_partition} is not divisible by "
                    f"weight quantization block_k = {block_k}.")

        # WEIGHTS
        w13_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size,
            dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w2_weight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition,
            dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        # WEIGHT_SCALES
        if not self.block_quant:
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, 2, dtype=torch.float32),
                                                  requires_grad=False)
            w2_weight_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)
        else:
            w13_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    2 * ((intermediate_size_per_partition + block_n - 1) //
                         block_n),
                    (hidden_size + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            w2_weight_scale = torch.nn.Parameter(
                torch.ones(
                    num_experts,
                    (hidden_size + block_n - 1) // block_n,
                    (intermediate_size_per_partition + block_k - 1) // block_k,
                    dtype=torch.float32,
                ),
                requires_grad=False,
            )
            layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
            layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
            assert self.quant_config.activation_scheme == "dynamic"

        # Add the quantization method used (per tensor/grouped/channel)
        # to ensure the weight scales are loaded in properly
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
             value} if self.block_quant else
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
        # If loading fp8 checkpoint, pass the weight loaders.
        # If loading an fp16 checkpoint, do not (we will quantize in
        #   process_weights_after_loading()
        if self.quant_config.is_checkpoint_fp8_serialized:
            set_weight_attrs(w13_weight_scale, extra_weight_attrs)
            set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        # INPUT_SCALES
        if self.quant_config.activation_scheme == "static":
            if not self.quant_config.is_checkpoint_fp8_serialized:
                raise ValueError(
                    "Found static activation scheme for checkpoint that "
                    "was not serialized fp8.")

            w13_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                 requires_grad=False)
            layer.register_parameter("w13_input_scale", w13_input_scale)
            set_weight_attrs(w13_input_scale, extra_weight_attrs)

            w2_input_scale = torch.nn.Parameter(torch.ones(
                num_experts, dtype=torch.float32),
                                                requires_grad=False)
            layer.register_parameter("w2_input_scale", w2_input_scale)
            set_weight_attrs(w2_input_scale, extra_weight_attrs)

        else:
            layer.w13_input_scale = None
            layer.w2_input_scale = None

    def process_weights_after_loading(self, layer: Module) -> None:
        # Lazy import to avoid importing triton too early.
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
            is_rocm_aiter_moe_enabled, shuffle_weights)

        self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

        # TODO (rob): refactor block quant into separate class.
        if self.block_quant:
            assert self.quant_config.activation_scheme == "dynamic"
            if current_platform.is_fp8_fnuz():
                w13_weight, w13_weight_scale_inv, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale_inv,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale_inv, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale_inv,
                        layer.w2_input_scale)
            elif self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                w13_weight_scale_inv = swap_w13_to_w31(
                    layer.w13_weight_scale_inv.data)
                w2_weight = layer.w2_weight.data
                w2_weight_scale_inv = layer.w2_weight_scale_inv.data
            else:
                w13_weight = layer.w13_weight.data
                w13_weight_scale_inv = layer.w13_weight_scale_inv.data
                w2_weight = layer.w2_weight
                w2_weight_scale_inv = layer.w2_weight_scale_inv

            # torch.compile() cannot use Parameter subclasses.
            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                                   requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                                  requires_grad=False)
            if self.rocm_aiter_moe_enabled:
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight.data, layer.w2_weight.data)

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

            # DeepGemm scales need to be transposed and aligned.  We try to do
            # it ahead of time for performance reasons.
            if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
                # Lazy import to avoid CUDA initialization problems.
                if _is_col_major(layer.w13_weight_scale_inv):
                    layer.w13_weight_scale_inv = \
                        get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
                if _is_col_major(layer.w2_weight_scale_inv):
                    layer.w2_weight_scale_inv = \
                        get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()

        # If checkpoint is fp16, quantize in place.
        elif not self.quant_config.is_checkpoint_fp8_serialized:
            fp8_dtype = current_platform.fp8_dtype()
            w13_weight = torch.empty_like(layer.w13_weight.data,
                                          dtype=fp8_dtype)
            w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

            # Re-initialize w13_scale because we directly quantize
            # merged w13 weights and generate a single scaling factor.
            layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
                layer.local_num_experts,
                dtype=torch.float32,
                device=w13_weight.device),
                                                        requires_grad=False)
            for expert in range(layer.local_num_experts):
                w13_weight[expert, :, :], layer.w13_weight_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w13_weight.data[expert, :, :])
                w2_weight[expert, :, :], layer.w2_weight_scale[
                    expert] = ops.scaled_fp8_quant(
                        layer.w2_weight.data[expert, :, :])
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            if self.rocm_aiter_moe_enabled:
                # reshaping weights is required for aiter moe kernel.
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)
        # If checkpoint is fp8, we need to handle that the
        # MoE kernels require single activation scale and single weight
        # scale for w13 per expert.
        else:
            # Fp8 moe kernels require a single activation scale.
            # We take the max of all the scales in case they differ.
            if self.quant_config.activation_scheme == "static":
                if (layer.w13_input_scale is None
                        or layer.w2_input_scale is None):
                    raise ValueError(
                        "QuantConfig has static quantization, but found "
                        "activation scales are None.")
                if (not all_close_1d(layer.w13_input_scale)
                        or not all_close_1d(layer.w2_input_scale)):
                    logger.warning_once(
                        "Found input_scales that are not equal for "
                        "fp8 MoE layer. Using the maximum across experts "
                        "for each layer.")
                layer.w13_input_scale = torch.nn.Parameter(
                    layer.w13_input_scale.max(), requires_grad=False)
                layer.w2_input_scale = torch.nn.Parameter(
                    layer.w2_input_scale.max(), requires_grad=False)
            if current_platform.is_fp8_fnuz():
                # Normalize the weights and scales
                w13_weight, w13_weight_scale, w13_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w13_weight, layer.w13_weight_scale,
                        layer.w13_input_scale)
                w2_weight, w2_weight_scale, w2_input_scale = \
                    normalize_e4m3fn_to_e4m3fnuz(
                        layer.w2_weight, layer.w2_weight_scale,
                        layer.w2_input_scale)
                # Reset the parameter
                layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                      requires_grad=False)
                layer.w13_weight_scale = torch.nn.Parameter(
                    w13_weight_scale, requires_grad=False)
                if w13_input_scale is not None:
                    layer.w13_input_scale = torch.nn.Parameter(
                        w13_input_scale, requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                     requires_grad=False)
                layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                           requires_grad=False)
                if w2_input_scale is not None:
                    layer.w2_input_scale = torch.nn.Parameter(
                        w2_input_scale, requires_grad=False)

            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max then dequant and requant each expert.
            assert layer.w13_weight_scale is not None
            shard_size = layer.intermediate_size_per_partition
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values
            for expert_id in range(layer.local_num_experts):
                start = 0
                for shard_id in range(2):
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    shard_size, :],
                        layer.w13_weight_scale[expert_id][shard_id])
                    layer.w13_weight[expert_id][
                        start:start + shard_size, :], _ = ops.scaled_fp8_quant(
                            dq_weight, max_w13_scales[expert_id])
                    start += shard_size

            if self.rocm_aiter_moe_enabled:
                shuffled_w13, shuffled_w2 = shuffle_weights(
                    layer.w13_weight, layer.w2_weight)

                layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                      requires_grad=False)
                layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                     requires_grad=False)

            layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                        requires_grad=False)

            if self.flashinfer_moe_backend is not None:
                # NOTE: weights have to be swapped since the activation is
                # applied on different half for flashinfer vs vllm
                assert not self.block_quant
                register_moe_scaling_factors(layer)
                w13_weight = swap_w13_to_w31(layer.w13_weight.data)
                if self.flashinfer_moe_backend == \
                    FlashinferMoeBackend.TENSORRT_LLM:
                    rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
                layer.w13_weight.data = w13_weight.data

        if self.use_marlin:
            prepare_moe_fp8_layer_for_marlin(layer, False)
            # Activations not quantized for marlin.
            del layer.w13_input_scale
            del layer.w2_input_scale

        if is_blackwell_deep_gemm_e8m0_used():
            assert layer.weight_block_size is not None
            # Re-quantise the expert weights so their scales are UE8M0.
            block_sz = tuple(layer.weight_block_size)
            requant_weight_ue8m0_inplace(
                layer.w13_weight.data,
                layer.w13_weight_scale_inv.data,
                block_sz,
            )
            requant_weight_ue8m0_inplace(
                layer.w2_weight.data,
                layer.w2_weight_scale_inv.data,
                block_sz,
            )

            # Ensure column-major TMA alignment expected by DeepGEMM.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w13_weight_scale_inv).contiguous()
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                    layer.w2_weight_scale_inv).contiguous()

    def select_gemm_impl(
        self,
        prepare_finalize: FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> FusedMoEPermuteExpertsUnpermute:
        from vllm.model_executor.layers.fused_moe import (
            BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

        assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
            "Marlin and ROCm AITER are not supported with all2all yet.")

        if (prepare_finalize.activation_format ==
                FusedMoEActivationFormat.BatchedExperts):
            max_num_tokens_per_rank = (
                prepare_finalize.max_num_tokens_per_rank())
            assert max_num_tokens_per_rank is not None
            logger.debug(
                "BatchedTritonOrDeepGemmExperts(%s): "
                "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
                self.__class__.__name__, max_num_tokens_per_rank,
                self.quant_config.weight_block_size, False)
            return BatchedTritonOrDeepGemmExperts(
                max_num_tokens=max_num_tokens_per_rank,
                num_dispatchers=prepare_finalize.num_dispatchers(),
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                per_act_token_quant=False,
                allow_deep_gemm=self.allow_deep_gemm,
            )
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            experts = select_cutlass_fp8_gemm_impl(
                moe,
                self.layer,
            )
            logger.debug_once("Using %s", experts.__class__.__name__)
            return experts
        else:
            logger.debug(
                "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
                self.__class__.__name__, self.quant_config.weight_block_size,
                False)
            return TritonOrDeepGemmExperts(
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
            )

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if enable_eplb:
            assert expert_load_view is not None
            assert logical_to_physical_map is not None
            assert logical_replica_count is not None
            assert isinstance(layer, FusedMoE)

        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
            if self.block_quant:
                assert (renormalize and use_grouped_topk
                        and custom_routing_function is None)

                return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                    routing_logits=router_logits.to(torch.float32),
                    routing_bias=e_score_correction_bias,
                    x=x,
                    w13_weight=layer.w13_weight,
                    w13_weight_scale_inv=layer.w13_weight_scale_inv,
                    w2_weight=layer.w2_weight,
                    w2_weight_scale_inv=layer.w2_weight_scale_inv,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    intermediate_size=layer.intermediate_size_per_partition,
                    expert_offset=layer.ep_rank * layer.local_num_experts,
                    local_num_experts=layer.local_num_experts,
                    block_shape=self.quant_config.weight_block_size,
                    routed_scaling=1.0,
                )
            else:
                assert (not renormalize
                        and custom_routing_function is not None)
                return apply_flashinfer_per_tensor_scale_fp8(
                    layer=layer,
                    hidden_states=x,
                    router_logits=router_logits,
                    routing_bias=e_score_correction_bias,
                    global_num_experts=global_num_experts,
                    top_k=top_k,
                    num_expert_group=num_expert_group,
                    topk_group=topk_group,
                    apply_router_weight_on_input=apply_router_weight_on_input)

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
            enable_eplb=enable_eplb,
            expert_map=expert_map,
            expert_load_view=expert_load_view,
            logical_to_physical_map=logical_to_physical_map,
            logical_replica_count=logical_replica_count,
        )

        if self.rocm_aiter_moe_enabled:
            from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
                rocm_aiter_fused_experts)
            return rocm_aiter_fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                activation=activation,
                use_fp8_w8a8=True,
                apply_router_weight_on_input=apply_router_weight_on_input,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
                block_shape=self.quant_config.weight_block_size,
                expert_map=expert_map)
        elif self.use_marlin:
            assert activation == "silu", (
                f"{activation} not supported for Marlin MoE.")
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                quant_type_id=scalar_types.float8_e4m3fn.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)
        elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            assert self.block_quant is None
            assert (not renormalize and custom_routing_function is not None)
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert scoring_func == 'sigmoid', (
                f"Expected 'sigmoid' scoring func but got {scoring_func}")
            if self.fused_experts is not None:
                return self.fused_experts(
                    x,
                    layer.w13_weight,
                    layer.w2_weight,
                    topk_weights,
                    topk_ids,
                    inplace=False,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=expert_map,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
            else:
                return flashinfer_cutlass_moe_fp8(
                    x,
                    layer,
                    topk_weights,
                    topk_ids,
                    inplace=False,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=expert_map,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
        else:
            common_kwargs = dict(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=True,
                activation=activation,
                global_num_experts=global_num_experts,
                apply_router_weight_on_input=apply_router_weight_on_input,
                expert_map=expert_map,
                w1_scale=(layer.w13_weight_scale_inv
                          if self.block_quant else layer.w13_weight_scale),
                w2_scale=(layer.w2_weight_scale_inv
                          if self.block_quant else layer.w2_weight_scale),
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )

            if self.fused_experts is not None:
                return self.fused_experts(**common_kwargs)
            else:
                from vllm.model_executor.layers.fused_moe import fused_experts
                return fused_experts(
                    **common_kwargs,
                    use_fp8_w8a8=True,
                    block_shape=self.quant_config.weight_block_size,
                    allow_deep_gemm=self.allow_deep_gemm,
                    allow_cutlass_block_scaled_grouped_gemm=(
                        self.allow_cutlass_block_scaled_grouped_gemm),
                )

allow_cutlass_block_scaled_grouped_gemm instance-attribute

allow_cutlass_block_scaled_grouped_gemm = False

allow_deep_gemm instance-attribute

allow_deep_gemm = False

block_quant instance-attribute

block_quant = weight_block_size is not None

flashinfer_moe_backend instance-attribute

flashinfer_moe_backend: Optional[FlashinferMoeBackend] = (
    None
)

fused_experts instance-attribute

fused_experts: Optional[FusedMoEModularKernel] = None

layer instance-attribute

layer = layer

quant_config instance-attribute

quant_config = quant_config

use_marlin instance-attribute

use_marlin = (
    not has_device_capability(89)
    or VLLM_TEST_FORCE_FP8_MARLIN
)

__init__

__init__(quant_config: Fp8Config, layer: Module)
Source code in vllm/model_executor/layers/quantization/fp8.py
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
    super().__init__(layer.moe_config)
    self.layer = layer
    self.quant_config = quant_config
    self.block_quant = self.quant_config.weight_block_size is not None

    self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
    self.fused_experts: Optional[
        mk.FusedMoEModularKernel] = None  # type: ignore
    if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
        self.flashinfer_moe_backend = get_flashinfer_moe_backend()
        logger.info_once(
            f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
        )
    # For GPUs that lack FP8 hardware support, we can leverage the Marlin
    # kernel for fast weight-only FP8 quantization
    self.use_marlin = (not current_platform.has_device_capability(89)
                       or envs.VLLM_TEST_FORCE_FP8_MARLIN)
    # Disable marlin for rocm
    if current_platform.is_rocm():
        self.use_marlin = False

    # Check for DeepGemm support.
    self.allow_deep_gemm = False
    if envs.VLLM_USE_DEEP_GEMM:
        if not has_deep_gemm():
            logger.warning_once("Failed to import DeepGemm kernels.")
        elif not self.block_quant:
            logger.warning_once("Model is not block quantized. Not using "
                                "DeepGemm kernels")
        elif (is_deep_gemm_supported()):
            logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
            self.allow_deep_gemm = True
        else:
            logger.warning_once(
                "DeepGemm not supported on the current platform.")

    # Check for CutlassBlockScaledGroupedGemm support.
    self.allow_cutlass_block_scaled_grouped_gemm = False
    if not self.block_quant:
        logger.debug_once("Model is not block quantized. Not using "
                          "CutlassBlockScaledGroupedGemm kernels")
    elif (current_platform.is_cuda()
          and current_platform.is_device_capability(100)):
        logger.info_once(
            "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
        )
        self.allow_cutlass_block_scaled_grouped_gemm = True
    else:
        logger.warning_once(
            "CutlassBlockScaledGroupedGemm not supported on the current "
            "platform.")

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/fp8.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if enable_eplb:
        assert expert_load_view is not None
        assert logical_to_physical_map is not None
        assert logical_replica_count is not None
        assert isinstance(layer, FusedMoE)

    if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        assert activation == 'silu', (
            f"Expected 'silu' activation but got {activation}")
        assert scoring_func == 'sigmoid', (
            f"Expected 'sigmoid' scoring func but got {scoring_func}")
        if self.block_quant:
            assert (renormalize and use_grouped_topk
                    and custom_routing_function is None)

            return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
                routing_logits=router_logits.to(torch.float32),
                routing_bias=e_score_correction_bias,
                x=x,
                w13_weight=layer.w13_weight,
                w13_weight_scale_inv=layer.w13_weight_scale_inv,
                w2_weight=layer.w2_weight,
                w2_weight_scale_inv=layer.w2_weight_scale_inv,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
                intermediate_size=layer.intermediate_size_per_partition,
                expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                block_shape=self.quant_config.weight_block_size,
                routed_scaling=1.0,
            )
        else:
            assert (not renormalize
                    and custom_routing_function is not None)
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
                apply_router_weight_on_input=apply_router_weight_on_input)

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype,
        enable_eplb=enable_eplb,
        expert_map=expert_map,
        expert_load_view=expert_load_view,
        logical_to_physical_map=logical_to_physical_map,
        logical_replica_count=logical_replica_count,
    )

    if self.rocm_aiter_moe_enabled:
        from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (  # noqa: E501
            rocm_aiter_fused_experts)
        return rocm_aiter_fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=activation,
            use_fp8_w8a8=True,
            apply_router_weight_on_input=apply_router_weight_on_input,
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            block_shape=self.quant_config.weight_block_size,
            expert_map=expert_map)
    elif self.use_marlin:
        assert activation == "silu", (
            f"{activation} not supported for Marlin MoE.")
        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            None,
            None,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=scalar_types.float8_e4m3fn.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map)
    elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
        assert self.block_quant is None
        assert (not renormalize and custom_routing_function is not None)
        assert activation == 'silu', (
            f"Expected 'silu' activation but got {activation}")
        assert scoring_func == 'sigmoid', (
            f"Expected 'sigmoid' scoring func but got {scoring_func}")
        if self.fused_experts is not None:
            return self.fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
    else:
        common_kwargs = dict(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            global_num_experts=global_num_experts,
            apply_router_weight_on_input=apply_router_weight_on_input,
            expert_map=expert_map,
            w1_scale=(layer.w13_weight_scale_inv
                      if self.block_quant else layer.w13_weight_scale),
            w2_scale=(layer.w2_weight_scale_inv
                      if self.block_quant else layer.w2_weight_scale),
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
        )

        if self.fused_experts is not None:
            return self.fused_experts(**common_kwargs)
        else:
            from vllm.model_executor.layers.fused_moe import fused_experts
            return fused_experts(
                **common_kwargs,
                use_fp8_w8a8=True,
                block_shape=self.quant_config.weight_block_size,
                allow_deep_gemm=self.allow_deep_gemm,
                allow_cutlass_block_scaled_grouped_gemm=(
                    self.allow_cutlass_block_scaled_grouped_gemm),
            )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fp8.py
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
                   intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    layer.intermediate_size_per_partition = intermediate_size_per_partition
    layer.hidden_size = hidden_size
    layer.num_experts = num_experts
    layer.orig_dtype = params_dtype
    layer.weight_block_size = None

    if self.quant_config.is_checkpoint_fp8_serialized:
        params_dtype = torch.float8_e4m3fn
    if self.block_quant:
        assert self.quant_config.weight_block_size is not None
        layer.weight_block_size = self.quant_config.weight_block_size
        tp_size = get_tensor_model_parallel_world_size()
        block_n, block_k = (
            self.quant_config.weight_block_size[0],
            self.quant_config.weight_block_size[1],
        )
        # NOTE: To ensure proper alignment of the block-wise quantization
        # scales, the output_size of the weights for both the gate and up
        # layers must be divisible by block_n.
        # Required by column parallel or enabling merged weights
        if intermediate_size_per_partition % block_n != 0:
            raise ValueError(
                f"The output_size of gate's and up's weight = "
                f"{intermediate_size_per_partition} is not divisible by "
                f"weight quantization block_n = {block_n}.")
        if (tp_size > 1
                and intermediate_size_per_partition % block_k != 0):
            # Required by row parallel
            raise ValueError(
                f"The input_size of down's weight = "
                f"{intermediate_size_per_partition} is not divisible by "
                f"weight quantization block_k = {block_k}.")

    # WEIGHTS
    w13_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        2 * intermediate_size_per_partition,
        hidden_size,
        dtype=params_dtype),
                                    requires_grad=False)
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w2_weight = torch.nn.Parameter(torch.empty(
        num_experts,
        hidden_size,
        intermediate_size_per_partition,
        dtype=params_dtype),
                                   requires_grad=False)
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    # WEIGHT_SCALES
    if not self.block_quant:
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts, 2, dtype=torch.float32),
                                              requires_grad=False)
        w2_weight_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
    else:
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(
                num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) //
                     block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32,
            ),
            requires_grad=False,
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.ones(
                num_experts,
                (hidden_size + block_n - 1) // block_n,
                (intermediate_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
        layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
        assert self.quant_config.activation_scheme == "dynamic"

    # Add the quantization method used (per tensor/grouped/channel)
    # to ensure the weight scales are loaded in properly
    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
         value} if self.block_quant else
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
    # If loading fp8 checkpoint, pass the weight loaders.
    # If loading an fp16 checkpoint, do not (we will quantize in
    #   process_weights_after_loading()
    if self.quant_config.is_checkpoint_fp8_serialized:
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    # INPUT_SCALES
    if self.quant_config.activation_scheme == "static":
        if not self.quant_config.is_checkpoint_fp8_serialized:
            raise ValueError(
                "Found static activation scheme for checkpoint that "
                "was not serialized fp8.")

        w13_input_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                             requires_grad=False)
        layer.register_parameter("w13_input_scale", w13_input_scale)
        set_weight_attrs(w13_input_scale, extra_weight_attrs)

        w2_input_scale = torch.nn.Parameter(torch.ones(
            num_experts, dtype=torch.float32),
                                            requires_grad=False)
        layer.register_parameter("w2_input_scale", w2_input_scale)
        set_weight_attrs(w2_input_scale, extra_weight_attrs)

    else:
        layer.w13_input_scale = None
        layer.w2_input_scale = None

maybe_make_prepare_finalize

maybe_make_prepare_finalize(
    moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]
Source code in vllm/model_executor/layers/quantization/fp8.py
def maybe_make_prepare_finalize(
    self,
    moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
    if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
        return super().maybe_make_prepare_finalize(moe)

    prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
        moe,
        layer=self.layer,
    )
    logger.debug_once("%s", prepare_finalize.__class__.__name__)
    return prepare_finalize

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    # Lazy import to avoid importing triton too early.
    from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
        is_rocm_aiter_moe_enabled, shuffle_weights)

    self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

    # TODO (rob): refactor block quant into separate class.
    if self.block_quant:
        assert self.quant_config.activation_scheme == "dynamic"
        if current_platform.is_fp8_fnuz():
            w13_weight, w13_weight_scale_inv, w13_input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    layer.w13_weight, layer.w13_weight_scale_inv,
                    layer.w13_input_scale)
            w2_weight, w2_weight_scale_inv, w2_input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    layer.w2_weight, layer.w2_weight_scale_inv,
                    layer.w2_input_scale)
        elif self.flashinfer_moe_backend is not None:
            # NOTE: weights have to be swapped since the activation is
            # applied on different half for flashinfer vs vllm
            w13_weight = swap_w13_to_w31(layer.w13_weight.data)
            w13_weight_scale_inv = swap_w13_to_w31(
                layer.w13_weight_scale_inv.data)
            w2_weight = layer.w2_weight.data
            w2_weight_scale_inv = layer.w2_weight_scale_inv.data
        else:
            w13_weight = layer.w13_weight.data
            w13_weight_scale_inv = layer.w13_weight_scale_inv.data
            w2_weight = layer.w2_weight
            w2_weight_scale_inv = layer.w2_weight_scale_inv

        # torch.compile() cannot use Parameter subclasses.
        layer.w13_weight = Parameter(w13_weight, requires_grad=False)
        layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
                                               requires_grad=False)
        layer.w2_weight = Parameter(w2_weight, requires_grad=False)
        layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
                                              requires_grad=False)
        if self.rocm_aiter_moe_enabled:
            # reshaping weights is required for aiter moe kernel.
            shuffled_w13, shuffled_w2 = shuffle_weights(
                layer.w13_weight.data, layer.w2_weight.data)

            layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                 requires_grad=False)

        # DeepGemm scales need to be transposed and aligned.  We try to do
        # it ahead of time for performance reasons.
        if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
            # Lazy import to avoid CUDA initialization problems.
            if _is_col_major(layer.w13_weight_scale_inv):
                layer.w13_weight_scale_inv = \
                    get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
            if _is_col_major(layer.w2_weight_scale_inv):
                layer.w2_weight_scale_inv = \
                    get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()

    # If checkpoint is fp16, quantize in place.
    elif not self.quant_config.is_checkpoint_fp8_serialized:
        fp8_dtype = current_platform.fp8_dtype()
        w13_weight = torch.empty_like(layer.w13_weight.data,
                                      dtype=fp8_dtype)
        w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)

        # Re-initialize w13_scale because we directly quantize
        # merged w13 weights and generate a single scaling factor.
        layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
            layer.local_num_experts,
            dtype=torch.float32,
            device=w13_weight.device),
                                                    requires_grad=False)
        for expert in range(layer.local_num_experts):
            w13_weight[expert, :, :], layer.w13_weight_scale[
                expert] = ops.scaled_fp8_quant(
                    layer.w13_weight.data[expert, :, :])
            w2_weight[expert, :, :], layer.w2_weight_scale[
                expert] = ops.scaled_fp8_quant(
                    layer.w2_weight.data[expert, :, :])
        layer.w13_weight = torch.nn.Parameter(w13_weight,
                                              requires_grad=False)
        layer.w2_weight = torch.nn.Parameter(w2_weight,
                                             requires_grad=False)
        if self.rocm_aiter_moe_enabled:
            # reshaping weights is required for aiter moe kernel.
            shuffled_w13, shuffled_w2 = shuffle_weights(
                layer.w13_weight, layer.w2_weight)

            layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                 requires_grad=False)
    # If checkpoint is fp8, we need to handle that the
    # MoE kernels require single activation scale and single weight
    # scale for w13 per expert.
    else:
        # Fp8 moe kernels require a single activation scale.
        # We take the max of all the scales in case they differ.
        if self.quant_config.activation_scheme == "static":
            if (layer.w13_input_scale is None
                    or layer.w2_input_scale is None):
                raise ValueError(
                    "QuantConfig has static quantization, but found "
                    "activation scales are None.")
            if (not all_close_1d(layer.w13_input_scale)
                    or not all_close_1d(layer.w2_input_scale)):
                logger.warning_once(
                    "Found input_scales that are not equal for "
                    "fp8 MoE layer. Using the maximum across experts "
                    "for each layer.")
            layer.w13_input_scale = torch.nn.Parameter(
                layer.w13_input_scale.max(), requires_grad=False)
            layer.w2_input_scale = torch.nn.Parameter(
                layer.w2_input_scale.max(), requires_grad=False)
        if current_platform.is_fp8_fnuz():
            # Normalize the weights and scales
            w13_weight, w13_weight_scale, w13_input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    layer.w13_weight, layer.w13_weight_scale,
                    layer.w13_input_scale)
            w2_weight, w2_weight_scale, w2_input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    layer.w2_weight, layer.w2_weight_scale,
                    layer.w2_input_scale)
            # Reset the parameter
            layer.w13_weight = torch.nn.Parameter(w13_weight,
                                                  requires_grad=False)
            layer.w13_weight_scale = torch.nn.Parameter(
                w13_weight_scale, requires_grad=False)
            if w13_input_scale is not None:
                layer.w13_input_scale = torch.nn.Parameter(
                    w13_input_scale, requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(w2_weight,
                                                 requires_grad=False)
            layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
                                                       requires_grad=False)
            if w2_input_scale is not None:
                layer.w2_input_scale = torch.nn.Parameter(
                    w2_input_scale, requires_grad=False)

        # Fp8 moe kernel needs single weight scale for w13 per expert.
        # We take the max then dequant and requant each expert.
        assert layer.w13_weight_scale is not None
        shard_size = layer.intermediate_size_per_partition
        max_w13_scales = layer.w13_weight_scale.max(dim=1).values
        for expert_id in range(layer.local_num_experts):
            start = 0
            for shard_id in range(2):
                dq_weight = per_tensor_dequantize(
                    layer.w13_weight[expert_id][start:start +
                                                shard_size, :],
                    layer.w13_weight_scale[expert_id][shard_id])
                layer.w13_weight[expert_id][
                    start:start + shard_size, :], _ = ops.scaled_fp8_quant(
                        dq_weight, max_w13_scales[expert_id])
                start += shard_size

        if self.rocm_aiter_moe_enabled:
            shuffled_w13, shuffled_w2 = shuffle_weights(
                layer.w13_weight, layer.w2_weight)

            layer.w13_weight = torch.nn.Parameter(shuffled_w13,
                                                  requires_grad=False)
            layer.w2_weight = torch.nn.Parameter(shuffled_w2,
                                                 requires_grad=False)

        layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
                                                    requires_grad=False)

        if self.flashinfer_moe_backend is not None:
            # NOTE: weights have to be swapped since the activation is
            # applied on different half for flashinfer vs vllm
            assert not self.block_quant
            register_moe_scaling_factors(layer)
            w13_weight = swap_w13_to_w31(layer.w13_weight.data)
            if self.flashinfer_moe_backend == \
                FlashinferMoeBackend.TENSORRT_LLM:
                rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
            layer.w13_weight.data = w13_weight.data

    if self.use_marlin:
        prepare_moe_fp8_layer_for_marlin(layer, False)
        # Activations not quantized for marlin.
        del layer.w13_input_scale
        del layer.w2_input_scale

    if is_blackwell_deep_gemm_e8m0_used():
        assert layer.weight_block_size is not None
        # Re-quantise the expert weights so their scales are UE8M0.
        block_sz = tuple(layer.weight_block_size)
        requant_weight_ue8m0_inplace(
            layer.w13_weight.data,
            layer.w13_weight_scale_inv.data,
            block_sz,
        )
        requant_weight_ue8m0_inplace(
            layer.w2_weight.data,
            layer.w2_weight_scale_inv.data,
            block_sz,
        )

        # Ensure column-major TMA alignment expected by DeepGEMM.
        if _is_col_major(layer.w13_weight_scale_inv):
            layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
                layer.w13_weight_scale_inv).contiguous()
        if _is_col_major(layer.w2_weight_scale_inv):
            layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
                layer.w2_weight_scale_inv).contiguous()

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/fp8.py
def select_gemm_impl(
    self,
    prepare_finalize: FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
    from vllm.model_executor.layers.fused_moe import (
        BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

    assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
        "Marlin and ROCm AITER are not supported with all2all yet.")

    if (prepare_finalize.activation_format ==
            FusedMoEActivationFormat.BatchedExperts):
        max_num_tokens_per_rank = (
            prepare_finalize.max_num_tokens_per_rank())
        assert max_num_tokens_per_rank is not None
        logger.debug(
            "BatchedTritonOrDeepGemmExperts(%s): "
            "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
            self.__class__.__name__, max_num_tokens_per_rank,
            self.quant_config.weight_block_size, False)
        return BatchedTritonOrDeepGemmExperts(
            max_num_tokens=max_num_tokens_per_rank,
            num_dispatchers=prepare_finalize.num_dispatchers(),
            use_fp8_w8a8=True,
            block_shape=self.quant_config.weight_block_size,
            per_act_token_quant=False,
            allow_deep_gemm=self.allow_deep_gemm,
        )
    elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
        experts = select_cutlass_fp8_gemm_impl(
            moe,
            self.layer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts
    else:
        logger.debug(
            "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
            self.__class__.__name__, self.quant_config.weight_block_size,
            False)
        return TritonOrDeepGemmExperts(
            use_fp8_w8a8=True,
            block_shape=self.quant_config.weight_block_size,
            allow_deep_gemm=self.allow_deep_gemm,
        )

_is_col_major

_is_col_major(x: Tensor) -> bool
Source code in vllm/model_executor/layers/quantization/fp8.py
def _is_col_major(x: torch.Tensor) -> bool:
    assert x.dim() == 3
    b, m, n = x.shape
    return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m