Skip to content

vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe

logger module-attribute

logger = init_logger(__name__)

FlashInferExperts

Bases: FusedMoEPermuteExpertsUnpermute

Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):

    def __init__(
        self,
        g1_alphas: torch.Tensor,
        g2_alphas: torch.Tensor,
        a1_gscale: torch.Tensor,
        a2_gscale: torch.Tensor,
        out_dtype: torch.dtype,
        quant_dtype: Union[torch.dtype, str, None],
        ep_rank: int = 0,
        ep_size: int = 1,
        tp_rank: int = 0,
        tp_size: int = 1,
    ):
        super().__init__(
            FusedMoEQuantConfig(
                quant_dtype=quant_dtype,
                per_act_token_quant=False,
                block_shape=None,
            ))
        assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
            "Only nvfp4,fp8 quantization are currently supported.")
        self.ep_rank = ep_rank
        self.ep_size = ep_size
        self.tp_rank = tp_rank
        self.tp_size = tp_size
        self.g1_alphas = g1_alphas
        self.g2_alphas = g2_alphas
        self.a1_gscale = a1_gscale
        self.a2_gscale = a2_gscale
        self.out_dtype = out_dtype

    @property
    def activation_formats(
        self
    ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
        return (mk.FusedMoEActivationFormat.Standard,
                mk.FusedMoEActivationFormat.Standard)

    def supports_expert_map(self) -> bool:
        return False

    def supports_chunking(self) -> bool:
        # This refers to TP chunking; DP chunking is handled separately.
        return True

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceNoOP()

    def workspace_shapes(
        self,
        a: torch.Tensor,
        aq: torch.Tensor,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
        # We use global_num_experts due to how moe_align_block_size handles
        # expert_maps.
        """
        Compute the shapes for the temporary and final outputs of the two gemms
        and activation in the fused expert function.  Since the gemms are
        independent, the workspace for the first gemm can be shared with the
        workspace for the last gemm.

        Returns a tuple of:
        - workspace13 shape tuple: must be large enough to hold the
          result of either expert gemm.
        - workspace2 shape tuple: must be large enough to hold the
          result of the activation function.
        - output shape tuple: must be exact size of the final gemm output.
        - Workspace type: The dtype to use for the workspace tensors.
        - Note: in order for activation chunking to work, the first dimension
          of each tuple must be the number of tokens.
        """
        aq_m, aq_n = aq.shape
        workspace2 = ()
        output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
            torch.float8_e4m3fn else (aq_m, aq_n)
        workspace_dtype = a.dtype
        workspace1 = output_shape
        # The workspace is determined by `aq`, since it comes after any
        # potential communication op and is involved in the expert computation.
        return (workspace1, workspace2, output_shape, workspace_dtype)

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: str,
        global_num_experts: int,
        expert_map: Optional[torch.Tensor],
        w1_scale: Optional[torch.Tensor],
        w2_scale: Optional[torch.Tensor],
        w1_zp: Optional[torch.Tensor],
        w2_zp: Optional[torch.Tensor],
        a1q_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],  # Not used
        workspace13: Optional[torch.Tensor],
        workspace2: Optional[torch.Tensor],
        expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
        apply_router_weight_on_input: Optional[bool],
    ):
        if self.quant_dtype == torch.float8_e4m3fn:
            quant_scales = [
                self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
            ]

            a1q_scale = None  # not passing input_sf in fp8
            fc1_expert_weights = w1
            fc2_expert_weights = w2
        else:
            # Ensure w1_scale and w2_scale are not None before calling view
            assert w1_scale is not None and w2_scale is not None, (
                "w1_scale and w2_scale must not "
                "be None for FlashInferExperts")
            # Flashinfer CUTLASS kernel takes scalar global scales,
            # min because inv_scale.
            quant_scales = [
                self.a1_gscale,
                w1_scale.view(torch.int32),
                self.g1_alphas,
                self.a2_gscale,
                w2_scale.view(torch.int32),
                self.g2_alphas,
            ]
            # FlashInfer API requires weight to be long for nvfp4
            fc1_expert_weights = w1.view(torch.long)
            fc2_expert_weights = w2.view(torch.long)

        _ = flashinfer_cutlass_fused_moe(
            input=hidden_states,
            token_selected_experts=topk_ids.to(torch.int),
            token_final_scales=topk_weights,
            fc1_expert_weights=fc1_expert_weights,
            fc2_expert_weights=fc2_expert_weights,
            output_dtype=self.out_dtype,
            quant_scales=quant_scales,
            input_sf=a1q_scale,
            tp_size=self.tp_size,
            tp_rank=self.tp_rank,
            ep_size=self.ep_size,
            ep_rank=self.ep_rank,
            output=output,
        )

a1_gscale instance-attribute

a1_gscale = a1_gscale

a2_gscale instance-attribute

a2_gscale = a2_gscale

activation_formats property

ep_rank instance-attribute

ep_rank = ep_rank

ep_size instance-attribute

ep_size = ep_size

g1_alphas instance-attribute

g1_alphas = g1_alphas

g2_alphas instance-attribute

g2_alphas = g2_alphas

out_dtype instance-attribute

out_dtype = out_dtype

tp_rank instance-attribute

tp_rank = tp_rank

tp_size instance-attribute

tp_size = tp_size

__init__

__init__(
    g1_alphas: Tensor,
    g2_alphas: Tensor,
    a1_gscale: Tensor,
    a2_gscale: Tensor,
    out_dtype: dtype,
    quant_dtype: Union[dtype, str, None],
    ep_rank: int = 0,
    ep_size: int = 1,
    tp_rank: int = 0,
    tp_size: int = 1,
)
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def __init__(
    self,
    g1_alphas: torch.Tensor,
    g2_alphas: torch.Tensor,
    a1_gscale: torch.Tensor,
    a2_gscale: torch.Tensor,
    out_dtype: torch.dtype,
    quant_dtype: Union[torch.dtype, str, None],
    ep_rank: int = 0,
    ep_size: int = 1,
    tp_rank: int = 0,
    tp_size: int = 1,
):
    super().__init__(
        FusedMoEQuantConfig(
            quant_dtype=quant_dtype,
            per_act_token_quant=False,
            block_shape=None,
        ))
    assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
        "Only nvfp4,fp8 quantization are currently supported.")
    self.ep_rank = ep_rank
    self.ep_size = ep_size
    self.tp_rank = tp_rank
    self.tp_size = tp_size
    self.g1_alphas = g1_alphas
    self.g2_alphas = g2_alphas
    self.a1_gscale = a1_gscale
    self.a2_gscale = a2_gscale
    self.out_dtype = out_dtype

apply

apply(
    output: Tensor,
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[Tensor],
    w1_scale: Optional[Tensor],
    w2_scale: Optional[Tensor],
    w1_zp: Optional[Tensor],
    w2_zp: Optional[Tensor],
    a1q_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    workspace13: Optional[Tensor],
    workspace2: Optional[Tensor],
    expert_tokens_meta: Optional[ExpertTokensMetadata],
    apply_router_weight_on_input: Optional[bool],
)
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def apply(
    self,
    output: torch.Tensor,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: str,
    global_num_experts: int,
    expert_map: Optional[torch.Tensor],
    w1_scale: Optional[torch.Tensor],
    w2_scale: Optional[torch.Tensor],
    w1_zp: Optional[torch.Tensor],
    w2_zp: Optional[torch.Tensor],
    a1q_scale: Optional[torch.Tensor],
    a2_scale: Optional[torch.Tensor],  # Not used
    workspace13: Optional[torch.Tensor],
    workspace2: Optional[torch.Tensor],
    expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
    apply_router_weight_on_input: Optional[bool],
):
    if self.quant_dtype == torch.float8_e4m3fn:
        quant_scales = [
            self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
        ]

        a1q_scale = None  # not passing input_sf in fp8
        fc1_expert_weights = w1
        fc2_expert_weights = w2
    else:
        # Ensure w1_scale and w2_scale are not None before calling view
        assert w1_scale is not None and w2_scale is not None, (
            "w1_scale and w2_scale must not "
            "be None for FlashInferExperts")
        # Flashinfer CUTLASS kernel takes scalar global scales,
        # min because inv_scale.
        quant_scales = [
            self.a1_gscale,
            w1_scale.view(torch.int32),
            self.g1_alphas,
            self.a2_gscale,
            w2_scale.view(torch.int32),
            self.g2_alphas,
        ]
        # FlashInfer API requires weight to be long for nvfp4
        fc1_expert_weights = w1.view(torch.long)
        fc2_expert_weights = w2.view(torch.long)

    _ = flashinfer_cutlass_fused_moe(
        input=hidden_states,
        token_selected_experts=topk_ids.to(torch.int),
        token_final_scales=topk_weights,
        fc1_expert_weights=fc1_expert_weights,
        fc2_expert_weights=fc2_expert_weights,
        output_dtype=self.out_dtype,
        quant_scales=quant_scales,
        input_sf=a1q_scale,
        tp_size=self.tp_size,
        tp_rank=self.tp_rank,
        ep_size=self.ep_size,
        ep_rank=self.ep_rank,
        output=output,
    )

finalize_weight_and_reduce_impl

finalize_weight_and_reduce_impl() -> TopKWeightAndReduce
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
    return TopKWeightAndReduceNoOP()

supports_chunking

supports_chunking() -> bool
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def supports_chunking(self) -> bool:
    # This refers to TP chunking; DP chunking is handled separately.
    return True

supports_expert_map

supports_expert_map() -> bool
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def supports_expert_map(self) -> bool:
    return False

workspace_shapes

workspace_shapes(
    a: Tensor,
    aq: Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[
    tuple[int, ...], tuple[int, ...], tuple[int, ...], dtype
]

Compute the shapes for the temporary and final outputs of the two gemms and activation in the fused expert function. Since the gemms are independent, the workspace for the first gemm can be shared with the workspace for the last gemm.

Returns a tuple of: - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - workspace2 shape tuple: must be large enough to hold the result of the activation function. - output shape tuple: must be exact size of the final gemm output. - Workspace type: The dtype to use for the workspace tensors. - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens.

Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def workspace_shapes(
    self,
    a: torch.Tensor,
    aq: torch.Tensor,
    M: int,
    N: int,
    K: int,
    topk: int,
    global_num_experts: int,
    local_num_experts: int,
    expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
    # We use global_num_experts due to how moe_align_block_size handles
    # expert_maps.
    """
    Compute the shapes for the temporary and final outputs of the two gemms
    and activation in the fused expert function.  Since the gemms are
    independent, the workspace for the first gemm can be shared with the
    workspace for the last gemm.

    Returns a tuple of:
    - workspace13 shape tuple: must be large enough to hold the
      result of either expert gemm.
    - workspace2 shape tuple: must be large enough to hold the
      result of the activation function.
    - output shape tuple: must be exact size of the final gemm output.
    - Workspace type: The dtype to use for the workspace tensors.
    - Note: in order for activation chunking to work, the first dimension
      of each tuple must be the number of tokens.
    """
    aq_m, aq_n = aq.shape
    workspace2 = ()
    output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
        torch.float8_e4m3fn else (aq_m, aq_n)
    workspace_dtype = a.dtype
    workspace1 = output_shape
    # The workspace is determined by `aq`, since it comes after any
    # potential communication op and is involved in the expert computation.
    return (workspace1, workspace2, output_shape, workspace_dtype)

flashinfer_cutlass_moe_fp4

flashinfer_cutlass_moe_fp4(
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    w1_scale: Tensor,
    w2_scale: Tensor,
    g1_alphas: Tensor,
    g2_alphas: Tensor,
    a1_gscale: Tensor,
    a2_gscale: Tensor,
    inplace: bool = False,
    activation: str = "silu",
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
) -> Tensor
Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def flashinfer_cutlass_moe_fp4(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    g1_alphas: torch.Tensor,
    g2_alphas: torch.Tensor,
    a1_gscale: torch.Tensor,
    a2_gscale: torch.Tensor,
    inplace: bool = False,
    activation: str = "silu",
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
) -> torch.Tensor:

    fused_experts = mk.FusedMoEModularKernel(
        FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
                                               a1_gscale=a1_gscale),
        FlashInferExperts(
            g1_alphas=g1_alphas,
            g2_alphas=g2_alphas,
            a1_gscale=a1_gscale,
            a2_gscale=a2_gscale,
            out_dtype=hidden_states.dtype,
            quant_dtype="nvfp4",
        ))

    return fused_experts(
        hidden_states=hidden_states,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=inplace,
        activation=activation,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )

is_valid_flashinfer_cutlass_fused_moe

is_valid_flashinfer_cutlass_fused_moe(
    hidden_states: Tensor, w1: Tensor, w2: Tensor
) -> bool

Check if the given problem size is supported by the FlashInfer CUTLASS MoE kernel.

Source code in vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor,
                                          w1: torch.Tensor,
                                          w2: torch.Tensor) -> bool:
    """
    Check if the given problem size is supported by the FlashInfer CUTLASS MoE
    kernel.
    """
    if not has_flashinfer_cutlass_fused_moe():
        logger.debug_once("FlashInferExperts disabled: "
                          "flashinfer_cutlass_fused_moe not available.")
        return False
    # Data type checks
    if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8
            or hidden_states.dtype
            not in [torch.float32, torch.float16, torch.bfloat16]):
        logger.debug_once(
            "FlashInferExperts disabled: w1/w2 must be torch.uint8 "
            f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
            f"float32, float16, or bfloat16 (got {hidden_states.dtype}).")
        return False
    return True