Skip to content

vllm.model_executor.layers.quantization.mxfp4

logger module-attribute

logger = init_logger(__name__)

Mxfp4Config

Bases: QuantizationConfig

Source code in vllm/model_executor/layers/quantization/mxfp4.py
class Mxfp4Config(QuantizationConfig):

    def __init__(self, ignored_layers: Optional[list[str]] = None):
        super().__init__()
        self.ignored_layers = ignored_layers

    @classmethod
    def from_config(cls, config):
        return cls()

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp4"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        from vllm.attention.layer import Attention  # Avoid circular import

        if isinstance(layer, LinearBase):
            if self.ignored_layers and is_layer_skipped(
                    prefix=prefix,
                    ignored_layers=self.ignored_layers,
                    fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            raise NotImplementedError("Mxfp4 linear layer is not implemented")
        elif isinstance(layer, FusedMoE):
            return Mxfp4MoEMethod(layer.moe_config)
        elif isinstance(layer, Attention):
            raise NotImplementedError(
                "Mxfp4 attention layer is not implemented")
        return None

ignored_layers instance-attribute

ignored_layers = ignored_layers

__init__

__init__(ignored_layers: Optional[list[str]] = None)
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def __init__(self, ignored_layers: Optional[list[str]] = None):
    super().__init__()
    self.ignored_layers = ignored_layers

from_config classmethod

from_config(config)
Source code in vllm/model_executor/layers/quantization/mxfp4.py
@classmethod
def from_config(cls, config):
    return cls()

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/mxfp4.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return []

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/mxfp4.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/mxfp4.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "mxfp4"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    from vllm.attention.layer import Attention  # Avoid circular import

    if isinstance(layer, LinearBase):
        if self.ignored_layers and is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping):
            return UnquantizedLinearMethod()
        raise NotImplementedError("Mxfp4 linear layer is not implemented")
    elif isinstance(layer, FusedMoE):
        return Mxfp4MoEMethod(layer.moe_config)
    elif isinstance(layer, Attention):
        raise NotImplementedError(
            "Mxfp4 attention layer is not implemented")
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/mxfp4.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16]

Mxfp4MoEMethod

Bases: FusedMoEMethodBase

Source code in vllm/model_executor/layers/quantization/mxfp4.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
class Mxfp4MoEMethod(FusedMoEMethodBase):

    def __init__(self, moe: FusedMoEConfig):
        super().__init__(moe)
        self.topk_indices_dtype = None
        self.moe = moe
        self.use_marlin = self._should_use_marlin()
        self.max_capture_size = get_current_vllm_config(
        ).compilation_config.max_capture_size

        if current_platform.is_device_capability(100) and not has_flashinfer():
            logger.warning_once(
                "MXFP4 MoE is enabled on Blackwell but FlashInfer "
                "is not available. This may result in degraded performance. "
                "Please `pip install vllm[flashinfer]` for best results.")

    def _should_use_marlin(self):
        if envs.VLLM_MXFP4_USE_MARLIN is not None:
            return envs.VLLM_MXFP4_USE_MARLIN
        if current_platform.is_cuda() and \
                not current_platform.is_device_capability(100):
            if not current_platform.has_device_capability(90):
                # marlin kernel has better performance on ampere
                return True
            if not has_triton_kernels():
                return True
            if not is_torch_equal_or_newer("2.8.0"):
                return True
        return False

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):
        self.num_experts = num_experts
        weight_dtype = torch.uint8
        scale_dtype = torch.uint8

        # FIXME (zyongye): ship after torch and safetensors support mxfp4
        # is_torch_mxfp4_available = (
        #     hasattr(torch, "float4_e2m1fn_x2") and
        #     hasattr(torch, "float8_e8m0fnu"))
        # if is_torch_mxfp4_available:
        #     weight_dtype = torch.float4_e2m1fn_x2
        #     scale_dtype = torch.float8_e8m0fnu

        mxfp4_block = 32

        intermediate_size_per_partition_after_pad = \
            intermediate_size_per_partition
        if self.use_marlin:
            # The moe marlin kernel requires that for each linear
            # n % 256 == 0 and k % 128 == 0.
            # In gate_up_proj:
            #    n = 2 * intermediate_size_per_partition_after_pad
            #    k = hidden_size
            # In down_proj
            #    n = hidden_size
            #    k = intermediate_size_per_partition_after_pad
            intermediate_size_per_partition_after_pad = round_up(
                intermediate_size_per_partition, 128)
            hidden_size = round_up(hidden_size, 256)

            layer.params_dtype = params_dtype
            layer.num_experts = num_experts
            layer.hidden_size = hidden_size
            layer.intermediate_size_per_partition = \
                intermediate_size_per_partition_after_pad
        elif should_use_flashinfer_mxfp4():
            # pad the intermediate size to be a multiple of 2 * mxfp4_block
            # for to hold non-uniform sharded tensor as well as swizzling
            # other padding to increase performance
            intermediate_size_per_partition_after_pad = round_up(
                intermediate_size_per_partition, 256)
            hidden_size = round_up(hidden_size, 256)
        elif current_platform.is_rocm():
            intermediate_size_per_partition_after_pad = round_up(
                intermediate_size_per_partition, 128)
        else:
            intermediate_size_per_partition_after_pad = round_up(
                intermediate_size_per_partition, 64)

        self.intermediate_size = intermediate_size_per_partition_after_pad
        self.hidden_size = hidden_size
        # Fused gate_up_proj (column parallel)
        w13_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight", w13_weight)
        set_weight_attrs(w13_weight, extra_weight_attrs)

        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                hidden_size // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)

        w13_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition_after_pad,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_bias", w13_bias)
        set_weight_attrs(w13_bias, extra_weight_attrs)

        # down_proj (row parallel)
        w2_weight = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // 2,
                dtype=weight_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight", w2_weight)
        set_weight_attrs(w2_weight, extra_weight_attrs)

        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition_after_pad // mxfp4_block,
                dtype=scale_dtype,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)

        w2_bias = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                dtype=torch.bfloat16,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_bias", w2_bias)
        set_weight_attrs(w2_bias, extra_weight_attrs)

    def process_weights_after_loading(self, layer):
        if self.use_marlin:
            prepare_moe_fp4_layer_for_marlin(layer)
        elif should_use_flashinfer_mxfp4():
            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
            layer.gemm1_alpha = Parameter(torch.tensor(
                [1.702] * self.num_experts, dtype=torch.float32).cuda(),
                                          requires_grad=False)
            layer.gemm1_beta = Parameter(torch.tensor(
                [1.0] * self.num_experts, dtype=torch.float32).cuda(),
                                         requires_grad=False)
            layer.gemm1_clamp_limit = Parameter(torch.tensor(
                [7.0] * self.num_experts, dtype=torch.float32).cuda(),
                                                requires_grad=False)
            sf_block_size = 32  # mxfp4 block size

            assert (layer.w13_weight.dim() == 3
                    and layer.w13_weight.shape[0] == self.num_experts
                    and layer.w13_weight.shape[1] == self.intermediate_size * 2
                    and layer.w13_weight.shape[2] == self.hidden_size // 2)
            assert (layer.w13_weight_scale.dim() == 3
                    and layer.w13_weight_scale.shape[0] == self.num_experts
                    and layer.w13_weight_scale.shape[1]
                    == self.intermediate_size * 2
                    and layer.w13_weight_scale.shape[2]
                    == self.hidden_size // sf_block_size)
            assert (layer.w2_weight.dim() == 3
                    and layer.w2_weight.shape[0] == self.num_experts
                    and layer.w2_weight.shape[1] == self.hidden_size and
                    layer.w2_weight.shape[2] == self.intermediate_size // 2)
            assert (layer.w2_weight_scale.dim() == 3
                    and layer.w2_weight_scale.shape[1] == self.hidden_size
                    and layer.w2_weight_scale.shape[2]
                    == self.intermediate_size // sf_block_size)
            assert (layer.w13_bias.dim() == 2
                    and layer.w13_bias.shape[0] == self.num_experts
                    and layer.w13_bias.shape[1] == self.intermediate_size * 2)
            assert (layer.w2_bias.dim() == 2
                    and layer.w2_bias.shape[0] == self.num_experts
                    and layer.w2_bias.shape[1] == self.hidden_size)

            w13_weight_scale = layer.w13_weight_scale.data
            w2_weight_scale = layer.w2_weight_scale.data
            w13_weight = layer.w13_weight.data
            w2_weight = layer.w2_weight.data
            w13_bias = layer.w13_bias.data.to(torch.float32)
            w2_bias = layer.w2_bias.data.to(torch.float32)

            # Swap w1 and w3 as the defenition of
            # swiglu is different in the trtllm-gen
            def swap_every_two_rows(x, axis=-1):
                shape = x.shape
                if axis < 0:
                    axis = len(shape) + axis

                # Create a new shape with pairs swapped along specified axis
                new_shape = list(shape)
                new_shape[axis] = shape[axis] // 2
                new_shape.insert(axis + 1, 2)

                # Reshape to expose pairs, swap them, and reshape back
                x = x.reshape(*new_shape)
                x = x.flip(axis + 1)
                new_shape = list(shape)
                return x.reshape(*new_shape)

            w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
            w13_weight = swap_every_two_rows(w13_weight, -2)
            w13_bias = swap_every_two_rows(w13_bias, -1)

            # Do not interleave as the checkpoint is already interleaved

            # Shuffle weights and scaling factors for transposed mma output
            gemm1_weights_mxfp4_shuffled = []
            gemm1_scales_mxfp4_shuffled = []
            gemm2_weights_mxfp4_shuffled = []
            gemm2_scales_mxfp4_shuffled = []
            gemm1_bias_shuffled = []
            gemm2_bias_shuffled = []
            epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
            for i in range(self.num_experts):
                gemm1_weights_mxfp4_shuffled.append(
                    shuffle_matrix_a(w13_weight[i].view(torch.uint8),
                                     epilogue_tile_m))
                gemm1_scales_mxfp4_shuffled.append(
                    shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
                                        epilogue_tile_m))
                gemm1_bias_shuffled.append(
                    shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
                                     epilogue_tile_m))

                gemm2_weights_mxfp4_shuffled.append(
                    shuffle_matrix_a(w2_weight[i].view(torch.uint8),
                                     epilogue_tile_m))
                gemm2_scales_mxfp4_shuffled.append(
                    shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
                                        epilogue_tile_m))
                gemm2_bias_shuffled.append(
                    shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
                                     epilogue_tile_m))

            w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
            w13_weight_scale = torch.stack(
                gemm1_scales_mxfp4_shuffled).reshape(
                    self.num_experts, 2 * self.intermediate_size,
                    self.hidden_size // sf_block_size).view(
                        torch.float8_e4m3fn)

            w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
            w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
                self.num_experts, self.hidden_size, self.intermediate_size //
                sf_block_size).view(torch.float8_e4m3fn)

            layer.w13_weight = Parameter(w13_weight, requires_grad=False)
            layer.w13_weight_scale = Parameter(w13_weight_scale,
                                               requires_grad=False)
            layer.w2_weight = Parameter(w2_weight, requires_grad=False)
            layer.w2_weight_scale = Parameter(w2_weight_scale,
                                              requires_grad=False)
            layer.w13_bias = Parameter(
                torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
                requires_grad=False)
            layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
                self.num_experts, -1),
                                      requires_grad=False)
        else:
            from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

            w13_bias = layer.w13_bias.to(torch.float32)
            w2_bias = layer.w2_bias.to(torch.float32)

            layer.w13_bias = Parameter(w13_bias, requires_grad=False)
            layer.w2_bias = Parameter(w2_bias, requires_grad=False)

            # FIXME warp need to be adjusted based on batch size
            # only apply to  batched mode
            if self.moe.use_ep:
                num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
            else:
                num_warps = 8

            w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
                layer.w13_weight, layer.w13_weight_scale, num_warps)
            w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
                layer.w2_weight, layer.w2_weight_scale, num_warps)

            self.w13_precision_config = PrecisionConfig(
                weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
            self.w2_precision_config = PrecisionConfig(
                weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))

            self.w13_weight_triton_tensor = w13_weight
            self.w2_weight_triton_tensor = w2_weight

            # need to delete the original weights to save memory on single GPU
            del layer.w13_weight
            del layer.w2_weight
            layer.w13_weight = None
            layer.w2_weight = None
            torch.cuda.empty_cache()

    def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
        # Number of tokens in the input tensor.
        num_tokens = x.shape[0]
        # Factor to account for the imbalance of the experts.
        # factor equals to the
        # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
        # - 1.0 means perfect expert distribution.
        # - > 1.0 means some experts have more
        #     tokens than the perfect distribution.
        # - < 1.0 does not make sense.
        imbalance_factor = 1.3
        # Calculate the number of tokens per expert
        # assuming perfect distribution.
        num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
        # Apply the imbalance factor.
        num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
        # And pad the number to the next power of 2.
        tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
        # Cap to 8-64 tokens per CTA tile
        # as it's the range supported by the kernel.
        tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

        return tile_tokens_dim

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        if enable_eplb:
            raise NotImplementedError("EPLB is not supported for mxfp4")

        if self.use_marlin:
            topk_weights, topk_ids = FusedMoE.select_experts(
                hidden_states=x,
                router_logits=router_logits,
                use_grouped_topk=use_grouped_topk,
                top_k=top_k,
                renormalize=renormalize,
                topk_group=topk_group,
                num_expert_group=num_expert_group,
                custom_routing_function=custom_routing_function,
                scoring_func=scoring_func,
                e_score_correction_bias=e_score_correction_bias)

            return torch.ops.vllm.fused_marlin_moe(
                x,
                layer.w13_weight,
                layer.w2_weight,
                layer.w13_bias,
                layer.w2_bias,
                layer.w13_weight_scale,
                layer.w2_weight_scale,
                router_logits,
                topk_weights,
                topk_ids,
                global_scale1=None,
                global_scale2=None,
                quant_type_id=scalar_types.float4_e2m1f.id,
                apply_router_weight_on_input=apply_router_weight_on_input,
                global_num_experts=global_num_experts,
                activation=activation,
                expert_map=expert_map)

        assert _can_support_mxfp4(
            use_grouped_topk, topk_group, num_expert_group, expert_map,
            custom_routing_function, e_score_correction_bias,
            apply_router_weight_on_input, scoring_func, activation,
            expert_load_view, logical_to_physical_map,
            logical_replica_count), (
                "MXFP4 are not supported with this configuration.")

        if should_use_flashinfer_mxfp4():
            from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
            assert not self.moe.use_ep, (
                "EP is not supported for flashinfer mxfp4 moe backend yet.")
            if _should_use_flashinfer_mxfp4_bf16():
                assert x.dtype == torch.bfloat16
                x_quant = x
                x_scale = None
            else:
                x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
                x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
                    *x.shape[:-1], -1)
            trtllm_gen_output = trtllm_fp4_block_scale_moe(
                router_logits.to(torch.bfloat16),
                None,  # routing_bias
                x_quant,
                x_scale,
                layer.w13_weight,  # uint8 (e2m1 x 2)
                layer.w13_weight_scale,  # uint8 (e4m3 x 2)
                layer.w13_bias,  # fp32 per expert per channel
                layer.gemm1_alpha,  # fp32 per expert
                layer.gemm1_beta,  # fp32 per expert
                layer.gemm1_clamp_limit,  # fp32 per expert
                layer.w2_weight,  # uint8 (e2m1 x 2)
                layer.w2_weight_scale,  # ue8m0
                layer.w2_bias,  # fp32 per expert per channel
                None,  # output1_scale_scalar
                None,  # output1_scale_gate_scalar
                None,  # output2_scale_scalar
                self.num_experts,
                top_k,
                None,  # n_group
                None,  # topk_group
                self.intermediate_size,  # padded to multiple of 256
                0,  # local_expert_offset
                self.num_experts,  # local num experts
                None,
                self._get_tile_tokens_dim(x, top_k),
                1 if renormalize else 0,  # routing_method_type, renormalize
                True,  # do finalize
                tune_max_num_tokens=self.max_capture_size,
            )[0]
            return trtllm_gen_output
        else:
            return triton_kernel_moe_forward(
                hidden_states=x,
                w1=self.w13_weight_triton_tensor,
                w2=self.w2_weight_triton_tensor,
                gating_output=router_logits,
                topk=top_k,
                renormalize=renormalize,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
                w1_bias=layer.w13_bias,
                w2_bias=layer.w2_bias,
                w1_precision=self.w13_precision_config,
                w2_precision=self.w2_precision_config,
                apply_router_weight_on_input=apply_router_weight_on_input,
            )

max_capture_size instance-attribute

max_capture_size = max_capture_size

moe instance-attribute

moe = moe

topk_indices_dtype instance-attribute

topk_indices_dtype = None

use_marlin instance-attribute

use_marlin = _should_use_marlin()

__init__

__init__(moe: FusedMoEConfig)
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def __init__(self, moe: FusedMoEConfig):
    super().__init__(moe)
    self.topk_indices_dtype = None
    self.moe = moe
    self.use_marlin = self._should_use_marlin()
    self.max_capture_size = get_current_vllm_config(
    ).compilation_config.max_capture_size

    if current_platform.is_device_capability(100) and not has_flashinfer():
        logger.warning_once(
            "MXFP4 MoE is enabled on Blackwell but FlashInfer "
            "is not available. This may result in degraded performance. "
            "Please `pip install vllm[flashinfer]` for best results.")

_get_tile_tokens_dim

_get_tile_tokens_dim(x: Tensor, top_k: int)
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
    # Number of tokens in the input tensor.
    num_tokens = x.shape[0]
    # Factor to account for the imbalance of the experts.
    # factor equals to the
    # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
    # - 1.0 means perfect expert distribution.
    # - > 1.0 means some experts have more
    #     tokens than the perfect distribution.
    # - < 1.0 does not make sense.
    imbalance_factor = 1.3
    # Calculate the number of tokens per expert
    # assuming perfect distribution.
    num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
    # Apply the imbalance factor.
    num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
    # And pad the number to the next power of 2.
    tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
    # Cap to 8-64 tokens per CTA tile
    # as it's the range supported by the kernel.
    tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

    return tile_tokens_dim

_should_use_marlin

_should_use_marlin()
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def _should_use_marlin(self):
    if envs.VLLM_MXFP4_USE_MARLIN is not None:
        return envs.VLLM_MXFP4_USE_MARLIN
    if current_platform.is_cuda() and \
            not current_platform.is_device_capability(100):
        if not current_platform.has_device_capability(90):
            # marlin kernel has better performance on ampere
            return True
        if not has_triton_kernels():
            return True
        if not is_torch_equal_or_newer("2.8.0"):
            return True
    return False

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:

    if enable_eplb:
        raise NotImplementedError("EPLB is not supported for mxfp4")

    if self.use_marlin:
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias)

        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_weight,
            layer.w2_weight,
            layer.w13_bias,
            layer.w2_bias,
            layer.w13_weight_scale,
            layer.w2_weight_scale,
            router_logits,
            topk_weights,
            topk_ids,
            global_scale1=None,
            global_scale2=None,
            quant_type_id=scalar_types.float4_e2m1f.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            activation=activation,
            expert_map=expert_map)

    assert _can_support_mxfp4(
        use_grouped_topk, topk_group, num_expert_group, expert_map,
        custom_routing_function, e_score_correction_bias,
        apply_router_weight_on_input, scoring_func, activation,
        expert_load_view, logical_to_physical_map,
        logical_replica_count), (
            "MXFP4 are not supported with this configuration.")

    if should_use_flashinfer_mxfp4():
        from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
        assert not self.moe.use_ep, (
            "EP is not supported for flashinfer mxfp4 moe backend yet.")
        if _should_use_flashinfer_mxfp4_bf16():
            assert x.dtype == torch.bfloat16
            x_quant = x
            x_scale = None
        else:
            x_quant, x_scale = mxfp8_quantize(x, False)  # to mxfp8
            x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
                *x.shape[:-1], -1)
        trtllm_gen_output = trtllm_fp4_block_scale_moe(
            router_logits.to(torch.bfloat16),
            None,  # routing_bias
            x_quant,
            x_scale,
            layer.w13_weight,  # uint8 (e2m1 x 2)
            layer.w13_weight_scale,  # uint8 (e4m3 x 2)
            layer.w13_bias,  # fp32 per expert per channel
            layer.gemm1_alpha,  # fp32 per expert
            layer.gemm1_beta,  # fp32 per expert
            layer.gemm1_clamp_limit,  # fp32 per expert
            layer.w2_weight,  # uint8 (e2m1 x 2)
            layer.w2_weight_scale,  # ue8m0
            layer.w2_bias,  # fp32 per expert per channel
            None,  # output1_scale_scalar
            None,  # output1_scale_gate_scalar
            None,  # output2_scale_scalar
            self.num_experts,
            top_k,
            None,  # n_group
            None,  # topk_group
            self.intermediate_size,  # padded to multiple of 256
            0,  # local_expert_offset
            self.num_experts,  # local num experts
            None,
            self._get_tile_tokens_dim(x, top_k),
            1 if renormalize else 0,  # routing_method_type, renormalize
            True,  # do finalize
            tune_max_num_tokens=self.max_capture_size,
        )[0]
        return trtllm_gen_output
    else:
        return triton_kernel_moe_forward(
            hidden_states=x,
            w1=self.w13_weight_triton_tensor,
            w2=self.w2_weight_triton_tensor,
            gating_output=router_logits,
            topk=top_k,
            renormalize=renormalize,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            w1_bias=layer.w13_bias,
            w2_bias=layer.w2_bias,
            w1_precision=self.w13_precision_config,
            w2_precision=self.w2_precision_config,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):
    self.num_experts = num_experts
    weight_dtype = torch.uint8
    scale_dtype = torch.uint8

    # FIXME (zyongye): ship after torch and safetensors support mxfp4
    # is_torch_mxfp4_available = (
    #     hasattr(torch, "float4_e2m1fn_x2") and
    #     hasattr(torch, "float8_e8m0fnu"))
    # if is_torch_mxfp4_available:
    #     weight_dtype = torch.float4_e2m1fn_x2
    #     scale_dtype = torch.float8_e8m0fnu

    mxfp4_block = 32

    intermediate_size_per_partition_after_pad = \
        intermediate_size_per_partition
    if self.use_marlin:
        # The moe marlin kernel requires that for each linear
        # n % 256 == 0 and k % 128 == 0.
        # In gate_up_proj:
        #    n = 2 * intermediate_size_per_partition_after_pad
        #    k = hidden_size
        # In down_proj
        #    n = hidden_size
        #    k = intermediate_size_per_partition_after_pad
        intermediate_size_per_partition_after_pad = round_up(
            intermediate_size_per_partition, 128)
        hidden_size = round_up(hidden_size, 256)

        layer.params_dtype = params_dtype
        layer.num_experts = num_experts
        layer.hidden_size = hidden_size
        layer.intermediate_size_per_partition = \
            intermediate_size_per_partition_after_pad
    elif should_use_flashinfer_mxfp4():
        # pad the intermediate size to be a multiple of 2 * mxfp4_block
        # for to hold non-uniform sharded tensor as well as swizzling
        # other padding to increase performance
        intermediate_size_per_partition_after_pad = round_up(
            intermediate_size_per_partition, 256)
        hidden_size = round_up(hidden_size, 256)
    elif current_platform.is_rocm():
        intermediate_size_per_partition_after_pad = round_up(
            intermediate_size_per_partition, 128)
    else:
        intermediate_size_per_partition_after_pad = round_up(
            intermediate_size_per_partition, 64)

    self.intermediate_size = intermediate_size_per_partition_after_pad
    self.hidden_size = hidden_size
    # Fused gate_up_proj (column parallel)
    w13_weight = torch.nn.Parameter(
        torch.zeros(
            num_experts,
            2 * intermediate_size_per_partition_after_pad,
            hidden_size // 2,
            dtype=weight_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight", w13_weight)
    set_weight_attrs(w13_weight, extra_weight_attrs)

    w13_weight_scale = torch.nn.Parameter(
        torch.zeros(
            num_experts,
            2 * intermediate_size_per_partition_after_pad,
            hidden_size // mxfp4_block,
            dtype=scale_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_weight_scale", w13_weight_scale)
    set_weight_attrs(w13_weight_scale, extra_weight_attrs)

    w13_bias = torch.nn.Parameter(
        torch.zeros(
            num_experts,
            2 * intermediate_size_per_partition_after_pad,
            dtype=torch.bfloat16,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_bias", w13_bias)
    set_weight_attrs(w13_bias, extra_weight_attrs)

    # down_proj (row parallel)
    w2_weight = torch.nn.Parameter(
        torch.zeros(
            num_experts,
            hidden_size,
            intermediate_size_per_partition_after_pad // 2,
            dtype=weight_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight", w2_weight)
    set_weight_attrs(w2_weight, extra_weight_attrs)

    w2_weight_scale = torch.nn.Parameter(
        torch.zeros(
            num_experts,
            hidden_size,
            intermediate_size_per_partition_after_pad // mxfp4_block,
            dtype=scale_dtype,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_weight_scale", w2_weight_scale)
    set_weight_attrs(w2_weight_scale, extra_weight_attrs)

    w2_bias = torch.nn.Parameter(
        torch.zeros(
            num_experts,
            hidden_size,
            dtype=torch.bfloat16,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_bias", w2_bias)
    set_weight_attrs(w2_bias, extra_weight_attrs)

process_weights_after_loading

process_weights_after_loading(layer)
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def process_weights_after_loading(self, layer):
    if self.use_marlin:
        prepare_moe_fp4_layer_for_marlin(layer)
    elif should_use_flashinfer_mxfp4():
        from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
        layer.gemm1_alpha = Parameter(torch.tensor(
            [1.702] * self.num_experts, dtype=torch.float32).cuda(),
                                      requires_grad=False)
        layer.gemm1_beta = Parameter(torch.tensor(
            [1.0] * self.num_experts, dtype=torch.float32).cuda(),
                                     requires_grad=False)
        layer.gemm1_clamp_limit = Parameter(torch.tensor(
            [7.0] * self.num_experts, dtype=torch.float32).cuda(),
                                            requires_grad=False)
        sf_block_size = 32  # mxfp4 block size

        assert (layer.w13_weight.dim() == 3
                and layer.w13_weight.shape[0] == self.num_experts
                and layer.w13_weight.shape[1] == self.intermediate_size * 2
                and layer.w13_weight.shape[2] == self.hidden_size // 2)
        assert (layer.w13_weight_scale.dim() == 3
                and layer.w13_weight_scale.shape[0] == self.num_experts
                and layer.w13_weight_scale.shape[1]
                == self.intermediate_size * 2
                and layer.w13_weight_scale.shape[2]
                == self.hidden_size // sf_block_size)
        assert (layer.w2_weight.dim() == 3
                and layer.w2_weight.shape[0] == self.num_experts
                and layer.w2_weight.shape[1] == self.hidden_size and
                layer.w2_weight.shape[2] == self.intermediate_size // 2)
        assert (layer.w2_weight_scale.dim() == 3
                and layer.w2_weight_scale.shape[1] == self.hidden_size
                and layer.w2_weight_scale.shape[2]
                == self.intermediate_size // sf_block_size)
        assert (layer.w13_bias.dim() == 2
                and layer.w13_bias.shape[0] == self.num_experts
                and layer.w13_bias.shape[1] == self.intermediate_size * 2)
        assert (layer.w2_bias.dim() == 2
                and layer.w2_bias.shape[0] == self.num_experts
                and layer.w2_bias.shape[1] == self.hidden_size)

        w13_weight_scale = layer.w13_weight_scale.data
        w2_weight_scale = layer.w2_weight_scale.data
        w13_weight = layer.w13_weight.data
        w2_weight = layer.w2_weight.data
        w13_bias = layer.w13_bias.data.to(torch.float32)
        w2_bias = layer.w2_bias.data.to(torch.float32)

        # Swap w1 and w3 as the defenition of
        # swiglu is different in the trtllm-gen
        def swap_every_two_rows(x, axis=-1):
            shape = x.shape
            if axis < 0:
                axis = len(shape) + axis

            # Create a new shape with pairs swapped along specified axis
            new_shape = list(shape)
            new_shape[axis] = shape[axis] // 2
            new_shape.insert(axis + 1, 2)

            # Reshape to expose pairs, swap them, and reshape back
            x = x.reshape(*new_shape)
            x = x.flip(axis + 1)
            new_shape = list(shape)
            return x.reshape(*new_shape)

        w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
        w13_weight = swap_every_two_rows(w13_weight, -2)
        w13_bias = swap_every_two_rows(w13_bias, -1)

        # Do not interleave as the checkpoint is already interleaved

        # Shuffle weights and scaling factors for transposed mma output
        gemm1_weights_mxfp4_shuffled = []
        gemm1_scales_mxfp4_shuffled = []
        gemm2_weights_mxfp4_shuffled = []
        gemm2_scales_mxfp4_shuffled = []
        gemm1_bias_shuffled = []
        gemm2_bias_shuffled = []
        epilogue_tile_m = 128  # FIXME: this depends on the kernel internals
        for i in range(self.num_experts):
            gemm1_weights_mxfp4_shuffled.append(
                shuffle_matrix_a(w13_weight[i].view(torch.uint8),
                                 epilogue_tile_m))
            gemm1_scales_mxfp4_shuffled.append(
                shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
                                    epilogue_tile_m))
            gemm1_bias_shuffled.append(
                shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
                                 epilogue_tile_m))

            gemm2_weights_mxfp4_shuffled.append(
                shuffle_matrix_a(w2_weight[i].view(torch.uint8),
                                 epilogue_tile_m))
            gemm2_scales_mxfp4_shuffled.append(
                shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
                                    epilogue_tile_m))
            gemm2_bias_shuffled.append(
                shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
                                 epilogue_tile_m))

        w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
        w13_weight_scale = torch.stack(
            gemm1_scales_mxfp4_shuffled).reshape(
                self.num_experts, 2 * self.intermediate_size,
                self.hidden_size // sf_block_size).view(
                    torch.float8_e4m3fn)

        w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
        w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
            self.num_experts, self.hidden_size, self.intermediate_size //
            sf_block_size).view(torch.float8_e4m3fn)

        layer.w13_weight = Parameter(w13_weight, requires_grad=False)
        layer.w13_weight_scale = Parameter(w13_weight_scale,
                                           requires_grad=False)
        layer.w2_weight = Parameter(w2_weight, requires_grad=False)
        layer.w2_weight_scale = Parameter(w2_weight_scale,
                                          requires_grad=False)
        layer.w13_bias = Parameter(
            torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
            requires_grad=False)
        layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
            self.num_experts, -1),
                                  requires_grad=False)
    else:
        from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig

        w13_bias = layer.w13_bias.to(torch.float32)
        w2_bias = layer.w2_bias.to(torch.float32)

        layer.w13_bias = Parameter(w13_bias, requires_grad=False)
        layer.w2_bias = Parameter(w2_bias, requires_grad=False)

        # FIXME warp need to be adjusted based on batch size
        # only apply to  batched mode
        if self.moe.use_ep:
            num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
        else:
            num_warps = 8

        w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
            layer.w13_weight, layer.w13_weight_scale, num_warps)
        w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
            layer.w2_weight, layer.w2_weight_scale, num_warps)

        self.w13_precision_config = PrecisionConfig(
            weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
        self.w2_precision_config = PrecisionConfig(
            weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))

        self.w13_weight_triton_tensor = w13_weight
        self.w2_weight_triton_tensor = w2_weight

        # need to delete the original weights to save memory on single GPU
        del layer.w13_weight
        del layer.w2_weight
        layer.w13_weight = None
        layer.w2_weight = None
        torch.cuda.empty_cache()

_should_use_flashinfer_mxfp4_bf16

_should_use_flashinfer_mxfp4_bf16()

Determine if FlashInfer MXFP4 BF16 should be used.

Source code in vllm/model_executor/layers/quantization/mxfp4.py
def _should_use_flashinfer_mxfp4_bf16():
    """Determine if FlashInfer MXFP4 BF16 should be used."""
    # If explicitly set, respect the setting
    if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
        return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16

    # Enable by default on SM100 if MXFP8 is not explicitly enabled
    if (current_platform.is_device_capability(100) and has_flashinfer()
            and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
        logger.info_once(
            "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
            "For faster performance, consider setting "
            "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
            "though this may impact accuracy.")
        return True

    return False

_should_use_flashinfer_mxfp4_mxfp8

_should_use_flashinfer_mxfp4_mxfp8()

Determine if FlashInfer MXFP4 MXFP8 should be used.

Source code in vllm/model_executor/layers/quantization/mxfp4.py
def _should_use_flashinfer_mxfp4_mxfp8():
    """Determine if FlashInfer MXFP4 MXFP8 should be used."""
    return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8

should_use_flashinfer_mxfp4

should_use_flashinfer_mxfp4()
Source code in vllm/model_executor/layers/quantization/mxfp4.py
def should_use_flashinfer_mxfp4():
    return (_should_use_flashinfer_mxfp4_mxfp8()
            or _should_use_flashinfer_mxfp4_bf16())