Skip to content

vllm.model_executor.layers.quantization.modelopt

KV_CACHE_QUANT_ALGOS module-attribute

KV_CACHE_QUANT_ALGOS = ['FP8']

QUANT_ALGOS module-attribute

QUANT_ALGOS = ['FP8', 'NVFP4']

logger module-attribute

logger = init_logger(__name__)

ModelOptFp8Config

Bases: QuantizationConfig

Config class for ModelOpt FP8.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8Config(QuantizationConfig):
    """Config class for ModelOpt FP8."""

    def __init__(
        self,
        is_checkpoint_fp8_serialized: bool = False,
        kv_cache_quant_method: Optional[str] = None,
        exclude_modules: Optional[list[str]] = None,
    ) -> None:
        super().__init__()
        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
        self.kv_cache_quant_method = kv_cache_quant_method
        self.exclude_modules = exclude_modules
        if is_checkpoint_fp8_serialized:
            logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
                           " the format is experimental and could change.")

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "modelopt"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 89

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        """Detect if this ModelOpt config should be used based on
        quantization config."""

        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "FP8" in quant_algo:
                    return "modelopt"
        else:
            # Check for compressed-tensors style config with specific quant_algo
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP8" in quant_algo:
                return "modelopt"

        return None

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
        # Handle both ModelOpt format and compressed-tensors style format
        if "quantization" in config:
            # ModelOpt format: {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError(
                    "Expected 'quantization' to be a dictionary in config")
            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")
            kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
            exclude_modules = quant_config.get("exclude_modules")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")
            kv_cache_quant_method = config.get("kv_cache_quant_algo")
            exclude_modules = config.get("exclude_modules")

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration.")
        is_checkpoint_fp8_serialized = ("FP8" in quant_method)

        return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
                   exclude_modules)

    def is_layer_excluded(self, prefix: str) -> bool:
        """
        Check if a layer should be excluded from quantization.

        This method handles both regular models and multimodal models that use
        the language_model prefix. For multimodal models, it checks if the
        module name (without the language_model prefix) is in the exclude list.
        """
        if self.exclude_modules is None:
            return False

        # Check if any excluded module matches the prefix
        for module in self.exclude_modules:
            if (module in prefix
                    or (prefix.startswith("language_model.")
                        and module in prefix.removeprefix("language_model."))):
                return True
        return False

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import
        if isinstance(layer, LinearBase):
            if self.is_layer_excluded(prefix):
                return UnquantizedLinearMethod()
            return ModelOptFp8LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
        elif isinstance(layer, FusedMoE):
            return ModelOptFp8MoEMethod(self, layer)
        return None

exclude_modules instance-attribute

exclude_modules = exclude_modules

is_checkpoint_fp8_serialized instance-attribute

is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized

kv_cache_quant_method instance-attribute

kv_cache_quant_method = kv_cache_quant_method

__init__

__init__(
    is_checkpoint_fp8_serialized: bool = False,
    kv_cache_quant_method: Optional[str] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    is_checkpoint_fp8_serialized: bool = False,
    kv_cache_quant_method: Optional[str] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None:
    super().__init__()
    self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
    self.kv_cache_quant_method = kv_cache_quant_method
    self.exclude_modules = exclude_modules
    if is_checkpoint_fp8_serialized:
        logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
                       " the format is experimental and could change.")

from_config classmethod

from_config(config: dict[str, Any]) -> ModelOptFp8Config
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
    # Handle both ModelOpt format and compressed-tensors style format
    if "quantization" in config:
        # ModelOpt format: {"quantization": {"quant_algo": "..."}}
        quant_config = cls.get_from_keys(config, ["quantization"])
        if not isinstance(quant_config, dict):
            raise ValueError(
                "Expected 'quantization' to be a dictionary in config")
        quant_method = quant_config.get("quant_algo", "")
        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")
        kv_cache_quant_method = quant_config.get("kv_cache_quant_algo")
        exclude_modules = quant_config.get("exclude_modules")
    else:
        # Compressed-tensors style format:
        # {"quant_algo": "...", "quant_method": "modelopt"}
        quant_method = config.get("quant_algo", "")
        kv_cache_quant_method = config.get("kv_cache_quant_algo")
        exclude_modules = config.get("exclude_modules")

    if quant_method not in QUANT_ALGOS:
        raise ValueError(
            f"ModelOpt currently only supports: {QUANT_ALGOS} "
            "quantizations in vLLM. Please check the "
            "`hf_quant_config.json` file for your model's "
            "quant configuration.")
    is_checkpoint_fp8_serialized = ("FP8" in quant_method)

    return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
               exclude_modules)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_min_capability(cls) -> int:
    return 89

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "modelopt"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import
    if isinstance(layer, LinearBase):
        if self.is_layer_excluded(prefix):
            return UnquantizedLinearMethod()
        return ModelOptFp8LinearMethod(self)
    elif isinstance(layer, Attention):
        return ModelOptFp8KVCacheMethod(self)
    elif isinstance(layer, FusedMoE):
        return ModelOptFp8MoEMethod(self, layer)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

is_layer_excluded

is_layer_excluded(prefix: str) -> bool

Check if a layer should be excluded from quantization.

This method handles both regular models and multimodal models that use the language_model prefix. For multimodal models, it checks if the module name (without the language_model prefix) is in the exclude list.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def is_layer_excluded(self, prefix: str) -> bool:
    """
    Check if a layer should be excluded from quantization.

    This method handles both regular models and multimodal models that use
    the language_model prefix. For multimodal models, it checks if the
    module name (without the language_model prefix) is in the exclude list.
    """
    if self.exclude_modules is None:
        return False

    # Check if any excluded module matches the prefix
    for module in self.exclude_modules:
        if (module in prefix
                or (prefix.startswith("language_model.")
                    and module in prefix.removeprefix("language_model."))):
            return True
    return False

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]

Detect if this ModelOpt config should be used based on quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def override_quantization_method(
        cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
    """Detect if this ModelOpt config should be used based on
    quantization config."""

    if hf_quant_cfg is None:
        return None

    # Use the community standard 'quant_method'
    quant_method = hf_quant_cfg.get("quant_method", "").lower()

    # Only proceed if the method is explicitly "modelopt"
    if quant_method != "modelopt":
        return None

    # Look for ModelOpt-specific config structure
    if "quantization" in hf_quant_cfg:
        quant_config = hf_quant_cfg["quantization"]
        if isinstance(quant_config, dict):
            quant_algo = quant_config.get("quant_algo", "")
            if "FP8" in quant_algo:
                return "modelopt"
    else:
        # Check for compressed-tensors style config with specific quant_algo
        quant_algo = hf_quant_cfg.get("quant_algo", "")
        if isinstance(quant_algo, str) and "FP8" in quant_algo:
            return "modelopt"

    return None

ModelOptFp8KVCacheMethod

Bases: BaseKVCacheMethod

Supports loading kv-cache scaling factors from FP8 checkpoints.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: Union[ModelOptFp8Config,
                                           ModelOptNvFp4Config]):
        super().__init__(quant_config)

__init__

__init__(
    quant_config: Union[
        ModelOptFp8Config, ModelOptNvFp4Config
    ],
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: Union[ModelOptFp8Config,
                                       ModelOptNvFp4Config]):
    super().__init__(quant_config)

ModelOptFp8LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and activation scale. Future support might be added for dynamic scales.

Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer static quantization.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale. Future support might be added for dynamic
    scales.

    Limitations:
    1. Only support per-tensor quantization due to torch._scaled_mm support.
    2. Only support float8_e4m3fn datatype
        Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptFp8Config) -> None:
        self.quant_config = quant_config
        self.fp8_linear = Fp8LinearOp(
            act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=weight_dtype),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALE
            weight_scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                                   weight_loader=weight_loader)
            weight_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("weight_scale", weight_scale)
            # INPUT SCALE
            scale = PerTensorScaleParameter(data=torch.empty(
                len(output_partition_sizes), dtype=torch.float32),
                                            weight_loader=weight_loader)

            scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", scale)

    def process_weights_after_loading(self, layer: Module) -> None:
        weight = layer.weight
        max_w_scale = layer.weight_scale.max()
        if not (layer.weight_scale == layer.weight_scale[0]).all():
            max_w_scale, weight = requantize_with_max_scale(
                layer.weight, layer.weight_scale, layer.logical_widths)
        layer.weight = Parameter(weight.t(), requires_grad=False)
        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
        layer.input_scale = Parameter(layer.input_scale.max(),
                                      requires_grad=False)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
                                     input_scale=layer.input_scale,
                                     bias=bias)

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(
    act_quant_static=True, act_quant_group_shape=PER_TENSOR
)

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptFp8Config) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptFp8Config) -> None:
    self.quant_config = quant_config
    self.fp8_linear = Fp8LinearOp(
        act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return self.fp8_linear.apply(input=x,
                                 weight=layer.weight,
                                 weight_scale=layer.weight_scale,
                                 input_scale=layer.input_scale,
                                 bias=bias)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_fp8_serialized else
                    params_dtype)
    weight = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition,
        dtype=weight_dtype),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALE
        weight_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                               weight_loader=weight_loader)
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)
        # INPUT SCALE
        scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                        weight_loader=weight_loader)

        scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("input_scale", scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:
    weight = layer.weight
    max_w_scale = layer.weight_scale.max()
    if not (layer.weight_scale == layer.weight_scale[0]).all():
        max_w_scale, weight = requantize_with_max_scale(
            layer.weight, layer.weight_scale, layer.logical_widths)
    layer.weight = Parameter(weight.t(), requires_grad=False)
    layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
    layer.input_scale = Parameter(layer.input_scale.max(),
                                  requires_grad=False)

ModelOptFp8MoEMethod

Bases: FusedMoEMethodBase

MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
    """MoE method for ModelOpt FP8.
    Supports loading FP8 checkpoints with static weight scale and
    activation scale.
    Args:
        quant_config: The ModelOpt quantization config.
    """

    def __init__(
        self,
        quant_config: ModelOptFp8Config,
        layer: torch.nn.Module,
    ) -> None:
        super().__init__(layer.moe_config)
        self.layer = layer
        self.quant_config = quant_config
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            cutlass_fp8_supported)
        self.cutlass_fp8_supported = cutlass_fp8_supported()
        self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
        self.fused_experts: Optional[
            mk.FusedMoEModularKernel] = None  # type: ignore
        if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            )

    def maybe_make_prepare_finalize(
        self,
        moe: FusedMoEConfig,
    ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if self.fused_experts is not None or \
            self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
            return super().maybe_make_prepare_finalize(moe)

        prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
            moe,
            layer=self.layer,
        )
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
        experts = select_cutlass_fp8_gemm_impl(
            moe,
            self.layer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):

        # Use FP8 dtype if checkpoint is serialized
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_fp8_serialized else
                        params_dtype)
        weight_loader = extra_weight_attrs.get("weight_loader")

        w13_weight = ModelWeightParameter(
            data=torch.empty(num_experts,
                             2 * intermediate_size_per_partition,
                             hidden_size,
                             dtype=weight_dtype),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight", w13_weight)

        w2_weight = ModelWeightParameter(
            data=torch.empty(num_experts,
                             hidden_size,
                             intermediate_size_per_partition,
                             dtype=weight_dtype),
            input_dim=2,
            output_dim=1,
            weight_loader=weight_loader,
        )
        layer.register_parameter("w2_weight", w2_weight)

        if self.quant_config.is_checkpoint_fp8_serialized:
            # WEIGHT SCALES - Per-tensor scaling for ModelOpts
            # Allocate 2 scales for w1 and w3 respectively.
            # They will be combined to a single scale after weight loading.
            w13_weight_scale = PerTensorScaleParameter(
                data=torch.full(
                    (num_experts, 2),
                    1.0,
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            w2_weight_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_weight_scale", w13_weight_scale)
            layer.register_parameter("w2_weight_scale", w2_weight_scale)

            # Set weight loader attributes for scales
            extra_weight_attrs.update(
                {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

            # INPUT SCALES - Per-tensor scaling for ModelOpt
            w13_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            w2_input_scale = PerTensorScaleParameter(
                data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
                weight_loader=weight_loader,
            )
            layer.register_parameter("w13_input_scale", w13_input_scale)
            layer.register_parameter("w2_input_scale", w2_input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        """Process FP8 MoE weights after loading from serialized checkpoint.
        Only supports pre-quantized checkpoints with FP8 weights and scales.
        """

        layer.w13_weight = Parameter(layer.w13_weight.data,
                                     requires_grad=False)
        layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

        from vllm._custom_ops import scaled_fp8_quant
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            per_tensor_dequantize)

        # Handle scale parameters
        if hasattr(layer,
                   "w13_weight_scale") and layer.w13_weight_scale is not None:
            # Fp8 moe kernel needs single weight scale for w13 per expert.
            # We take the max of the w1 and w3 scales
            # then dequant and requant each expert.
            if layer.w13_weight_scale.dim() == 2:

                # Get the maximum scale across w1 and w3 for each expert
                max_w13_scales = layer.w13_weight_scale.max(dim=1).values

                # Requantize each expert's weights using the combined scale
                # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
                # where the first intermediate_size rows are w1, the next are w3
                intermediate_size = layer.w13_weight.shape[1] // 2
                for expert_id in range(layer.w13_weight.shape[0]):
                    start = 0
                    for shard_id in range(2):  # w1 and w3
                        # Dequantize using the original scale for this shard
                        dq_weight = per_tensor_dequantize(
                            layer.w13_weight[expert_id][start:start +
                                                        intermediate_size, :],
                            layer.w13_weight_scale[expert_id][shard_id],
                        )
                        # Requantize using the combined max scale

                        (
                            layer.w13_weight[expert_id][start:start +
                                                        intermediate_size, :],
                            _,
                        ) = scaled_fp8_quant(dq_weight,
                                             max_w13_scales[expert_id])

                        start += intermediate_size

                # Update the scale parameter to be per-expert
                layer.w13_weight_scale = Parameter(max_w13_scales,
                                                   requires_grad=False)
            else:
                layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
                                                   requires_grad=False)

        if hasattr(layer,
                   "w2_weight_scale") and layer.w2_weight_scale is not None:
            layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
                                              requires_grad=False)
        # Input scales must be equal for each expert in fp8 MoE layers.
        if hasattr(layer,
                   "w13_input_scale") and layer.w13_input_scale is not None:
            layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
                                              requires_grad=False)
        if hasattr(layer,
                   "w2_input_scale") and layer.w2_input_scale is not None:
            layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
                                             requires_grad=False)

        if self.flashinfer_moe_backend is not None:
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
            register_moe_scaling_factors(layer)
            if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
                rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
                                                  layer.w2_weight)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptFp8MoEMethod` yet.")

        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            assert not renormalize
            return apply_flashinfer_per_tensor_scale_fp8(
                layer=layer,
                hidden_states=x,
                router_logits=router_logits,
                routing_bias=e_score_correction_bias,
                global_num_experts=global_num_experts,
                top_k=top_k,
                num_expert_group=num_expert_group,
                topk_group=topk_group,
                apply_router_weight_on_input=apply_router_weight_on_input)

        # Expert selection
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype,
        )

        if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
            assert not renormalize
            assert activation == 'silu', (
                f"Expected 'silu' activation but got {activation}")
            if self.fused_experts is not None:
                return self.fused_experts(
                    x,
                    layer.w13_weight,
                    layer.w2_weight,
                    topk_weights,
                    topk_ids,
                    inplace=False,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=expert_map,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
            else:
                return flashinfer_cutlass_moe_fp8(
                    x,
                    layer,
                    topk_weights,
                    topk_ids,
                    inplace=False,
                    activation=activation,
                    global_num_experts=global_num_experts,
                    expert_map=expert_map,
                    apply_router_weight_on_input=apply_router_weight_on_input,
                )
        from vllm.model_executor.layers.fused_moe.fused_moe import (
            fused_experts)
        return fused_experts(
            x,
            layer.w13_weight,
            layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            activation=activation,
            use_fp8_w8a8=True,
            per_channel_quant=False,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            a1_scale=layer.w13_input_scale,
            a2_scale=layer.w2_input_scale,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

cutlass_fp8_supported instance-attribute

cutlass_fp8_supported = cutlass_fp8_supported()

flashinfer_moe_backend instance-attribute

flashinfer_moe_backend: Optional[FlashinferMoeBackend] = (
    None
)

fused_experts instance-attribute

fused_experts: Optional[FusedMoEModularKernel] = None

layer instance-attribute

layer = layer

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(
    quant_config: ModelOptFp8Config, layer: Module
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    quant_config: ModelOptFp8Config,
    layer: torch.nn.Module,
) -> None:
    super().__init__(layer.moe_config)
    self.layer = layer
    self.quant_config = quant_config
    from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
        cutlass_fp8_supported)
    self.cutlass_fp8_supported = cutlass_fp8_supported()
    self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
    self.fused_experts: Optional[
        mk.FusedMoEModularKernel] = None  # type: ignore
    if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
        self.flashinfer_moe_backend = get_flashinfer_moe_backend()
        logger.info_once(
            f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
        )

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `ModelOptFp8MoEMethod` yet.")

    if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        assert activation == 'silu', (
            f"Expected 'silu' activation but got {activation}")
        assert not renormalize
        return apply_flashinfer_per_tensor_scale_fp8(
            layer=layer,
            hidden_states=x,
            router_logits=router_logits,
            routing_bias=e_score_correction_bias,
            global_num_experts=global_num_experts,
            top_k=top_k,
            num_expert_group=num_expert_group,
            topk_group=topk_group,
            apply_router_weight_on_input=apply_router_weight_on_input)

    # Expert selection
    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype,
    )

    if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
        assert not renormalize
        assert activation == 'silu', (
            f"Expected 'silu' activation but got {activation}")
        if self.fused_experts is not None:
            return self.fused_experts(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
            return flashinfer_cutlass_moe_fp8(
                x,
                layer,
                topk_weights,
                topk_ids,
                inplace=False,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
    from vllm.model_executor.layers.fused_moe.fused_moe import (
        fused_experts)
    return fused_experts(
        x,
        layer.w13_weight,
        layer.w2_weight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        activation=activation,
        use_fp8_w8a8=True,
        per_channel_quant=False,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        w1_scale=layer.w13_weight_scale,
        w2_scale=layer.w2_weight_scale,
        a1_scale=layer.w13_input_scale,
        a2_scale=layer.w2_input_scale,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):

    # Use FP8 dtype if checkpoint is serialized
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_fp8_serialized else
                    params_dtype)
    weight_loader = extra_weight_attrs.get("weight_loader")

    w13_weight = ModelWeightParameter(
        data=torch.empty(num_experts,
                         2 * intermediate_size_per_partition,
                         hidden_size,
                         dtype=weight_dtype),
        input_dim=2,
        output_dim=1,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w13_weight", w13_weight)

    w2_weight = ModelWeightParameter(
        data=torch.empty(num_experts,
                         hidden_size,
                         intermediate_size_per_partition,
                         dtype=weight_dtype),
        input_dim=2,
        output_dim=1,
        weight_loader=weight_loader,
    )
    layer.register_parameter("w2_weight", w2_weight)

    if self.quant_config.is_checkpoint_fp8_serialized:
        # WEIGHT SCALES - Per-tensor scaling for ModelOpts
        # Allocate 2 scales for w1 and w3 respectively.
        # They will be combined to a single scale after weight loading.
        w13_weight_scale = PerTensorScaleParameter(
            data=torch.full(
                (num_experts, 2),
                1.0,
                dtype=torch.float32,
            ),
            weight_loader=weight_loader,
        )
        w2_weight_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        # Set weight loader attributes for scales
        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

        # INPUT SCALES - Per-tensor scaling for ModelOpt
        w13_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        w2_input_scale = PerTensorScaleParameter(
            data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("w13_input_scale", w13_input_scale)
        layer.register_parameter("w2_input_scale", w2_input_scale)

maybe_make_prepare_finalize

maybe_make_prepare_finalize(
    moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def maybe_make_prepare_finalize(
    self,
    moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
    if self.fused_experts is not None or \
        self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
        return super().maybe_make_prepare_finalize(moe)

    prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
        moe,
        layer=self.layer,
    )
    logger.debug_once("%s", prepare_finalize.__class__.__name__)
    return prepare_finalize

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None

Process FP8 MoE weights after loading from serialized checkpoint. Only supports pre-quantized checkpoints with FP8 weights and scales.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    """Process FP8 MoE weights after loading from serialized checkpoint.
    Only supports pre-quantized checkpoints with FP8 weights and scales.
    """

    layer.w13_weight = Parameter(layer.w13_weight.data,
                                 requires_grad=False)
    layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

    from vllm._custom_ops import scaled_fp8_quant
    from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
        per_tensor_dequantize)

    # Handle scale parameters
    if hasattr(layer,
               "w13_weight_scale") and layer.w13_weight_scale is not None:
        # Fp8 moe kernel needs single weight scale for w13 per expert.
        # We take the max of the w1 and w3 scales
        # then dequant and requant each expert.
        if layer.w13_weight_scale.dim() == 2:

            # Get the maximum scale across w1 and w3 for each expert
            max_w13_scales = layer.w13_weight_scale.max(dim=1).values

            # Requantize each expert's weights using the combined scale
            # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
            # where the first intermediate_size rows are w1, the next are w3
            intermediate_size = layer.w13_weight.shape[1] // 2
            for expert_id in range(layer.w13_weight.shape[0]):
                start = 0
                for shard_id in range(2):  # w1 and w3
                    # Dequantize using the original scale for this shard
                    dq_weight = per_tensor_dequantize(
                        layer.w13_weight[expert_id][start:start +
                                                    intermediate_size, :],
                        layer.w13_weight_scale[expert_id][shard_id],
                    )
                    # Requantize using the combined max scale

                    (
                        layer.w13_weight[expert_id][start:start +
                                                    intermediate_size, :],
                        _,
                    ) = scaled_fp8_quant(dq_weight,
                                         max_w13_scales[expert_id])

                    start += intermediate_size

            # Update the scale parameter to be per-expert
            layer.w13_weight_scale = Parameter(max_w13_scales,
                                               requires_grad=False)
        else:
            layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
                                               requires_grad=False)

    if hasattr(layer,
               "w2_weight_scale") and layer.w2_weight_scale is not None:
        layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
                                          requires_grad=False)
    # Input scales must be equal for each expert in fp8 MoE layers.
    if hasattr(layer,
               "w13_input_scale") and layer.w13_input_scale is not None:
        layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
                                          requires_grad=False)
    if hasattr(layer,
               "w2_input_scale") and layer.w2_input_scale is not None:
        layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
                                         requires_grad=False)

    if self.flashinfer_moe_backend is not None:
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
        register_moe_scaling_factors(layer)
        if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
                                              layer.w2_weight)

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/modelopt.py
def select_gemm_impl(
    self,
    prepare_finalize: mk.FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
    experts = select_cutlass_fp8_gemm_impl(
        moe,
        self.layer,
    )
    logger.debug_once("Using %s", experts.__class__.__name__)
    return experts

ModelOptNvFp4Config

Bases: QuantizationConfig

Config class for ModelOpt FP4.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptNvFp4Config(QuantizationConfig):
    """Config class for ModelOpt FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool,
        kv_cache_quant_algo: Optional[str],
        exclude_modules: list[str],
        group_size: int = 16,
    ) -> None:
        super().__init__()
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning(
                "Detected ModelOpt NVFP4 checkpoint. Please note that"
                " the format is experimental and could change in future.")

            self.group_size = group_size
            self.kv_cache_quant_algo = kv_cache_quant_algo
            self.exclude_modules = exclude_modules

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "modelopt_fp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        """Detect if this ModelOpt FP4 config should be used based on
        quantization config."""
        if hf_quant_cfg is None:
            return None

        # Use the community standard 'quant_method'
        quant_method = hf_quant_cfg.get("quant_method", "").lower()

        # Only proceed if the method is explicitly "modelopt"
        if quant_method != "modelopt":
            return None

        # Look for ModelOpt-specific config structure
        if "quantization" in hf_quant_cfg:
            quant_config = hf_quant_cfg["quantization"]
            if isinstance(quant_config, dict):
                quant_algo = quant_config.get("quant_algo", "")
                if "NVFP4" in quant_algo:
                    return "modelopt_fp4"
        else:
            # Check for compressed-tensors style config with specific
            # quant_algo field
            quant_algo = hf_quant_cfg.get("quant_algo", "")
            if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
                return "modelopt_fp4"

        return None

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
        # Handle both traditional ModelOpt format and compressed-tensors
        # style format
        if "quantization" in config:
            # Traditional ModelOpt format:
            # {"quantization": {"quant_algo": "..."}}
            quant_config = cls.get_from_keys(config, ["quantization"])
            if not isinstance(quant_config, dict):
                raise ValueError(
                    "Expected 'quantization' to be a dictionary in config")

            quant_method = quant_config.get("quant_algo", "")
            if not quant_method:
                raise ValueError("Missing 'quant_algo' in quantization config")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
                raise ValueError(f"kv_cache_quant_algo must be a string, got "
                                 f"{type(kv_cache_quant_algo_raw)}")

            # Handle group_size with proper type validation
            group_size_raw = quant_config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
                    raise ValueError(f"group_size must be an integer, got "
                                     f"{type(group_size_raw)}") from None

            exclude_modules = quant_config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
                raise ValueError(f"exclude_modules must be a list, got "
                                 f"{type(exclude_modules)}")
        else:
            # Compressed-tensors style format:
            # {"quant_algo": "...", "quant_method": "modelopt"}
            quant_method = config.get("quant_algo", "")

            # Handle kv_cache_quant_algo with proper type validation
            kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
            if kv_cache_quant_algo_raw is None:
                # No KV cache quantization by default
                kv_cache_quant_algo = None
            elif isinstance(kv_cache_quant_algo_raw, str):
                kv_cache_quant_algo = kv_cache_quant_algo_raw
            else:
                raise ValueError(f"kv_cache_quant_algo must be a string, got "
                                 f"{type(kv_cache_quant_algo_raw)}")

            # Handle group_size with proper type validation
            group_size_raw = config.get("group_size")
            if group_size_raw is None:
                group_size = 16  # Default value
            elif isinstance(group_size_raw, int):
                group_size = group_size_raw
            else:
                try:
                    group_size = int(group_size_raw)
                except (ValueError, TypeError):
                    raise ValueError(f"group_size must be an integer, got "
                                     f"{type(group_size_raw)}") from None

            exclude_modules = config.get("exclude_modules", [])
            if not isinstance(exclude_modules, list):
                raise ValueError(f"exclude_modules must be a list, got "
                                 f"{type(exclude_modules)}")

        if quant_method not in QUANT_ALGOS:
            raise ValueError(
                f"ModelOpt currently only supports: {QUANT_ALGOS} "
                "quantizations in vLLM. Please check the "
                "`hf_quant_config.json` file for your model's "
                "quant configuration.")
        is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)

        # For FP4, these fields are required
        if is_checkpoint_nvfp4_serialized and "quantization" in config:
            # Check if required fields are present in the quantization config
            quant_config = config["quantization"]
            required_fields = [
                "group_size", "kv_cache_quant_algo", "exclude_modules"
            ]
            missing_fields = [
                field for field in required_fields if field not in quant_config
            ]
            if missing_fields:
                raise ValueError(
                    f"NVFP4 quantization requires the following fields in "
                    f"hf_quant_config.json: {missing_fields}")

        return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
                   exclude_modules, group_size)

    def is_layer_excluded(self, prefix: str,
                          exclude_modules: list[str]) -> bool:
        import regex as re
        for pattern in exclude_modules:
            regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
            if re.fullmatch(regex_str, prefix):
                return True
        return False

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import
        if isinstance(layer, LinearBase):
            if (is_layer_skipped(prefix, self.exclude_modules)
                    or self.is_layer_excluded(prefix, self.exclude_modules)):
                return UnquantizedLinearMethod()
            return ModelOptNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return ModelOptFp8KVCacheMethod(self)
        elif isinstance(layer, FusedMoE):
            return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
        return None

exclude_modules instance-attribute

exclude_modules = exclude_modules

group_size instance-attribute

group_size = group_size

is_checkpoint_nvfp4_serialized instance-attribute

is_checkpoint_nvfp4_serialized = (
    is_checkpoint_nvfp4_serialized
)

kv_cache_quant_algo instance-attribute

kv_cache_quant_algo = kv_cache_quant_algo

__init__

__init__(
    is_checkpoint_nvfp4_serialized: bool,
    kv_cache_quant_algo: Optional[str],
    exclude_modules: list[str],
    group_size: int = 16,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    is_checkpoint_nvfp4_serialized: bool,
    kv_cache_quant_algo: Optional[str],
    exclude_modules: list[str],
    group_size: int = 16,
) -> None:
    super().__init__()
    self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
    if is_checkpoint_nvfp4_serialized:
        logger.warning(
            "Detected ModelOpt NVFP4 checkpoint. Please note that"
            " the format is experimental and could change in future.")

        self.group_size = group_size
        self.kv_cache_quant_algo = kv_cache_quant_algo
        self.exclude_modules = exclude_modules

from_config classmethod

from_config(config: dict[str, Any]) -> ModelOptNvFp4Config
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
    # Handle both traditional ModelOpt format and compressed-tensors
    # style format
    if "quantization" in config:
        # Traditional ModelOpt format:
        # {"quantization": {"quant_algo": "..."}}
        quant_config = cls.get_from_keys(config, ["quantization"])
        if not isinstance(quant_config, dict):
            raise ValueError(
                "Expected 'quantization' to be a dictionary in config")

        quant_method = quant_config.get("quant_algo", "")
        if not quant_method:
            raise ValueError("Missing 'quant_algo' in quantization config")

        # Handle kv_cache_quant_algo with proper type validation
        kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo")
        if kv_cache_quant_algo_raw is None:
            # No KV cache quantization by default
            kv_cache_quant_algo = None
        elif isinstance(kv_cache_quant_algo_raw, str):
            kv_cache_quant_algo = kv_cache_quant_algo_raw
        else:
            raise ValueError(f"kv_cache_quant_algo must be a string, got "
                             f"{type(kv_cache_quant_algo_raw)}")

        # Handle group_size with proper type validation
        group_size_raw = quant_config.get("group_size")
        if group_size_raw is None:
            group_size = 16  # Default value
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(f"group_size must be an integer, got "
                                 f"{type(group_size_raw)}") from None

        exclude_modules = quant_config.get("exclude_modules", [])
        if not isinstance(exclude_modules, list):
            raise ValueError(f"exclude_modules must be a list, got "
                             f"{type(exclude_modules)}")
    else:
        # Compressed-tensors style format:
        # {"quant_algo": "...", "quant_method": "modelopt"}
        quant_method = config.get("quant_algo", "")

        # Handle kv_cache_quant_algo with proper type validation
        kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo")
        if kv_cache_quant_algo_raw is None:
            # No KV cache quantization by default
            kv_cache_quant_algo = None
        elif isinstance(kv_cache_quant_algo_raw, str):
            kv_cache_quant_algo = kv_cache_quant_algo_raw
        else:
            raise ValueError(f"kv_cache_quant_algo must be a string, got "
                             f"{type(kv_cache_quant_algo_raw)}")

        # Handle group_size with proper type validation
        group_size_raw = config.get("group_size")
        if group_size_raw is None:
            group_size = 16  # Default value
        elif isinstance(group_size_raw, int):
            group_size = group_size_raw
        else:
            try:
                group_size = int(group_size_raw)
            except (ValueError, TypeError):
                raise ValueError(f"group_size must be an integer, got "
                                 f"{type(group_size_raw)}") from None

        exclude_modules = config.get("exclude_modules", [])
        if not isinstance(exclude_modules, list):
            raise ValueError(f"exclude_modules must be a list, got "
                             f"{type(exclude_modules)}")

    if quant_method not in QUANT_ALGOS:
        raise ValueError(
            f"ModelOpt currently only supports: {QUANT_ALGOS} "
            "quantizations in vLLM. Please check the "
            "`hf_quant_config.json` file for your model's "
            "quant configuration.")
    is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)

    # For FP4, these fields are required
    if is_checkpoint_nvfp4_serialized and "quantization" in config:
        # Check if required fields are present in the quantization config
        quant_config = config["quantization"]
        required_fields = [
            "group_size", "kv_cache_quant_algo", "exclude_modules"
        ]
        missing_fields = [
            field for field in required_fields if field not in quant_config
        ]
        if missing_fields:
            raise ValueError(
                f"NVFP4 quantization requires the following fields in "
                f"hf_quant_config.json: {missing_fields}")

    return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
               exclude_modules, group_size)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "modelopt_fp4"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import
    if isinstance(layer, LinearBase):
        if (is_layer_skipped(prefix, self.exclude_modules)
                or self.is_layer_excluded(prefix, self.exclude_modules)):
            return UnquantizedLinearMethod()
        return ModelOptNvFp4LinearMethod(self)
    elif isinstance(layer, Attention):
        return ModelOptFp8KVCacheMethod(self)
    elif isinstance(layer, FusedMoE):
        return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

is_layer_excluded

is_layer_excluded(
    prefix: str, exclude_modules: list[str]
) -> bool
Source code in vllm/model_executor/layers/quantization/modelopt.py
def is_layer_excluded(self, prefix: str,
                      exclude_modules: list[str]) -> bool:
    import regex as re
    for pattern in exclude_modules:
        regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
        if re.fullmatch(regex_str, prefix):
            return True
    return False

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]

Detect if this ModelOpt FP4 config should be used based on quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
@classmethod
def override_quantization_method(
        cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
    """Detect if this ModelOpt FP4 config should be used based on
    quantization config."""
    if hf_quant_cfg is None:
        return None

    # Use the community standard 'quant_method'
    quant_method = hf_quant_cfg.get("quant_method", "").lower()

    # Only proceed if the method is explicitly "modelopt"
    if quant_method != "modelopt":
        return None

    # Look for ModelOpt-specific config structure
    if "quantization" in hf_quant_cfg:
        quant_config = hf_quant_cfg["quantization"]
        if isinstance(quant_config, dict):
            quant_algo = quant_config.get("quant_algo", "")
            if "NVFP4" in quant_algo:
                return "modelopt_fp4"
    else:
        # Check for compressed-tensors style config with specific
        # quant_algo field
        quant_algo = hf_quant_cfg.get("quant_algo", "")
        if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
            return "modelopt_fp4"

    return None

ModelOptNvFp4FusedMoE

Bases: FusedMoEMethodBase

MoE Method for FP4 Quantization. Args: quant_config: NVFP4 Quant Config

Source code in vllm/model_executor/layers/quantization/modelopt.py
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
    """
    MoE Method for FP4 Quantization.
    Args:
        quant_config: NVFP4 Quant Config
    """

    def __init__(
        self,
        quant_config: ModelOptNvFp4Config,
        moe: FusedMoEConfig,
        layer: torch.nn.Module,
    ) -> None:
        from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
            detect_nvfp4_moe_support)
        super().__init__(moe)
        self.quant_config = quant_config
        self.layer = layer
        _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
        self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
        self.allow_flashinfer = _nvfp4.allow_flashinfer
        self.use_marlin = _nvfp4.use_marlin
        self.flashinfer_moe_backend = None

        if self.allow_flashinfer:
            self.flashinfer_moe_backend = get_flashinfer_moe_backend()
            logger.info_once(
                f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
                " for ModelOptNvFp4FusedMoE.")

    def maybe_make_prepare_finalize(
        self,
        moe: FusedMoEConfig,
    ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
        if (self.allow_flashinfer and self.flashinfer_moe_backend
                == FlashinferMoeBackend.CUTLASS):
            prepare_finalize = (
                build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                    moe,
                    a1_gscale=self.layer.w13_input_scale_quant,
                ))
            logger.debug_once("%s", prepare_finalize.__class__.__name__)
            return prepare_finalize

        return super().maybe_make_prepare_finalize(moe)

    def select_gemm_impl(
        self,
        prepare_finalize: mk.FusedMoEPrepareAndFinalize,
        moe: FusedMoEConfig,
    ) -> mk.FusedMoEPermuteExpertsUnpermute:
        experts = select_nvfp4_gemm_impl(
            moe,
            g1_alphas=self.layer.g1_alphas,
            g2_alphas=self.layer.g2_alphas,
            a1_gscale=self.layer.w13_input_scale_quant,
            a2_gscale=self.layer.w2_input_scale_quant,
            allow_flashinfer=self.allow_flashinfer,
        )
        logger.debug_once("Using %s", experts.__class__.__name__)
        return experts

    def uses_weight_scale_2_pattern(self) -> bool:
        """
        FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
        """
        return True

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")

        layer.num_experts = num_experts
        layer.params_dtype = params_dtype
        layer.quant_config = self.quant_config
        weight_dtype = torch.uint8
        weight_scale_dtype = torch.float8_e4m3fn
        weight_loader = extra_weight_attrs.get("weight_loader")
        # GEMM 1
        w13_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // 2,
                dtype=weight_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight", w13_weight)

        # GEMM 2
        w2_weight = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition // 2,
                dtype=weight_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight", w2_weight)

        w13_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                2 * intermediate_size_per_partition,
                # 2 fp4 items are packed in the input dimension
                hidden_size // self.quant_config.group_size,
                dtype=weight_scale_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight_scale", w13_weight_scale)

        w2_weight_scale = ModelWeightParameter(
            data=torch.empty(
                num_experts,
                hidden_size,
                # 2 fp4 items are packed in the input dimension
                intermediate_size_per_partition //
                self.quant_config.group_size,
                dtype=weight_scale_dtype),
            input_dim=1,
            output_dim=2,
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

        w13_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, 2, dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

        w2_weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(num_experts, dtype=torch.float32),
            weight_loader=weight_loader)
        layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

        extra_weight_attrs.update(
            {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

        w13_input_scale = PerTensorScaleParameter(data=torch.empty(
            num_experts, 2, dtype=torch.float32),
                                                  weight_loader=weight_loader)
        layer.register_parameter("w13_input_scale", w13_input_scale)

        w2_input_scale = PerTensorScaleParameter(data=torch.empty(
            num_experts, dtype=torch.float32),
                                                 weight_loader=weight_loader)
        layer.register_parameter("w2_input_scale", w2_input_scale)

    def prepare_static_weight_layouts_for_trtllm_moe(
        self,
        gemm1_weights: torch.Tensor,
        gemm2_weights: torch.Tensor,
        gemm1_scales_linear_fp4_bytes: torch.Tensor,
        gemm2_scales_linear_fp4_bytes: torch.Tensor,
        hidden_size: int,
        intermediate_size: int,
        num_experts: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Prepare quantized weights for kernel (done offline with weights)."""
        from flashinfer import (reorder_rows_for_gated_act_gemm,
                                shuffle_matrix_a, shuffle_matrix_sf_a)
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

        # Convert quantized weights to proper formats
        gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
            num_experts, 2 * intermediate_size, hidden_size // 2)  # packed fp4
        gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
            torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
                                         hidden_size //
                                         16)  # fp8 scaling factors

        gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
            num_experts, hidden_size, intermediate_size // 2)  # packed fp4
        gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
            torch.float8_e4m3fn).reshape(num_experts, hidden_size,
                                         intermediate_size //
                                         16)  # fp8 scaling factors

        # Reorder rows of W1 and scales for fused gated activation
        gemm1_weights_fp4_interleaved = []
        gemm1_scales_fp4_interleaved = []
        for i in range(num_experts):
            gemm1_weights_fp4_interleaved.append(
                reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
            gemm1_scales_fp4_interleaved.append(
                reorder_rows_for_gated_act_gemm(
                    gemm1_scales_linear_fp4[i].clone()))

        # Stack weights and scales for all experts
        gemm1_weights_fp4_interleaved = torch.stack(
            gemm1_weights_fp4_interleaved).reshape(num_experts,
                                                   2 * intermediate_size,
                                                   hidden_size // 2)
        gemm1_scales_fp4_interleaved = torch.stack(
            gemm1_scales_fp4_interleaved).reshape(num_experts,
                                                  2 * intermediate_size,
                                                  hidden_size // 16)

        # Shuffle weights and scaling factors for transposed mma output
        gemm1_weights_fp4_shuffled = []
        gemm1_scales_fp4_shuffled = []
        gemm2_weights_fp4_shuffled = []
        gemm2_scales_fp4_shuffled = []
        for i in range(num_experts):
            gemm1_weights_fp4_shuffled.append(
                shuffle_matrix_a(
                    gemm1_weights_fp4_interleaved[i].view(torch.uint8),
                    epilogue_tile_m))
            gemm1_scales_fp4_shuffled.append(
                shuffle_matrix_sf_a(
                    gemm1_scales_fp4_interleaved[i].view(torch.uint8),
                    epilogue_tile_m))

            gemm2_weights_fp4_shuffled.append(
                shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
                                 epilogue_tile_m))
            gemm2_scales_fp4_shuffled.append(
                shuffle_matrix_sf_a(
                    gemm2_scales_linear_fp4[i].view(torch.uint8),
                    epilogue_tile_m))

        # Stack weights for all experts
        gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
        gemm1_scales_fp4_shuffled = (
            torch.stack(gemm1_scales_fp4_shuffled).view(
                torch.float8_e4m3fn).reshape(num_experts,
                                             2 * intermediate_size,
                                             hidden_size // 16))

        gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
        gemm2_scales_fp4_shuffled = (
            torch.stack(gemm2_scales_fp4_shuffled).view(
                torch.float8_e4m3fn).reshape(num_experts, hidden_size,
                                             intermediate_size // 16))
        return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
                gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # GEMM 1 processing
        gemm1_weight = layer.w13_weight.data
        gemm1_weight_scale = layer.w13_weight_scale.data

        if self.allow_flashinfer:
            gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
                gemm1_weight, gemm1_weight_scale, dim=-2)

        layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
        layer.w13_weight_scale = Parameter(gemm1_weight_scale,
                                           requires_grad=False)

        # Common processing for w13_weight_scale_2
        if not torch.allclose(layer.w13_weight_scale_2[:, 0],
                              layer.w13_weight_scale_2[:, 1]):
            logger.warning_once(
                "w1_weight_scale_2 must match w3_weight_scale_2. "
                "Accuracy may be affected.")

        w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
        layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
                                             requires_grad=False)

        # Common processing for input scales and alphas
        w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
            torch.float32)
        layer.g1_alphas = Parameter(
            (w13_input_scale * w13_weight_scale_2).to(torch.float32),
            requires_grad=False)

        # This is for quantization, so we need to invert it.
        layer.w13_input_scale_quant = Parameter(
            (1 / w13_input_scale).to(torch.float32), requires_grad=False)

        # GEMM 2 processing
        layer.g2_alphas = Parameter(
            (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
            requires_grad=False)

        # This is for quantization, so we need to invert it.
        layer.w2_input_scale_quant = Parameter(
            (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)

        # TensorRT-LLM specific processing
        if self.allow_flashinfer and \
            self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            # Prepare static weights for TRT-LLM kernel
            (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
             gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
             ) = self.prepare_static_weight_layouts_for_trtllm_moe(
                 layer.w13_weight,
                 layer.w2_weight,
                 layer.w13_weight_scale,
                 layer.w2_weight_scale,
                 layer.w2_weight.size(-2),  # hidden_size
                 layer.w13_weight.size(-2) // 2,  # intermediate_size
                 layer.w13_weight.size(0),  # num_experts
             )

            layer.gemm1_weights_fp4_shuffled = Parameter(
                gemm1_weights_fp4_shuffled, requires_grad=False)
            layer.gemm2_weights_fp4_shuffled = Parameter(
                gemm2_weights_fp4_shuffled, requires_grad=False)
            layer.gemm1_scales_fp4_shuffled = Parameter(
                gemm1_scales_fp4_shuffled, requires_grad=False)
            layer.gemm2_scales_fp4_shuffled = Parameter(
                gemm2_scales_fp4_shuffled, requires_grad=False)

            # Additional parameter needed for TRT-LLM
            layer.g1_scale_c = Parameter(
                (layer.w2_input_scale_quant * layer.g1_alphas).to(
                    torch.float32),
                requires_grad=False,
            )

            # Clean up weights that won't be used by TRT-LLM
            del layer.w2_weight
            del layer.w2_weight_scale
            del layer.w13_weight
            del layer.w13_weight_scale
        else:
            # Non-TRT-LLM processing (Cutlass or non-flashinfer)
            assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
                "Expected weight_scale.dim(1) to be divisible by 16")
            assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
                "Weight Blockscale must be represented as FP8-E4M3")
            w13_blockscale_swizzled = swizzle_blockscale(
                layer.w13_weight_scale)
            layer.w13_weight_scale = Parameter(w13_blockscale_swizzled,
                                               requires_grad=False)

            assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
                "Expected weight_scale.dim(1) to be divisible by 16")
            assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
                "Weight Blockscale must be represented as FP8-E4M3")
            w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
            layer.w2_weight_scale = Parameter(w2_blockscale_swizzled,
                                              requires_grad=False)
            layer.w2_weight = Parameter(layer.w2_weight.data,
                                        requires_grad=False)

        if self.use_marlin:
            prepare_moe_fp4_layer_for_marlin(layer)
            del layer.g1_alphas
            del layer.g2_alphas
            del layer.w13_input_scale_quant
            del layer.w2_input_scale_quant

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ):
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
        assert activation == "silu", "Only SiLU activation is supported."

        if self.allow_flashinfer and \
            self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
            import flashinfer

            from vllm.model_executor.models.llama4 import Llama4MoE

            a1_gscale = layer.w13_input_scale_quant
            (hidden_states_fp4,
             hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
                 x,
                 a1_gscale,
                 is_sf_swizzled_layout=False,
             )
            use_llama4_routing = \
                custom_routing_function is Llama4MoE.custom_routing_function
            routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
            if use_llama4_routing:
                routing_method_type = flashinfer.RoutingMethodType.Llama4
            out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
                routing_logits=router_logits
                if use_llama4_routing else router_logits.to(torch.float32),
                routing_bias=e_score_correction_bias,
                hidden_states=hidden_states_fp4,
                hidden_states_scale=hidden_states_scale_linear_fp4.view(
                    torch.float8_e4m3fn).flatten(),
                gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
                gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
                    torch.float8_e4m3fn),
                gemm1_bias=None,
                gemm1_alpha=None,
                gemm1_beta=None,
                gemm1_clamp_limit=None,
                gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
                gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
                    torch.float8_e4m3fn),
                gemm2_bias=None,
                output1_scale_scalar=layer.g1_scale_c.data,
                output1_scale_gate_scalar=layer.g1_alphas.data,
                output2_scale_scalar=layer.g2_alphas.data,
                num_experts=global_num_experts,
                top_k=top_k,
                n_group=num_expert_group
                if num_expert_group is not None else 0,
                topk_group=topk_group if topk_group is not None else 0,
                intermediate_size=layer.intermediate_size_per_partition,
                local_expert_offset=layer.ep_rank * layer.local_num_experts,
                local_num_experts=layer.local_num_experts,
                routed_scaling_factor=None,
                tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
                                                     layer.local_num_experts),
                routing_method_type=routing_method_type,
                do_finalize=True,
            )[0]
            return out

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias,
            indices_type=self.topk_indices_dtype)

        if self.use_marlin:
            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                None,
                None,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=layer.w13_weight_scale_2,
                global_scale2=layer.w2_weight_scale_2,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                expert_map=expert_map)

        if self.fused_experts is not None:
            assert self.allow_flashinfer and \
               self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS

            assert is_valid_flashinfer_cutlass_fused_moe(
                x, layer.w13_weight, layer.w2_weight), (
                    "Flashinfer CUTLASS Fused MoE not applicable!")

            out = self.fused_experts(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        elif (self.allow_flashinfer
              and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
            from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
                flashinfer_cutlass_moe_fp4)

            out = flashinfer_cutlass_moe_fp4(
                hidden_states=x,
                w1=layer.w13_weight,
                w2=layer.w2_weight,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                g1_alphas=layer.g1_alphas,
                g2_alphas=layer.g2_alphas,
                a1_gscale=layer.w13_input_scale_quant,
                a2_gscale=layer.w2_input_scale_quant,
                inplace=False,  # TODO(shuw): fix later, now output is high prec
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )
        else:
            # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
            # only (no EP).
            from vllm.model_executor.layers.fused_moe.cutlass_moe import (
                cutlass_moe_fp4)
            out = cutlass_moe_fp4(
                a=x,
                w1_fp4=layer.w13_weight,
                w2_fp4=layer.w2_weight,
                w1_blockscale=layer.w13_weight_scale,
                w2_blockscale=layer.w2_weight_scale,
                g1_alphas=layer.g1_alphas,
                g2_alphas=layer.g2_alphas,
                a1_gscale=layer.w13_input_scale_quant,
                a2_gscale=layer.w2_input_scale_quant,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                m=x.shape[0],
                n=layer.w2_weight.shape[2] * 2,
                k=x.shape[1],
                e=layer.w13_weight.shape[0],
                expert_map=expert_map,
                apply_router_weight_on_input=apply_router_weight_on_input)

        return out

allow_flashinfer instance-attribute

allow_flashinfer = allow_flashinfer

cutlass_nvfp4_supported instance-attribute

cutlass_nvfp4_supported = cutlass_supported

flashinfer_moe_backend instance-attribute

flashinfer_moe_backend = None

layer instance-attribute

layer = layer

quant_config instance-attribute

quant_config = quant_config

use_marlin instance-attribute

use_marlin = use_marlin

__init__

__init__(
    quant_config: ModelOptNvFp4Config,
    moe: FusedMoEConfig,
    layer: Module,
) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(
    self,
    quant_config: ModelOptNvFp4Config,
    moe: FusedMoEConfig,
    layer: torch.nn.Module,
) -> None:
    from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (  # noqa: E501
        detect_nvfp4_moe_support)
    super().__init__(moe)
    self.quant_config = quant_config
    self.layer = layer
    _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
    self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
    self.allow_flashinfer = _nvfp4.allow_flashinfer
    self.use_marlin = _nvfp4.use_marlin
    self.flashinfer_moe_backend = None

    if self.allow_flashinfer:
        self.flashinfer_moe_backend = get_flashinfer_moe_backend()
        logger.info_once(
            f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
            " for ModelOptNvFp4FusedMoE.")

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
):
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
    assert activation == "silu", "Only SiLU activation is supported."

    if self.allow_flashinfer and \
        self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        import flashinfer

        from vllm.model_executor.models.llama4 import Llama4MoE

        a1_gscale = layer.w13_input_scale_quant
        (hidden_states_fp4,
         hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
             x,
             a1_gscale,
             is_sf_swizzled_layout=False,
         )
        use_llama4_routing = \
            custom_routing_function is Llama4MoE.custom_routing_function
        routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
        if use_llama4_routing:
            routing_method_type = flashinfer.RoutingMethodType.Llama4
        out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
            routing_logits=router_logits
            if use_llama4_routing else router_logits.to(torch.float32),
            routing_bias=e_score_correction_bias,
            hidden_states=hidden_states_fp4,
            hidden_states_scale=hidden_states_scale_linear_fp4.view(
                torch.float8_e4m3fn).flatten(),
            gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
            gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
                torch.float8_e4m3fn),
            gemm1_bias=None,
            gemm1_alpha=None,
            gemm1_beta=None,
            gemm1_clamp_limit=None,
            gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
            gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
                torch.float8_e4m3fn),
            gemm2_bias=None,
            output1_scale_scalar=layer.g1_scale_c.data,
            output1_scale_gate_scalar=layer.g1_alphas.data,
            output2_scale_scalar=layer.g2_alphas.data,
            num_experts=global_num_experts,
            top_k=top_k,
            n_group=num_expert_group
            if num_expert_group is not None else 0,
            topk_group=topk_group if topk_group is not None else 0,
            intermediate_size=layer.intermediate_size_per_partition,
            local_expert_offset=layer.ep_rank * layer.local_num_experts,
            local_num_experts=layer.local_num_experts,
            routed_scaling_factor=None,
            tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
                                                 layer.local_num_experts),
            routing_method_type=routing_method_type,
            do_finalize=True,
        )[0]
        return out

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias,
        indices_type=self.topk_indices_dtype)

    if self.use_marlin:
        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            None,
            None,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            global_scale1=layer.w13_weight_scale_2,
            global_scale2=layer.w2_weight_scale_2,
            quant_type_id=scalar_types.float4_e2m1f.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map)

    if self.fused_experts is not None:
        assert self.allow_flashinfer and \
           self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS

        assert is_valid_flashinfer_cutlass_fused_moe(
            x, layer.w13_weight, layer.w2_weight), (
                "Flashinfer CUTLASS Fused MoE not applicable!")

        out = self.fused_experts(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,  # TODO(shuw): fix later, now output is high prec
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )
    elif (self.allow_flashinfer
          and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (  # noqa: E501
            flashinfer_cutlass_moe_fp4)

        out = flashinfer_cutlass_moe_fp4(
            hidden_states=x,
            w1=layer.w13_weight,
            w2=layer.w2_weight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            w1_scale=layer.w13_weight_scale,
            w2_scale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
            inplace=False,  # TODO(shuw): fix later, now output is high prec
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )
    else:
        # If no modular kernel is provided, use cutlass_moe_fp4 for TP case
        # only (no EP).
        from vllm.model_executor.layers.fused_moe.cutlass_moe import (
            cutlass_moe_fp4)
        out = cutlass_moe_fp4(
            a=x,
            w1_fp4=layer.w13_weight,
            w2_fp4=layer.w2_weight,
            w1_blockscale=layer.w13_weight_scale,
            w2_blockscale=layer.w2_weight_scale,
            g1_alphas=layer.g1_alphas,
            g2_alphas=layer.g2_alphas,
            a1_gscale=layer.w13_input_scale_quant,
            a2_gscale=layer.w2_input_scale_quant,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            m=x.shape[0],
            n=layer.w2_weight.shape[2] * 2,
            k=x.shape[1],
            e=layer.w13_weight.shape[0],
            expert_map=expert_map,
            apply_router_weight_on_input=apply_router_weight_on_input)

    return out

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError("NVFP4 quantization was selected, "
                         " dynamic quantization is not supported.")

    layer.num_experts = num_experts
    layer.params_dtype = params_dtype
    layer.quant_config = self.quant_config
    weight_dtype = torch.uint8
    weight_scale_dtype = torch.float8_e4m3fn
    weight_loader = extra_weight_attrs.get("weight_loader")
    # GEMM 1
    w13_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // 2,
            dtype=weight_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight", w13_weight)

    # GEMM 2
    w2_weight = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition // 2,
            dtype=weight_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight", w2_weight)

    w13_weight_scale = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            # 2 fp4 items are packed in the input dimension
            hidden_size // self.quant_config.group_size,
            dtype=weight_scale_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight_scale", w13_weight_scale)

    w2_weight_scale = ModelWeightParameter(
        data=torch.empty(
            num_experts,
            hidden_size,
            # 2 fp4 items are packed in the input dimension
            intermediate_size_per_partition //
            self.quant_config.group_size,
            dtype=weight_scale_dtype),
        input_dim=1,
        output_dim=2,
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight_scale", w2_weight_scale)

    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})

    w13_weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(num_experts, 2, dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)

    w2_weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(num_experts, dtype=torch.float32),
        weight_loader=weight_loader)
    layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)

    extra_weight_attrs.update(
        {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})

    w13_input_scale = PerTensorScaleParameter(data=torch.empty(
        num_experts, 2, dtype=torch.float32),
                                              weight_loader=weight_loader)
    layer.register_parameter("w13_input_scale", w13_input_scale)

    w2_input_scale = PerTensorScaleParameter(data=torch.empty(
        num_experts, dtype=torch.float32),
                                             weight_loader=weight_loader)
    layer.register_parameter("w2_input_scale", w2_input_scale)

maybe_make_prepare_finalize

maybe_make_prepare_finalize(
    moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]
Source code in vllm/model_executor/layers/quantization/modelopt.py
def maybe_make_prepare_finalize(
    self,
    moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
    if (self.allow_flashinfer and self.flashinfer_moe_backend
            == FlashinferMoeBackend.CUTLASS):
        prepare_finalize = (
            build_flashinfer_fp4_cutlass_moe_prepare_finalize(
                moe,
                a1_gscale=self.layer.w13_input_scale_quant,
            ))
        logger.debug_once("%s", prepare_finalize.__class__.__name__)
        return prepare_finalize

    return super().maybe_make_prepare_finalize(moe)

prepare_static_weight_layouts_for_trtllm_moe

prepare_static_weight_layouts_for_trtllm_moe(
    gemm1_weights: Tensor,
    gemm2_weights: Tensor,
    gemm1_scales_linear_fp4_bytes: Tensor,
    gemm2_scales_linear_fp4_bytes: Tensor,
    hidden_size: int,
    intermediate_size: int,
    num_experts: int,
) -> tuple[Tensor, Tensor, Tensor, Tensor]

Prepare quantized weights for kernel (done offline with weights).

Source code in vllm/model_executor/layers/quantization/modelopt.py
def prepare_static_weight_layouts_for_trtllm_moe(
    self,
    gemm1_weights: torch.Tensor,
    gemm2_weights: torch.Tensor,
    gemm1_scales_linear_fp4_bytes: torch.Tensor,
    gemm2_scales_linear_fp4_bytes: torch.Tensor,
    hidden_size: int,
    intermediate_size: int,
    num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Prepare quantized weights for kernel (done offline with weights)."""
    from flashinfer import (reorder_rows_for_gated_act_gemm,
                            shuffle_matrix_a, shuffle_matrix_sf_a)
    epilogue_tile_m = 128  # FIXME: this depends on the kernel internals

    # Convert quantized weights to proper formats
    gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
        num_experts, 2 * intermediate_size, hidden_size // 2)  # packed fp4
    gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
        torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size,
                                     hidden_size //
                                     16)  # fp8 scaling factors

    gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
        num_experts, hidden_size, intermediate_size // 2)  # packed fp4
    gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
        torch.float8_e4m3fn).reshape(num_experts, hidden_size,
                                     intermediate_size //
                                     16)  # fp8 scaling factors

    # Reorder rows of W1 and scales for fused gated activation
    gemm1_weights_fp4_interleaved = []
    gemm1_scales_fp4_interleaved = []
    for i in range(num_experts):
        gemm1_weights_fp4_interleaved.append(
            reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
        gemm1_scales_fp4_interleaved.append(
            reorder_rows_for_gated_act_gemm(
                gemm1_scales_linear_fp4[i].clone()))

    # Stack weights and scales for all experts
    gemm1_weights_fp4_interleaved = torch.stack(
        gemm1_weights_fp4_interleaved).reshape(num_experts,
                                               2 * intermediate_size,
                                               hidden_size // 2)
    gemm1_scales_fp4_interleaved = torch.stack(
        gemm1_scales_fp4_interleaved).reshape(num_experts,
                                              2 * intermediate_size,
                                              hidden_size // 16)

    # Shuffle weights and scaling factors for transposed mma output
    gemm1_weights_fp4_shuffled = []
    gemm1_scales_fp4_shuffled = []
    gemm2_weights_fp4_shuffled = []
    gemm2_scales_fp4_shuffled = []
    for i in range(num_experts):
        gemm1_weights_fp4_shuffled.append(
            shuffle_matrix_a(
                gemm1_weights_fp4_interleaved[i].view(torch.uint8),
                epilogue_tile_m))
        gemm1_scales_fp4_shuffled.append(
            shuffle_matrix_sf_a(
                gemm1_scales_fp4_interleaved[i].view(torch.uint8),
                epilogue_tile_m))

        gemm2_weights_fp4_shuffled.append(
            shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
                             epilogue_tile_m))
        gemm2_scales_fp4_shuffled.append(
            shuffle_matrix_sf_a(
                gemm2_scales_linear_fp4[i].view(torch.uint8),
                epilogue_tile_m))

    # Stack weights for all experts
    gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
    gemm1_scales_fp4_shuffled = (
        torch.stack(gemm1_scales_fp4_shuffled).view(
            torch.float8_e4m3fn).reshape(num_experts,
                                         2 * intermediate_size,
                                         hidden_size // 16))

    gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
    gemm2_scales_fp4_shuffled = (
        torch.stack(gemm2_scales_fp4_shuffled).view(
            torch.float8_e4m3fn).reshape(num_experts, hidden_size,
                                         intermediate_size // 16))
    return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
            gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    # GEMM 1 processing
    gemm1_weight = layer.w13_weight.data
    gemm1_weight_scale = layer.w13_weight_scale.data

    if self.allow_flashinfer:
        gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
            gemm1_weight, gemm1_weight_scale, dim=-2)

    layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
    layer.w13_weight_scale = Parameter(gemm1_weight_scale,
                                       requires_grad=False)

    # Common processing for w13_weight_scale_2
    if not torch.allclose(layer.w13_weight_scale_2[:, 0],
                          layer.w13_weight_scale_2[:, 1]):
        logger.warning_once(
            "w1_weight_scale_2 must match w3_weight_scale_2. "
            "Accuracy may be affected.")

    w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
    layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
                                         requires_grad=False)

    # Common processing for input scales and alphas
    w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
        torch.float32)
    layer.g1_alphas = Parameter(
        (w13_input_scale * w13_weight_scale_2).to(torch.float32),
        requires_grad=False)

    # This is for quantization, so we need to invert it.
    layer.w13_input_scale_quant = Parameter(
        (1 / w13_input_scale).to(torch.float32), requires_grad=False)

    # GEMM 2 processing
    layer.g2_alphas = Parameter(
        (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
        requires_grad=False)

    # This is for quantization, so we need to invert it.
    layer.w2_input_scale_quant = Parameter(
        (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)

    # TensorRT-LLM specific processing
    if self.allow_flashinfer and \
        self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
        # Prepare static weights for TRT-LLM kernel
        (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
         gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
         ) = self.prepare_static_weight_layouts_for_trtllm_moe(
             layer.w13_weight,
             layer.w2_weight,
             layer.w13_weight_scale,
             layer.w2_weight_scale,
             layer.w2_weight.size(-2),  # hidden_size
             layer.w13_weight.size(-2) // 2,  # intermediate_size
             layer.w13_weight.size(0),  # num_experts
         )

        layer.gemm1_weights_fp4_shuffled = Parameter(
            gemm1_weights_fp4_shuffled, requires_grad=False)
        layer.gemm2_weights_fp4_shuffled = Parameter(
            gemm2_weights_fp4_shuffled, requires_grad=False)
        layer.gemm1_scales_fp4_shuffled = Parameter(
            gemm1_scales_fp4_shuffled, requires_grad=False)
        layer.gemm2_scales_fp4_shuffled = Parameter(
            gemm2_scales_fp4_shuffled, requires_grad=False)

        # Additional parameter needed for TRT-LLM
        layer.g1_scale_c = Parameter(
            (layer.w2_input_scale_quant * layer.g1_alphas).to(
                torch.float32),
            requires_grad=False,
        )

        # Clean up weights that won't be used by TRT-LLM
        del layer.w2_weight
        del layer.w2_weight_scale
        del layer.w13_weight
        del layer.w13_weight_scale
    else:
        # Non-TRT-LLM processing (Cutlass or non-flashinfer)
        assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Blockscale must be represented as FP8-E4M3")
        w13_blockscale_swizzled = swizzle_blockscale(
            layer.w13_weight_scale)
        layer.w13_weight_scale = Parameter(w13_blockscale_swizzled,
                                           requires_grad=False)

        assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
            "Expected weight_scale.dim(1) to be divisible by 16")
        assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Blockscale must be represented as FP8-E4M3")
        w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
        layer.w2_weight_scale = Parameter(w2_blockscale_swizzled,
                                          requires_grad=False)
        layer.w2_weight = Parameter(layer.w2_weight.data,
                                    requires_grad=False)

    if self.use_marlin:
        prepare_moe_fp4_layer_for_marlin(layer)
        del layer.g1_alphas
        del layer.g2_alphas
        del layer.w13_input_scale_quant
        del layer.w2_input_scale_quant

select_gemm_impl

select_gemm_impl(
    prepare_finalize: FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute
Source code in vllm/model_executor/layers/quantization/modelopt.py
def select_gemm_impl(
    self,
    prepare_finalize: mk.FusedMoEPrepareAndFinalize,
    moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
    experts = select_nvfp4_gemm_impl(
        moe,
        g1_alphas=self.layer.g1_alphas,
        g2_alphas=self.layer.g2_alphas,
        a1_gscale=self.layer.w13_input_scale_quant,
        a2_gscale=self.layer.w2_input_scale_quant,
        allow_flashinfer=self.allow_flashinfer,
    )
    logger.debug_once("Using %s", experts.__class__.__name__)
    return experts

uses_weight_scale_2_pattern

uses_weight_scale_2_pattern() -> bool

FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.

Source code in vllm/model_executor/layers/quantization/modelopt.py
def uses_weight_scale_2_pattern(self) -> bool:
    """
    FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
    """
    return True

ModelOptNvFp4LinearMethod

Bases: LinearMethodBase

Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure:

input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, weight_scale_2: torch.float32, scalar, Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/modelopt.py
class ModelOptNvFp4LinearMethod(LinearMethodBase):
    """Linear method for Model Optimizer NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:

    input_scale: torch.float32, scalar ,
    weight: NVFP4(represented as byte) Shape: [1, X, y/2]
    weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
    weight_scale_2: torch.float32, scalar,
    Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
        self.quant_config = quant_config

        if envs.VLLM_USE_TRTLLM_FP4_GEMM:
            assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
            self.backend = "flashinfer-trtllm"
        elif has_flashinfer():
            self.backend = "flashinfer-cutlass"
        elif cutlass_fp4_supported():
            self.backend = "cutlass"
        elif is_fp4_marlin_supported():
            self.backend = "marlin"
        else:
            raise ValueError("Current platform does not support NVFP4"
                             " quantization. Please use Blackwell and"
                             " above.")

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")
        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        if (input_size_per_partition % 16 != 0):
            raise ValueError("Unsupported model when in features size is "
                             "not multiple of 16")
        # The nvfp4 weight is still represented as
        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_nvfp4_serialized
                        else params_dtype)
        # Weight
        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 items are packed in the input dimension
                layer.output_size_per_partition,
                layer.input_size_per_partition // 2,
                dtype=torch.uint8),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        # Input Weight Scale
        input_scale = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                              weight_loader=weight_loader)
        layer.register_parameter("input_scale", input_scale)

        # Global Weight Scale
        weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
            len(output_partition_sizes), dtype=torch.float32),
                                                 weight_loader=weight_loader)
        layer.register_parameter("weight_scale_2", weight_scale_2)

        # Per Block Weight Scale
        weight_scale = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition // self.quant_config.group_size,
            dtype=weight_dtype,
        ),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: Module) -> None:

        # global scales:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)

        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

        layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                                requires_grad=False)

        # Swizzle the weight blockscale.
        # contracting dimension is input dimension
        # block_size = 16;
        assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
            "Weight Block scale must be represented as FP8-E4M3")

        if self.backend == "flashinfer-trtllm":
            # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
            # FlashInfer provides nvfp4_quantize to quantize + shuffle the
            # layout but we use our own quantization so we have to call
            # shuffles ourselves.
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

            weight = layer.weight.data
            weight_scale = layer.weight_scale.data

            epilogue_tile_m = 128
            weight = shuffle_matrix_a(weight.view(torch.uint8),
                                      epilogue_tile_m)
            weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
                torch.uint8), epilogue_tile_m).reshape(
                    weight_scale.shape).view(torch.float8_e4m3fn))

            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
            layer.weight = Parameter(weight, requires_grad=False)
        else:
            swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
            layer.weight_scale = Parameter(swizzled_weight_scale,
                                           requires_grad=False)
            layer.weight = Parameter(layer.weight.data, requires_grad=False)

            if self.backend == "marlin":
                prepare_fp4_layer_for_marlin(layer)
                del layer.alpha
                del layer.input_scale

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if self.backend == "marlin":
            return apply_fp4_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                weight_scale_2=layer.weight_scale_2,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                bias=bias)

        output_dtype = x.dtype
        output_shape = [x.shape[0], layer.weight.shape[0]]

        # quantize BF16 or FP16 to (FP4 and interleaved block scale)
        s_quant = 1 / layer.input_scale
        x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

        # validate dtypes of quantized input, input block scale,
        # weight and weight_blockscale
        assert (x_fp4.dtype == torch.uint8)
        assert (layer.weight.dtype == torch.uint8)
        assert (x_blockscale.dtype == torch.float8_e4m3fn)
        assert (layer.weight_scale.dtype == torch.float8_e4m3fn)
        assert (layer.alpha.dtype == torch.float32)

        mm_args = (
            x_fp4,
            layer.weight,
            x_blockscale,
            layer.weight_scale,
            layer.alpha,
            output_dtype,
        )
        if self.backend == "flashinfer-trtllm":
            out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
        elif self.backend == "flashinfer-cutlass":
            out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
        else:
            out = cutlass_scaled_fp4_mm(*mm_args)

        if bias is not None:
            out = out + bias
        return out.view(*output_shape)

backend instance-attribute

backend = 'flashinfer-trtllm'

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: ModelOptNvFp4Config) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
    self.quant_config = quant_config

    if envs.VLLM_USE_TRTLLM_FP4_GEMM:
        assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer"
        self.backend = "flashinfer-trtllm"
    elif has_flashinfer():
        self.backend = "flashinfer-cutlass"
    elif cutlass_fp4_supported():
        self.backend = "cutlass"
    elif is_fp4_marlin_supported():
        self.backend = "marlin"
    else:
        raise ValueError("Current platform does not support NVFP4"
                         " quantization. Please use Blackwell and"
                         " above.")

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/modelopt.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if self.backend == "marlin":
        return apply_fp4_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            weight_scale_2=layer.weight_scale_2,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias)

    output_dtype = x.dtype
    output_shape = [x.shape[0], layer.weight.shape[0]]

    # quantize BF16 or FP16 to (FP4 and interleaved block scale)
    s_quant = 1 / layer.input_scale
    x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

    # validate dtypes of quantized input, input block scale,
    # weight and weight_blockscale
    assert (x_fp4.dtype == torch.uint8)
    assert (layer.weight.dtype == torch.uint8)
    assert (x_blockscale.dtype == torch.float8_e4m3fn)
    assert (layer.weight_scale.dtype == torch.float8_e4m3fn)
    assert (layer.alpha.dtype == torch.float32)

    mm_args = (
        x_fp4,
        layer.weight,
        x_blockscale,
        layer.weight_scale,
        layer.alpha,
        output_dtype,
    )
    if self.backend == "flashinfer-trtllm":
        out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
    elif self.backend == "flashinfer-cutlass":
        out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass")
    else:
        out = cutlass_scaled_fp4_mm(*mm_args)

    if bias is not None:
        out = out + bias
    return out.view(*output_shape)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/modelopt.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError("NVFP4 quantization was selected, "
                         " dynamic quantization is not supported.")
    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")
    layer.logical_widths = output_partition_sizes
    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition

    if (input_size_per_partition % 16 != 0):
        raise ValueError("Unsupported model when in features size is "
                         "not multiple of 16")
    # The nvfp4 weight is still represented as
    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_nvfp4_serialized
                    else params_dtype)
    # Weight
    weight = ModelWeightParameter(
        data=torch.empty(
            # 2 fp4 items are packed in the input dimension
            layer.output_size_per_partition,
            layer.input_size_per_partition // 2,
            dtype=torch.uint8),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    # Input Weight Scale
    input_scale = PerTensorScaleParameter(data=torch.empty(
        len(output_partition_sizes), dtype=torch.float32),
                                          weight_loader=weight_loader)
    layer.register_parameter("input_scale", input_scale)

    # Global Weight Scale
    weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
        len(output_partition_sizes), dtype=torch.float32),
                                             weight_loader=weight_loader)
    layer.register_parameter("weight_scale_2", weight_scale_2)

    # Per Block Weight Scale
    weight_scale = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition // self.quant_config.group_size,
        dtype=weight_dtype,
    ),
                                        input_dim=1,
                                        output_dim=0,
                                        weight_loader=weight_loader)

    layer.register_parameter("weight_scale", weight_scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/modelopt.py
def process_weights_after_loading(self, layer: Module) -> None:

    # global scales:
    input_scale_2 = layer.input_scale.max().to(torch.float32)
    layer.input_scale = Parameter(input_scale_2, requires_grad=False)

    weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
    layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

    layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                            requires_grad=False)

    # Swizzle the weight blockscale.
    # contracting dimension is input dimension
    # block_size = 16;
    assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
        "Weight Block scale must be represented as FP8-E4M3")

    if self.backend == "flashinfer-trtllm":
        # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
        # FlashInfer provides nvfp4_quantize to quantize + shuffle the
        # layout but we use our own quantization so we have to call
        # shuffles ourselves.
        from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

        weight = layer.weight.data
        weight_scale = layer.weight_scale.data

        epilogue_tile_m = 128
        weight = shuffle_matrix_a(weight.view(torch.uint8),
                                  epilogue_tile_m)
        weight_scale = (shuffle_matrix_sf_a(weight_scale.view(
            torch.uint8), epilogue_tile_m).reshape(
                weight_scale.shape).view(torch.float8_e4m3fn))

        layer.weight_scale = Parameter(weight_scale, requires_grad=False)
        layer.weight = Parameter(weight, requires_grad=False)
    else:
        swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
        layer.weight_scale = Parameter(swizzled_weight_scale,
                                       requires_grad=False)
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        if self.backend == "marlin":
            prepare_fp4_layer_for_marlin(layer)
            del layer.alpha
            del layer.input_scale

_get_tile_tokens_dim

_get_tile_tokens_dim(
    num_tokens: int, top_k: int, num_experts: int
) -> int
Source code in vllm/model_executor/layers/quantization/modelopt.py
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
    # Guess tokens per expert assuming perfect expert distribution first.
    num_tokens_per_expert = (num_tokens * top_k) // num_experts
    # And pad the number to the next power of 2.
    tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
    # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
    tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
    return tile_tokens_dim