class CutlassW4A8LinearKernel(MPLinearKernel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# dynamic per-tok fp8 activation quantization
self.quant_fp8 = QuantFP8(static=False,
group_shape=GroupShape.PER_TOKEN)
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_cuda():
return False, "CUTLASS only supported on CUDA"
if not current_platform.is_device_capability(90):
return False, "CUTLASS W4A8 requires compute capability of 90 "\
"(Hopper)"
if c.act_type != torch.float8_e4m3fn:
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
if c.has_g_idx:
return False, "Act reordering not supported by CUTLASS W4A8"
if c.zero_points:
return False, "Zero points not supported by CUTLASS W4A8"
if c.weight_type != scalar_types.int4:
return False, f"Quant type ({c.weight_type}) not supported by "\
"CUTLASS W4A8, only supported int4"
# TODO(czhu): support -1 (column-wise)
if c.group_size != 128:
return False, "Only group_size 128 is supported"
in_features, out_features = c.partition_weight_shape
if in_features % 128 or out_features % 128:
return False, "K and N must be divisible by 128, got "\
f"{c.partition_weight_shape}"
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
# TODO(czhu): optimize speed/mem usage
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(
x.data.t().contiguous().t())
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
x.data = ops.cutlass_pack_scale_fp8(x.data)
return x
# Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
# TODO(czhu): support loading channel scales
self.w_ch_s = torch.ones((c.partition_weight_shape[1], ),
dtype=torch.float32,
device='cuda')
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None, "bias not supported by CUTLASS W4A8"
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
x_2d, act_scales = self.quant_fp8(x_2d)
output = ops.cutlass_w4a8_mm(a=x_2d,
b_q=w_q,
b_group_scales=w_s,
b_group_size=c.group_size,
a_token_scales=act_scales,
b_channel_scales=self.w_ch_s)
return output.reshape(out_shape)