class QuarkW4A4MXFP4(QuarkScheme):
def __init__(self, weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any]):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(
"QuarkW4A4MXFP4 with static input scales is currently not "
"implemented. Please open an issue.")
if not current_platform.supports_mx():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
@classmethod
def get_min_capability(cls) -> int:
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=2,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.emulate:
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
x = quant_dequant_mxfp4(x)
return F.linear(x, dq_w, bias)
else:
raise NotImplementedError()