Fused MoE Modular Kernel¶
Introduction¶
FusedMoEModularKernel is implemented here
Based on the format of the input activations, FusedMoE implementations are broadly classified into 2 types.
- Contiguous / Standard / Non-Batched, and
- Batched
Note
The terms Contiguous, Standard, and Non-Batched are used interchangeably throughout the document.
The input activation format completely depends on the All2All Dispatch being used.
- In the Contiguous variant, the All2All Dispatch returns the activations as a contiguous tensor of shape (M, K) along with TopK Ids and TopK weights of shape (M, num_topk). Look at
DeepEPHTPrepareAndFinalize
for an example. - In the Batched variant, the All2All Dispatch returns the activations as a tensor of shape (num_experts, max_tokens, K). Here, the activations/tokens that subscribe to the same expert are batched together. Note that not all entries of the tensor are valid. The activations tensor is typically accompanied by an
expert_num_tokens
tensor of sizenum_experts
, whereexpert_num_tokens[i]
indicates the number of valid tokens that subscribe to the ith expert. Look atPplxPrepareAndFinalize
orDeepEPLLPrepareAndFinalize
for an example.
The FusedMoE operation is generally made of multiple operations, in both the Contiguous and Batched variants, as described in the diagrams below
Note
The main difference, in terms of operations, between the Batched and Non-Batched cases is the Permute / Unpermute operations. All other operations remain.
Motivation¶
As can be seen from the diagrams, there are a lot of operations and there can be a variety of implementations for each operation. The set of ways the operations can be put together to make a valid FusedMoE implementation quickly becomes intractable. The Modular Kernel framework addresses this issue, by grouping the operations into logical components. This broad categorization makes the combinations manageable and prevents code-duplication. This also decouples the All2All Dispatch & Combine implementations from the FusedMoE implementations and allows for their independent development and testing. Furthermore, the Modular Kernel framework introduces Abstract classes for the different components thus providing a well-defined skeleton for future implementations.
The rest of the document will focus on the Contiguous / Non-Batched case. Extrapolating to the Batched case should be straight-forward.
ModularKernel Components¶
FusedMoEModularKernel splits the FusedMoE operation into 3 parts,
- TopKWeightAndReduce
- FusedMoEPrepareAndFinalize
- FusedMoEPermuteExpertsUnpermute
TopKWeightAndReduce¶
The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the FusedMoEPermuteExpertsUnpermute
is responsible for the Unpermute and FusedMoEPrepareAndFinalize
is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the FusedMoEPermuteExpertsUnpermute
. But some implementations choose to do it FusedMoEPrepareAndFinalize
. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class.
Please find the implementations of TopKWeightAndReduce here.
FusedMoEPrepareAndFinalize::finalize()
method accepts a TopKWeightAndReduce
argument that is invoked inside the method. The FusedMoEModularKernel
acts as a bridge between the FusedMoEPermuteExpertsUnpermute
and FusedMoEPerpareAndFinalize
implementations to determine where the TopK Weight Application and Reduction happens.
FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl
method returnsTopKWeightAndReduceNoOp
if theFusedMoEPermuteExpertsUnpermute
implementation does the weight application and reduction itself.FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl
method returnsTopKWeightAndReduceContiguous
/TopKWeightAndReduceNaiveBatched
/TopKWeightAndReduceDelegate
if theFusedMoEPermuteExpertsUnpermute
implementation needs theFusedMoEPrepareAndFinalize::finalize()
to do the weight application and reduction.
FusedMoEPrepareAndFinalize¶
The FusedMoEPrepareAndFinalize
abstract class exposes prepare
and finalize
functions. The prepare
function is responsible for input activation Quantization and All2All Dispatch. The finalize
function is responsible for invoking the All2All Combine. Additionally the finalize
function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
FusedMoEPermuteExpertsUnpermute¶
The FusedMoEPermuteExpertsUnpermute
class is where the crux of the MoE operations happen. The FusedMoEPermuteExpertsUnpermute
abstract class exposes a few important functions,
- apply()
- workspace_shapes()
- finalize_weight_and_reduce_impl()
apply()¶
The apply
method is where the implementations perform
- Permute
- Matmul with weight W1
- Act + Mul
- Quantization
- Matmul with weight W2
- Unpermute
- Maybe TopK Weight Application + Reduction
workspace_shapes()¶
The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in FusedMoEModularKernel::forward()
and passed on to the FusedMoEPermuteExpertsUnpermute::apply()
method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation.
finalize_weight_and_reduce_impl()¶
It is sometimes efficient to perform TopK weight application and Reduction inside the FusedMoEPermuteExpertsUnpermute::apply()
. Find an example here. We have a TopKWeightAndReduce
abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section. FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl()
returns the TopKWeightAndReduce
object that the implementation wants the FusedMoEPrepareAndFinalize::finalize()
to use.
FusedMoEModularKernel¶
FusedMoEModularKernel
is composed of the FusedMoEPrepareAndFinalize
and FusedMoEPermuteExpertsUnpermute
objects. FusedMoEModularKernel
pseudocode/sketch,
class FusedMoEModularKernel:
def __init__(self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
def forward(self, DP_A):
Aq, A_scale, _, _, _ = self.prepare_finalize.prepare(DP_A, ...)
workspace13_shape, workspace2_shape, _, _ = self.fused_experts.workspace_shapes(...)
# allocate workspaces
workspace_13 = torch.empty(workspace13_shape, ...)
workspace_2 = torch.empty(workspace2_shape, ...)
# execute fused_experts
fe_out = self.fused_experts.apply(Aq, A_scale, workspace13, workspace2, ...)
# war_impl is an object of type TopKWeightAndReduceNoOp if the fused_experts implementations
# performs the TopK Weight Application and Reduction.
war_impl = self.fused_experts.finalize_weight_and_reduce_impl()
output = self.prepare_finalize.finalize(fe_out, war_impl,...)
return output
How-To¶
How To Add a FusedMoEPrepareAndFinalize Type¶
Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example,
- PplxPrepareAndFinalize type is backed by Pplx All2All kernels,
- DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and
- DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels.
Step 1: Add an All2All manager¶
The purpose of the All2All Manager is to setup the All2All kernel implementations. The FusedMoEPrepareAndFinalize
implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations here.
Step 2: Add a FusedMoEPrepareAndFinalize Type¶
This section describes the significance of the various functions exposed by the FusedMoEPrepareAndFinalize
abstract class.
FusedMoEPrepareAndFinalize::prepare()
: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
FusedMoEPrepareAndFinalize::finalize()
: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.
FusedMoEPrepareAndFinalize::activation_format()
: Return FusedMoEActivationFormat.BatchedExperts
if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return FusedMoEActivationFormat.Standard
otherwise.
FusedMoEPrepareAndFinalize::topk_indices_dtype()
: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the FusedMoe::select_experts
function so it could be respected. If there are no strict requirements return None.
FusedMoEPrepareAndFinalize::max_num_tokens_per_rank()
: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once.
FusedMoEPrepareAndFinalize::num_dispatchers()
: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank().
We suggest picking an already existing FusedMoEPrepareAndFinalize
implementation that matches your All2All implementation closely and using it as a reference.
How To Add a FusedMoEPermuteExpertsUnpermute Type¶
FusedMoEPermuteExpertsUnpermute performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows,
FusedMoEPermuteExpertsUnpermute::activation_formats()
: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
FusedMoEPermuteExpertsUnpermute::supports_chunking()
: Return True if the implementation supports chunking. Typically implementations that input FusedMoEActivationFormat.Standard
support chunking and FusedMoEActivationFormat.BatchedExperts
do not.
FusedMoEPermuteExpertsUnpermute::supports_expert_map()
: Return True if the implementation supports expert map.
FusedMoEPermuteExpertsUnpermute::workspace_shapes()
/ FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl
/ FusedMoEPermuteExpertsUnpermute::apply
: Refer to FusedMoEPermuteExpertsUnpermute
section above.
FusedMoEModularKernel Initialization¶
FusedMoEMethodBase
class has 3 methods that are collectively responsible in creating the FusedMoEModularKernel
object. They are,
- maybe_make_prepare_finalize,
- select_gemm_impl, and
- init_prepare_finalize
maybe_make_prepare_finalize¶
The maybe_make_prepare_finalize
method is responsible for constructing an instance of FusedMoEPrepareAndFinalize
when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the FusedMoEPrepareAndFinalize
objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. ModelOptNvFp4FusedMoE
can construct a FlashInferCutlassMoEPrepareAndFinalize
for the EP+TP case. Please refer to the implementations in,
ModelOptNvFp4FusedMoE
select_gemm_impl¶
The select_gemm_impl
method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate FusedMoEPermuteExpertsUnpermute
object. Please refer to the implementations in,
UnquantizedFusedMoEMethod
CompressedTensorsW8A8Fp8MoEMethod
CompressedTensorsW8A8Fp8MoECutlassMethod
Fp8MoEMethod
ModelOptNvFp4FusedMoE
derived classes.
init_prepare_finalize¶
Based on the input and env settings, the init_prepare_finalize
method creates the appropriate FusedMoEPrepareAndFinalize
object. The method then queries select_gemm_impl
for the appropriate FusedMoEPermuteExpertsUnpermute
object and builds the FusedMoEModularKernel
object
Please take a look at init_prepare_finalize. Important: The FusedMoEMethodBase
derived classes use the FusedMoEMethodBase::fused_experts
object in their apply
methods. When settings permit the construction of a valid FusedMoEModularKernel
object, we override FusedMoEMethodBase::fused_experts
with it. This essentially makes the derived classes agnostic to what FusedMoE implementation is used.
How To Unit Test¶
We have FusedMoEModularKernel
unit tests at test_modular_kernel_combinations.py.
The unit test iterates through all combinations of FusedMoEPrepareAndFinalize
and FusedMoEPremuteExpertsUnpermute
types and if they are compatible, runs some correctness tests. If you are adding some FusedMoEPrepareAndFinalize
/ FusedMoEPermuteExpertsUnpermute
implementations,
- Add the implementation type to
MK_ALL_PREPARE_FINALIZE_TYPES
andMK_FUSED_EXPERT_TYPES
in mk_objects.py respectively. - Update
Config::is_batched_prepare_finalize()
,Config::is_batched_fused_experts()
,Config::is_standard_fused_experts()
,Config::is_fe_16bit_supported()
,Config::is_fe_fp8_supported()
,Config::is_fe_block_fp8_supported()
,Config::is_fe_supports_chunking()
methods in /tests/kernels/moe/modular_kernel_tools/common.py
Doing this will add the new implementation to the test suite.
How To Check FusedMoEPrepareAndFinalize
& FusedMoEPermuteExpertsUnpermute
Compatibility¶
The unit test file test_modular_kernel_combinations.py can also be executed as a standalone script. Example: python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts
As a side effect, this script can be used to test FusedMoEPrepareAndFinalize
& FusedMoEPermuteExpertsUnpermute
compatibility. When invoked with incompatible types, the script will error.
How To Profile¶
Please take a look at profile_modular_kernel.py The script can be used to generate Torch traces for a single FusedMoEModularKernel::forward()
call for any compatible FusedMoEPrepareAndFinalize
and FusedMoEPermuteExpertsUnpermute
types. Example: python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts
FusedMoEPrepareAndFinalize Implementations¶
The following table lists the FusedMoEPrepareAndFinalize
implementations at the time of writing,
Implementation | Type | Comments |
---|---|---|
DeepEPHTPrepareAndFinalize | Contiguous / Non-Batched | Uses the DeepEP High-Throughput all2all kernels. |
DeepEPLLPrepareAndFinalize | Batched | Uses the DeepEP Low-Latency all2all kernels. |
PplxPrepareAndFinalize | Batched | Uses the Perplexity all2all kernels. |
FlashInferCutlassMoEPrepareAndFinalize | Contiguous | |
MoEPrepareAndFinalizeNoEP | Contiguous | This implementation is used when there is no EP. i.e. no all2all kernels are invoked. |
BatchedPrepareAndFinalize | Batched | A reference prepare/finalize class that reorganizes the tokens into expert batched format, i.e. E x max_num_tokens x K. (Doesn’t use any all2all kernels. This is primarily used in unit testing) |
FusedMoEPermuteExpertsUnpermute¶
The following table lists the FusedMoEPermuteExpertsUnpermute
implementations at the time of writing,
Implementation | Type | Comment |
---|---|---|
BatchedDeepGemmExperts | Batched | Uses the DeepGemm’s Masked Grouped Gemm kernels for the fused_moe operation. |
BatchedTritonExperts | Batched | Uses a Triton Kernel for the Batched matmuls. |
BatchedTritonOrDeepGemmExperts | Batched | Chooses either the BatchedDeepGemmExperts or BatchedTritonExperts based on environment settings. |
DeepGemmExperts | Contiguous / Non-Batched | Uses DeepGemm’s Grouped Gemm kernels for fused_moe operation. |
TritonExperts | Contiguous / Non-Batched | Uses a Triton Kernel for fused_moe matmuls. |
TritonOrDeepGemmExperts | Contiguous / Non-Batched | Chooses either the DeepGemmExperts or TritonExperts based on fused_moe inputs. |
CutlassExpertsFP8 | Supports both Batched and Contiguous formats | Uses Cutlass Grouped Gemm implementations for the fp8 matmuls. |
CutlassExpertsFP4 | Supports both Batched and Contiguous formats | Uses Cutlass Grouped Gemm implementations for the fp4 matmuls. |
FlashInferExperts | Contiguous | Uses fused_moe operation from FlashInfer |
NaiveBatchedExperts | Batched | Reference Batched Experts implementation. Primarily used in unit tests. |