Skip to content

vllm.model_executor.layers.quantization.petit

logger module-attribute

logger = init_logger(__name__)

PetitFp8KVCacheMethod

Bases: BaseKVCacheMethod

Supports loading kv-cache scaling factors from FP8 checkpoints.

Source code in vllm/model_executor/layers/quantization/petit.py
class PetitFp8KVCacheMethod(BaseKVCacheMethod):
    """
    Supports loading kv-cache scaling factors from FP8 checkpoints.
    """

    def __init__(self, quant_config: PetitNvFp4Config):
        super().__init__(quant_config)

__init__

__init__(quant_config: PetitNvFp4Config)
Source code in vllm/model_executor/layers/quantization/petit.py
def __init__(self, quant_config: PetitNvFp4Config):
    super().__init__(quant_config)

PetitNvFp4Config

Bases: QuantizationConfig

Config class for Petit FP4.

Source code in vllm/model_executor/layers/quantization/petit.py
class PetitNvFp4Config(QuantizationConfig):
    """Config class for Petit FP4."""

    def __init__(
        self,
        is_checkpoint_nvfp4_serialized: bool = False,
        kv_cache_quant_algo: Optional[str] = None,
        group_size: Optional[int] = None,
        exclude_modules: Optional[list[str]] = None,
    ) -> None:
        self._check_hardware_support()
        self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
        if is_checkpoint_nvfp4_serialized:
            logger.warning("Detected nvfp4 checkpoint. Please note that the "
                           "format is experimental and subject to change.")
        self.group_size = group_size
        self.kv_cache_quant_algo = kv_cache_quant_algo
        self.exclude_modules = exclude_modules

    def _check_hardware_support(self) -> None:
        """
        Verifies that the current hardware is supported by the Petit backend.
        This backend is specifically designed for AMD GPUs and is not
        supported on the CUDA platform.
        """
        # This check ensures the code is NOT running on an NVIDIA GPU.
        if current_platform.is_cuda():
            raise ValueError(
                "The 'petit' quantization backend is designed for AMD GPUs "
                "and is not supported on the CUDA platform. For NVIDIA GPUs, "
                "please use a different quantization method such as FP8, AWQ, "
                "or GPTQ.")

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "petit_nvfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        # Petit supports the gfx90a and gfx942 GPUs
        return 90

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["hf_quant_config.json"]

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
        qc = cls.get_from_keys(config, ["quantization"])

        quant_method_raw = qc.get("quant_algo")
        if not isinstance(quant_method_raw, str) or not quant_method_raw:
            raise ValueError(
                "Missing or invalid 'quant_algo' in quantization config.")
        quant_method = quant_method_raw.upper()

        group_size_raw = qc.get("group_size")
        if not isinstance(group_size_raw, int):
            raise ValueError(
                "Missing or invalid 'group_size' (int) in hf_quant_config.json."
            )
        group_size = group_size_raw

        verify_petit_nvfp4_supported(quant_method, group_size)

        kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
        if not isinstance(kv_cache_quant_algo_raw, str):
            raise ValueError(
                "'kv_cache_quant_algo' must be a string if provided.")
        kv_cache_quant_algo = kv_cache_quant_algo_raw

        exclude_raw = qc.get("exclude_modules", [])
        if exclude_raw is None:
            exclude_modules: list[str] = []
        elif isinstance(exclude_raw, list) and all(
                isinstance(x, str) for x in exclude_raw):
            exclude_modules = exclude_raw
        else:
            raise ValueError(
                "'exclude_modules' must be a list[str] (or omitted).")

        is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method

        return cls(
            is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
            kv_cache_quant_algo=kv_cache_quant_algo,
            group_size=group_size,
            exclude_modules=exclude_modules,
        )

    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        if not current_platform.is_rocm():
            return None

        qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
        algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
        if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
            return cls.get_name()  # "petit_nvfp4"
        return None

    @classmethod
    def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
        qc = quant_config.get("quantization", quant_config)
        algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
        return algo == "NVFP4"

    def is_layer_excluded(self, prefix: str,
                          exclude_modules: list[str]) -> bool:
        for pattern in exclude_modules:
            regex_str = pattern.replace(".", r"\.").replace("*", r".*")
            if re.fullmatch(regex_str, prefix):
                return True
        return False

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

        exclude = self.require_exclude_modules()

        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
                    prefix, exclude):
                return UnquantizedLinearMethod()
            return PetitNvFp4LinearMethod(self)
        elif isinstance(layer, Attention):
            return PetitFp8KVCacheMethod(self)
        return None

    def get_scaled_act_names(self) -> list[str]:
        return []

    def require_group_size(self) -> int:
        if self.group_size is None:
            logger.warning("group_size not set; defaulting to 16 for NVFP4.")
            return 16
        return self.group_size

    def require_kv_cache_quant_algo(self) -> str:
        return self.kv_cache_quant_algo or "auto"

    def require_exclude_modules(self) -> list[str]:
        return list(self.exclude_modules or [])

exclude_modules instance-attribute

exclude_modules = exclude_modules

group_size instance-attribute

group_size = group_size

is_checkpoint_nvfp4_serialized instance-attribute

is_checkpoint_nvfp4_serialized = (
    is_checkpoint_nvfp4_serialized
)

kv_cache_quant_algo instance-attribute

kv_cache_quant_algo = kv_cache_quant_algo

__init__

__init__(
    is_checkpoint_nvfp4_serialized: bool = False,
    kv_cache_quant_algo: Optional[str] = None,
    group_size: Optional[int] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None
Source code in vllm/model_executor/layers/quantization/petit.py
def __init__(
    self,
    is_checkpoint_nvfp4_serialized: bool = False,
    kv_cache_quant_algo: Optional[str] = None,
    group_size: Optional[int] = None,
    exclude_modules: Optional[list[str]] = None,
) -> None:
    self._check_hardware_support()
    self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
    if is_checkpoint_nvfp4_serialized:
        logger.warning("Detected nvfp4 checkpoint. Please note that the "
                       "format is experimental and subject to change.")
    self.group_size = group_size
    self.kv_cache_quant_algo = kv_cache_quant_algo
    self.exclude_modules = exclude_modules

_check_hardware_support

_check_hardware_support() -> None

Verifies that the current hardware is supported by the Petit backend. This backend is specifically designed for AMD GPUs and is not supported on the CUDA platform.

Source code in vllm/model_executor/layers/quantization/petit.py
def _check_hardware_support(self) -> None:
    """
    Verifies that the current hardware is supported by the Petit backend.
    This backend is specifically designed for AMD GPUs and is not
    supported on the CUDA platform.
    """
    # This check ensures the code is NOT running on an NVIDIA GPU.
    if current_platform.is_cuda():
        raise ValueError(
            "The 'petit' quantization backend is designed for AMD GPUs "
            "and is not supported on the CUDA platform. For NVIDIA GPUs, "
            "please use a different quantization method such as FP8, AWQ, "
            "or GPTQ.")

from_config classmethod

from_config(config: dict[str, Any]) -> PetitNvFp4Config
Source code in vllm/model_executor/layers/quantization/petit.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
    qc = cls.get_from_keys(config, ["quantization"])

    quant_method_raw = qc.get("quant_algo")
    if not isinstance(quant_method_raw, str) or not quant_method_raw:
        raise ValueError(
            "Missing or invalid 'quant_algo' in quantization config.")
    quant_method = quant_method_raw.upper()

    group_size_raw = qc.get("group_size")
    if not isinstance(group_size_raw, int):
        raise ValueError(
            "Missing or invalid 'group_size' (int) in hf_quant_config.json."
        )
    group_size = group_size_raw

    verify_petit_nvfp4_supported(quant_method, group_size)

    kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
    if not isinstance(kv_cache_quant_algo_raw, str):
        raise ValueError(
            "'kv_cache_quant_algo' must be a string if provided.")
    kv_cache_quant_algo = kv_cache_quant_algo_raw

    exclude_raw = qc.get("exclude_modules", [])
    if exclude_raw is None:
        exclude_modules: list[str] = []
    elif isinstance(exclude_raw, list) and all(
            isinstance(x, str) for x in exclude_raw):
        exclude_modules = exclude_raw
    else:
        raise ValueError(
            "'exclude_modules' must be a list[str] (or omitted).")

    is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method

    return cls(
        is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
        kv_cache_quant_algo=kv_cache_quant_algo,
        group_size=group_size,
        exclude_modules=exclude_modules,
    )

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/petit.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["hf_quant_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/petit.py
@classmethod
def get_min_capability(cls) -> int:
    # Petit supports the gfx90a and gfx942 GPUs
    return 90

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/petit.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "petit_nvfp4"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/petit.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import

    exclude = self.require_exclude_modules()

    if isinstance(layer, LinearBase):
        if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
                prefix, exclude):
            return UnquantizedLinearMethod()
        return PetitNvFp4LinearMethod(self)
    elif isinstance(layer, Attention):
        return PetitFp8KVCacheMethod(self)
    return None

get_scaled_act_names

get_scaled_act_names() -> list[str]
Source code in vllm/model_executor/layers/quantization/petit.py
def get_scaled_act_names(self) -> list[str]:
    return []

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/petit.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

is_layer_excluded

is_layer_excluded(
    prefix: str, exclude_modules: list[str]
) -> bool
Source code in vllm/model_executor/layers/quantization/petit.py
def is_layer_excluded(self, prefix: str,
                      exclude_modules: list[str]) -> bool:
    for pattern in exclude_modules:
        regex_str = pattern.replace(".", r"\.").replace("*", r".*")
        if re.fullmatch(regex_str, prefix):
            return True
    return False

is_petit_nvfp4_compatible classmethod

is_petit_nvfp4_compatible(
    quant_config: dict[str, Any],
) -> bool
Source code in vllm/model_executor/layers/quantization/petit.py
@classmethod
def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
    qc = quant_config.get("quantization", quant_config)
    algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
    return algo == "NVFP4"

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]
Source code in vllm/model_executor/layers/quantization/petit.py
@classmethod
def override_quantization_method(
        cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
    if not current_platform.is_rocm():
        return None

    qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
    algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
    if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
        return cls.get_name()  # "petit_nvfp4"
    return None

require_exclude_modules

require_exclude_modules() -> list[str]
Source code in vllm/model_executor/layers/quantization/petit.py
def require_exclude_modules(self) -> list[str]:
    return list(self.exclude_modules or [])

require_group_size

require_group_size() -> int
Source code in vllm/model_executor/layers/quantization/petit.py
def require_group_size(self) -> int:
    if self.group_size is None:
        logger.warning("group_size not set; defaulting to 16 for NVFP4.")
        return 16
    return self.group_size

require_kv_cache_quant_algo

require_kv_cache_quant_algo() -> str
Source code in vllm/model_executor/layers/quantization/petit.py
def require_kv_cache_quant_algo(self) -> str:
    return self.kv_cache_quant_algo or "auto"

PetitNvFp4LinearMethod

Bases: LinearMethodBase

Linear method for NVFP4. Supports loading NVFP4 checkpoints with the following structure:

|Tensor Name | datatype | shape | |----------------------------------------------------| |input_scale | torch.float32 | scalar | |weight | NVFP4(SE2M1) | [1, X, y/2] | |weight_scale | FP8-E4M3 | [X, Y] | |weight_scale_2 | torch.float32 | scalar |

The weights are quantized per block of 16 elements. Args: quant_config: The ModelOpt quantization config.

Source code in vllm/model_executor/layers/quantization/petit.py
class PetitNvFp4LinearMethod(LinearMethodBase):
    """Linear method for NVFP4.
    Supports loading NVFP4 checkpoints with the following structure:

    |Tensor Name           | datatype      |  shape      |
    |----------------------------------------------------|
    |input_scale           | torch.float32 | scalar      |
    |weight                | NVFP4(SE2M1)  | [1, X, y/2] |
    |weight_scale          | FP8-E4M3      | [X, Y]      |
    |weight_scale_2        | torch.float32 | scalar      |

    The weights are quantized per block of 16 elements.
    Args: quant_config: The ModelOpt quantization config.
    """

    def __init__(self, quant_config: PetitNvFp4Config):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del input_size, output_size
        if not self.quant_config.is_checkpoint_nvfp4_serialized:
            raise ValueError("NVFP4 quantization was selected, "
                             " dynamic quantization is not supported.")

        output_size_per_partition = sum(output_partition_sizes)
        weight_loader = extra_weight_attrs.get("weight_loader")

        layer.logical_widths = output_partition_sizes

        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        if input_size_per_partition % 16 != 0:
            raise ValueError("Unsupported model when in features size is "
                             "not multiple of 16")

        weight_dtype = (torch.float8_e4m3fn
                        if self.quant_config.is_checkpoint_nvfp4_serialized
                        else params_dtype)

        weight = ModelWeightParameter(
            data=torch.empty(
                # 2 fp4 data is packed in one uint8 in the input dimension
                output_size_per_partition,
                input_size_per_partition // 2,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        input_scale = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )

        layer.register_parameter("input_scale", input_scale)

        weight_scale_2 = PerTensorScaleParameter(
            data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale_2", weight_scale_2)

        group_size = self.quant_config.require_group_size()
        weight_scale = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // group_size,
                dtype=weight_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )

        layer.register_parameter("weight_scale", weight_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        input_scale_2 = layer.input_scale.max().to(torch.float32)
        weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
        layer.input_scale = Parameter(input_scale_2, requires_grad=False)
        layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
        layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                                requires_grad=False)

        prepare_nvfp4_layer_for_petit(layer)
        del layer.input_scale

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return apply_petit_nvfp4_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            weight_scale_2=layer.weight_scale_2,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias,
        )

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: PetitNvFp4Config)
Source code in vllm/model_executor/layers/quantization/petit.py
def __init__(self, quant_config: PetitNvFp4Config):
    self.quant_config = quant_config

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/petit.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return apply_petit_nvfp4_linear(
        input=x,
        weight=layer.weight,
        weight_scale=layer.weight_scale,
        weight_scale_2=layer.weight_scale_2,
        size_n=layer.output_size_per_partition,
        size_k=layer.input_size_per_partition,
        bias=bias,
    )

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/petit.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    del input_size, output_size
    if not self.quant_config.is_checkpoint_nvfp4_serialized:
        raise ValueError("NVFP4 quantization was selected, "
                         " dynamic quantization is not supported.")

    output_size_per_partition = sum(output_partition_sizes)
    weight_loader = extra_weight_attrs.get("weight_loader")

    layer.logical_widths = output_partition_sizes

    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    if input_size_per_partition % 16 != 0:
        raise ValueError("Unsupported model when in features size is "
                         "not multiple of 16")

    weight_dtype = (torch.float8_e4m3fn
                    if self.quant_config.is_checkpoint_nvfp4_serialized
                    else params_dtype)

    weight = ModelWeightParameter(
        data=torch.empty(
            # 2 fp4 data is packed in one uint8 in the input dimension
            output_size_per_partition,
            input_size_per_partition // 2,
            dtype=torch.uint8,
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight", weight)

    input_scale = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
        weight_loader=weight_loader,
    )

    layer.register_parameter("input_scale", input_scale)

    weight_scale_2 = PerTensorScaleParameter(
        data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
        weight_loader=weight_loader,
    )
    layer.register_parameter("weight_scale_2", weight_scale_2)

    group_size = self.quant_config.require_group_size()
    weight_scale = ModelWeightParameter(
        data=torch.empty(
            output_size_per_partition,
            input_size_per_partition // group_size,
            dtype=weight_dtype,
        ),
        input_dim=1,
        output_dim=0,
        weight_loader=weight_loader,
    )

    layer.register_parameter("weight_scale", weight_scale)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/petit.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    input_scale_2 = layer.input_scale.max().to(torch.float32)
    weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
    layer.input_scale = Parameter(input_scale_2, requires_grad=False)
    layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
    layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
                            requires_grad=False)

    prepare_nvfp4_layer_for_petit(layer)
    del layer.input_scale