Skip to content

vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu

CPUScaledMMLinearKernel

Bases: ScaledMMLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):

    @classmethod
    def get_min_capability(cls) -> int:
        return 75

    @classmethod
    def can_implement(
            cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
        if not current_platform.is_cpu():
            return False, "CPUScaledMM requires running on CPU."

        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        weight = getattr(layer, self.w_q_name)
        dtype = weight.dtype
        N, K = weight.size()
        if (current_platform.get_cpu_architecture() == CpuArchEnum.X86
                and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric
                and check_cpu_sgl_kernel(N, K, dtype)):
            self.linear_method = self._apply_weights_sgl
            self.process_weights_for_sgl(layer)
        else:
            self.linear_method = self._apply_weights_onednn
            self.process_weights_for_onednn(layer)

    def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
        # WEIGHT
        # Transpose to [K, N] for convenience
        weight = getattr(layer, self.w_q_name)
        replace_parameter(
            layer, self.w_q_name,
            torch.nn.Parameter(weight.t().data, requires_grad=False))

        # WEIGHT SCALE
        # oneDNN kernels support only per-tensor and per-channel.
        # If we have a fused module (QKV, MLP) with per tensor scales (thus N
        # scales being passed to the kernel), convert to the per-channel case.
        is_fused_module = len(layer.logical_widths) > 1
        weight_scale = getattr(layer, self.w_s_name)
        if is_fused_module and not self.config.is_channelwise:
            weight_scale = convert_to_channelwise(weight_scale,
                                                  layer.logical_widths)
        replace_parameter(
            layer, self.w_s_name,
            torch.nn.Parameter(weight_scale.data, requires_grad=False))

        # INPUT SCALE
        if self.config.is_static_input_scheme:
            input_scale = getattr(layer, self.i_s_name)

            if self.config.input_symmetric:
                replace_parameter(
                    layer, self.i_s_name,
                    torch.nn.Parameter(input_scale.max(), requires_grad=False))
                setattr(layer, self.i_zp_name, None)
            else:
                input_zero_point = getattr(layer, self.i_zp_name)

                # reconstruct the ranges
                int8_traits = torch.iinfo(torch.int8)
                azps = input_zero_point.to(dtype=torch.int32)
                range_max = (input_scale * (int8_traits.max - azps)).max()
                range_min = (input_scale * (int8_traits.min - azps)).min()

                scale = (range_max - range_min) / (int8_traits.max -
                                                   int8_traits.min)
                replace_parameter(
                    layer, self.i_s_name,
                    torch.nn.Parameter(scale, requires_grad=False))

                azp = (int8_traits.min -
                       range_min / scale).round().to(dtype=torch.int32)
                replace_parameter(layer, self.i_zp_name,
                                  torch.nn.Parameter(azp, requires_grad=False))

        else:
            setattr(layer, self.i_s_name, None)
            setattr(layer, self.i_zp_name, None)

        # Different from cutlass, oneDNN kernels only need the AZP adjustment
        # term for dynamic quantization. And s_b should be folded into the
        # term. Such as:
        # s_a * s_b * [(A - zp_a)B] + bias =
        # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
        # s_a * GEMM_output - s_a * zp_a * adj + bias
        if not (self.config.input_symmetric
                and self.config.is_static_input_scheme):
            weight = getattr(layer, self.w_q_name)
            weight_scale = getattr(layer, self.w_s_name)
            azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
            azp_adj = azp_adj * weight_scale.squeeze()
            setattr(layer, self.azp_adj_name,
                    torch.nn.Parameter(azp_adj, requires_grad=False))
        else:
            setattr(layer, self.azp_adj_name, None)

        weight = getattr(layer, self.w_q_name)
        self.dnnl_handler = ops.create_onednn_scaled_mm(
            weight,
            getattr(layer, self.w_s_name),
            torch.get_default_dtype(),
            getattr(layer, self.i_s_name) is None,
            not self.config.input_symmetric,
            32,
        )
        # weight is prepacked and maintained by the dnnl_handler,
        # release the original weight
        setattr(layer, self.w_q_name, None)
        del weight

    def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
        # WEIGHT
        weight = getattr(layer, self.w_q_name)
        packed_weight = torch.ops._C.convert_weight_packed(weight)
        replace_parameter(
            layer, self.w_q_name,
            torch.nn.Parameter(packed_weight, requires_grad=False))

        if layer.bias is not None:
            bias = layer.bias
            layer.register_parameter(
                "bias_fp32",
                torch.nn.Parameter(bias.float().data, requires_grad=False))

        # WEIGHT SCALE
        # CPU SGL kernels only support per-channel.
        # For per-tensor quant, convert to the per-channel case.
        weight_scale = getattr(layer, self.w_s_name)
        if not self.config.is_channelwise:
            weight_scale = convert_to_channelwise(weight_scale,
                                                  layer.logical_widths)
        replace_parameter(
            layer, self.w_s_name,
            torch.nn.Parameter(weight_scale.data, requires_grad=False))

        setattr(layer, self.i_s_name, None)
        setattr(layer, self.i_zp_name, None)
        setattr(layer, self.azp_adj_name, None)

    def apply_weights(self,
                      layer: torch.nn.Module,
                      x: torch.Tensor,
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self.linear_method(
            layer,
            x,
            bias,
        )

    def _apply_weights_onednn(
            self,
            layer: torch.nn.Module,
            x: torch.Tensor,
            bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)

        # ops.scaled_int8_quant supports both dynamic and static quant:
        # * dynamic, i_s is None and x_s computed from x.
        # * static, i_s is scalar and x_s is i_s.
        x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
            x, i_s, i_zp, self.config.input_symmetric)

        m = x.size(0)
        n = self.dnnl_handler.n
        out = torch.empty((m, n), dtype=x.dtype)
        ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj,
                             bias)

        return out

    def _apply_weights_sgl(
            self,
            layer: torch.nn.Module,
            x: torch.Tensor,
            bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        w_q, w_s, _, _, _ = self._get_weight_params(layer)
        return torch.ops._C.int8_scaled_mm_with_quant(
            x,
            w_q,
            w_s,
            layer.bias_fp32 if bias is not None else None,
            x.dtype,
            True,
        )

_apply_weights_onednn

_apply_weights_onednn(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
def _apply_weights_onednn(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)

    # ops.scaled_int8_quant supports both dynamic and static quant:
    # * dynamic, i_s is None and x_s computed from x.
    # * static, i_s is scalar and x_s is i_s.
    x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
        x, i_s, i_zp, self.config.input_symmetric)

    m = x.size(0)
    n = self.dnnl_handler.n
    out = torch.empty((m, n), dtype=x.dtype)
    ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj,
                         bias)

    return out

_apply_weights_sgl

_apply_weights_sgl(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
def _apply_weights_sgl(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    w_q, w_s, _, _, _ = self._get_weight_params(layer)
    return torch.ops._C.int8_scaled_mm_with_quant(
        x,
        w_q,
        w_s,
        layer.bias_fp32 if bias is not None else None,
        x.dtype,
        True,
    )

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
def apply_weights(self,
                  layer: torch.nn.Module,
                  x: torch.Tensor,
                  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    return self.linear_method(
        layer,
        x,
        bias,
    )

can_implement classmethod

can_implement(
    c: ScaledMMLinearLayerConfig,
) -> tuple[bool, Optional[str]]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
@classmethod
def can_implement(
        cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
    if not current_platform.is_cpu():
        return False, "CPUScaledMM requires running on CPU."

    return True, None

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
@classmethod
def get_min_capability(cls) -> int:
    return 75

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    weight = getattr(layer, self.w_q_name)
    dtype = weight.dtype
    N, K = weight.size()
    if (current_platform.get_cpu_architecture() == CpuArchEnum.X86
            and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric
            and check_cpu_sgl_kernel(N, K, dtype)):
        self.linear_method = self._apply_weights_sgl
        self.process_weights_for_sgl(layer)
    else:
        self.linear_method = self._apply_weights_onednn
        self.process_weights_for_onednn(layer)

process_weights_for_onednn

process_weights_for_onednn(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
    # WEIGHT
    # Transpose to [K, N] for convenience
    weight = getattr(layer, self.w_q_name)
    replace_parameter(
        layer, self.w_q_name,
        torch.nn.Parameter(weight.t().data, requires_grad=False))

    # WEIGHT SCALE
    # oneDNN kernels support only per-tensor and per-channel.
    # If we have a fused module (QKV, MLP) with per tensor scales (thus N
    # scales being passed to the kernel), convert to the per-channel case.
    is_fused_module = len(layer.logical_widths) > 1
    weight_scale = getattr(layer, self.w_s_name)
    if is_fused_module and not self.config.is_channelwise:
        weight_scale = convert_to_channelwise(weight_scale,
                                              layer.logical_widths)
    replace_parameter(
        layer, self.w_s_name,
        torch.nn.Parameter(weight_scale.data, requires_grad=False))

    # INPUT SCALE
    if self.config.is_static_input_scheme:
        input_scale = getattr(layer, self.i_s_name)

        if self.config.input_symmetric:
            replace_parameter(
                layer, self.i_s_name,
                torch.nn.Parameter(input_scale.max(), requires_grad=False))
            setattr(layer, self.i_zp_name, None)
        else:
            input_zero_point = getattr(layer, self.i_zp_name)

            # reconstruct the ranges
            int8_traits = torch.iinfo(torch.int8)
            azps = input_zero_point.to(dtype=torch.int32)
            range_max = (input_scale * (int8_traits.max - azps)).max()
            range_min = (input_scale * (int8_traits.min - azps)).min()

            scale = (range_max - range_min) / (int8_traits.max -
                                               int8_traits.min)
            replace_parameter(
                layer, self.i_s_name,
                torch.nn.Parameter(scale, requires_grad=False))

            azp = (int8_traits.min -
                   range_min / scale).round().to(dtype=torch.int32)
            replace_parameter(layer, self.i_zp_name,
                              torch.nn.Parameter(azp, requires_grad=False))

    else:
        setattr(layer, self.i_s_name, None)
        setattr(layer, self.i_zp_name, None)

    # Different from cutlass, oneDNN kernels only need the AZP adjustment
    # term for dynamic quantization. And s_b should be folded into the
    # term. Such as:
    # s_a * s_b * [(A - zp_a)B] + bias =
    # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
    # s_a * GEMM_output - s_a * zp_a * adj + bias
    if not (self.config.input_symmetric
            and self.config.is_static_input_scheme):
        weight = getattr(layer, self.w_q_name)
        weight_scale = getattr(layer, self.w_s_name)
        azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
        azp_adj = azp_adj * weight_scale.squeeze()
        setattr(layer, self.azp_adj_name,
                torch.nn.Parameter(azp_adj, requires_grad=False))
    else:
        setattr(layer, self.azp_adj_name, None)

    weight = getattr(layer, self.w_q_name)
    self.dnnl_handler = ops.create_onednn_scaled_mm(
        weight,
        getattr(layer, self.w_s_name),
        torch.get_default_dtype(),
        getattr(layer, self.i_s_name) is None,
        not self.config.input_symmetric,
        32,
    )
    # weight is prepacked and maintained by the dnnl_handler,
    # release the original weight
    setattr(layer, self.w_q_name, None)
    del weight

process_weights_for_sgl

process_weights_for_sgl(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
    # WEIGHT
    weight = getattr(layer, self.w_q_name)
    packed_weight = torch.ops._C.convert_weight_packed(weight)
    replace_parameter(
        layer, self.w_q_name,
        torch.nn.Parameter(packed_weight, requires_grad=False))

    if layer.bias is not None:
        bias = layer.bias
        layer.register_parameter(
            "bias_fp32",
            torch.nn.Parameter(bias.float().data, requires_grad=False))

    # WEIGHT SCALE
    # CPU SGL kernels only support per-channel.
    # For per-tensor quant, convert to the per-channel case.
    weight_scale = getattr(layer, self.w_s_name)
    if not self.config.is_channelwise:
        weight_scale = convert_to_channelwise(weight_scale,
                                              layer.logical_widths)
    replace_parameter(
        layer, self.w_s_name,
        torch.nn.Parameter(weight_scale.data, requires_grad=False))

    setattr(layer, self.i_s_name, None)
    setattr(layer, self.i_zp_name, None)
    setattr(layer, self.azp_adj_name, None)