Skip to content

vllm.model_executor.layers.fused_moe.moe_permute_unpermute

_moe_permute

_moe_permute(
    curr_hidden_states: Tensor,
    a1q_scale: Optional[Tensor],
    curr_topk_ids: Tensor,
    global_num_experts: int,
    expert_map: Optional[Tensor],
    block_m: int,
) -> tuple[
    Tensor, Optional[Tensor], Tensor, Tensor, Tensor
]

Determine the sorted_token_ids, expert_ids for the given problem size. Permute the hidden states and scales according to sorted_token_ids.

Source code in vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
def _moe_permute(
    curr_hidden_states: torch.Tensor,
    a1q_scale: Optional[torch.Tensor],
    curr_topk_ids: torch.Tensor,
    global_num_experts: int,
    expert_map: Optional[torch.Tensor],
    block_m: int,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
           torch.Tensor]:
    """
    Determine the sorted_token_ids, expert_ids for the given problem size.
    Permute the hidden states and scales according to `sorted_token_ids`.
    """
    top_k_num = curr_topk_ids.size(1)

    tokens_in_chunk = curr_hidden_states.size(0)

    sorted_token_ids, expert_ids, num_tokens_post_padded = (
        moe_align_block_size(curr_topk_ids,
                             block_m,
                             global_num_experts,
                             expert_map,
                             pad_sorted_ids=True))

    inv_perm: Optional[torch.Tensor] = None

    num_tokens = top_k_num * tokens_in_chunk
    expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
    inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]

    # Permute according to sorted token ids.
    sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)

    curr_hidden_states = _fp8_perm(curr_hidden_states,
                                   sorted_token_ids // top_k_num)

    if a1q_scale is not None:
        a1q_scale = a1q_scale[sorted_token_ids // top_k_num]

    return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
            inv_perm)

_moe_unpermute_and_reduce

_moe_unpermute_and_reduce(
    out: Tensor,
    curr_hidden: Tensor,
    inv_perm: Optional[Tensor],
    topk_weight: Tensor,
    apply_router_weight_on_input: bool,
) -> None

Unpermute the final result and apply topk_weights, then perform the final reduction on the hidden states.

Source code in vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
def _moe_unpermute_and_reduce(
    out: torch.Tensor,
    curr_hidden: torch.Tensor,
    inv_perm: Optional[torch.Tensor],
    topk_weight: torch.Tensor,
    apply_router_weight_on_input: bool,
) -> None:
    """
    Unpermute the final result and apply topk_weights, then perform the final
    reduction on the hidden states.
    """
    M, topk = topk_weight.size()
    K = curr_hidden.size(-1)
    if inv_perm is not None:
        curr_hidden = curr_hidden[inv_perm, ...]
    curr_hidden = curr_hidden.view(-1, topk, K)
    if not apply_router_weight_on_input:
        curr_hidden.mul_(topk_weight.view(M, -1, 1))
    ops.moe_sum(curr_hidden, out)

moe_permute

moe_permute(
    hidden_states: Tensor,
    a1q_scale: Optional[Tensor],
    topk_ids: Tensor,
    n_expert: int,
    n_local_expert: int = -1,
    expert_map: Optional[Tensor] = None,
    align_block_size: Optional[int] = None,
    fill_invalid_expert: int = -1,
    permuted_hidden_states: Optional[Tensor] = None,
) -> tuple[
    Tensor, Optional[Tensor], Tensor, Tensor, Tensor
]

This function expands and permutes activation to gather uncontinuous tokens for each expert. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - topk_ids (torch.Tensor): topk expert route id for each token. - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices - permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor. If None, the output tensor will be created in this function. Returns: - permuted_hidden_states (torch.Tensor): permuted activation. - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states if original scale not per-tensor scaling - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute. - permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden. - m_indices: m_indices for grouped gemm in deepgemm,m_indices[i] records the group which the j-th row of the LHS belong to.`

Source code in vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
def moe_permute(
    hidden_states: torch.Tensor,
    a1q_scale: Optional[torch.Tensor],
    topk_ids: torch.Tensor,
    n_expert: int,
    n_local_expert: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    align_block_size: Optional[int] = None,
    fill_invalid_expert: int = -1,
    permuted_hidden_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
           torch.Tensor]:
    """
    This function expands and permutes activation to gather uncontinuous tokens
      for each expert.
    Parameters:
    - hidden_states (torch.Tensor): The input tensor to the MoE layer.
    - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
    - topk_ids (torch.Tensor): topk expert route id for each token.
    - n_expert (int): The number of expert.
    - n_local_expert (int): The number of expert in current EP rank.
    - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
        from the global expert space to the local expert space of the expert
        parallel shard.
    - align_block_size (Optional[int]): align group gemm block size for deepgemm
    - fill_invalid_expert(int): fill expert id in m_indices for invalid expert
      to workaround DeepGemm unsupported -1 in m_indices
    - permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
        If None, the output tensor will be created in this function.
    Returns:
    - permuted_hidden_states (torch.Tensor): permuted activation.
    - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
        if original scale not per-tensor scaling
    - expert_first_token_offset (torch.Tensor): offset of the first token
       of each expert for standard grouped gemm. if enable 'align_block_size'
       expert_first_token_offset will align up to 'align_block_size'.
    - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
    - permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
    - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
    the group which the j-th row of the LHS belong to.`
    """
    n_token, n_hidden = hidden_states.size()
    topk = topk_ids.size(1)
    assert (n_hidden * hidden_states.element_size()
            ) % 16 == 0, "permue kernel need hidden dim align to 16B"
    permuted_row_size = n_token * topk
    if align_block_size is not None:
        permuted_row_size = (permuted_row_size + n_expert *
                             (align_block_size - 1) + align_block_size -
                             1) // align_block_size * align_block_size
    if n_local_expert == -1:
        n_local_expert = n_expert
    if permuted_hidden_states is None:
        permuted_hidden_states = torch.empty(
            (permuted_row_size, n_hidden),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )
    assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), (
        f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}"
        f" but got {permuted_hidden_states.size()}")

    token_expert_indices = torch.arange(0,
                                        n_token * topk,
                                        dtype=torch.int32,
                                        device=hidden_states.device).reshape(
                                            (n_token, topk))

    m_indices = torch.full((permuted_row_size, ),
                           fill_invalid_expert,
                           dtype=torch.int32,
                           device=hidden_states.device)
    expert_first_token_offset = torch.empty(n_local_expert + 1,
                                            dtype=torch.int64,
                                            device=hidden_states.device)
    permuted_idx = torch.full((permuted_row_size, ),
                              n_token * topk,
                              dtype=torch.int32,
                              device=hidden_states.device)
    inv_permuted_idx = torch.empty((n_token, topk),
                                   dtype=torch.int32,
                                   device=hidden_states.device)
    topk_ids = topk_ids.to(torch.int32)
    torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices,
                                 expert_map, n_expert, n_local_expert, topk,
                                 align_block_size, permuted_hidden_states,
                                 expert_first_token_offset, inv_permuted_idx,
                                 permuted_idx, m_indices)

    if a1q_scale is not None and a1q_scale.dim() > 1:
        a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
                              topk]
    return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
            inv_permuted_idx.flatten(), m_indices)

moe_permute_unpermute_supported

moe_permute_unpermute_supported()
Source code in vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
def moe_permute_unpermute_supported():
    return torch.ops._moe_C.moe_permute_unpermute_supported()

moe_unpermute

moe_unpermute(
    out: Tensor,
    permuted_hidden_states: Tensor,
    topk_weights: Tensor,
    inv_permuted_idx: Tensor,
    expert_first_token_offset: Optional[Tensor] = None,
) -> None

This function expands and permutes activation to gathering uncontinuous tokens for each expert. Parameters: - out (torch.Tensor): output tensor - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. - expert_first_token_offset (Optional[torch.Tensor]): offset of the first token of each expert for grouped gemm. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation tensor.

Source code in vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
def moe_unpermute(
    out: torch.Tensor,
    permuted_hidden_states: torch.Tensor,
    topk_weights: torch.Tensor,
    inv_permuted_idx: torch.Tensor,
    expert_first_token_offset: Optional[torch.Tensor] = None,
) -> None:
    """
    This function expands and permutes activation to gathering uncontinuous
      tokens for each expert.
    Parameters:
    - out (torch.Tensor): output tensor
    - permuted_hidden_states (torch.Tensor): permuted activation.
    - topk_weights (torch.Tensor): topk expert route weight for each token.
    - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute.
    - expert_first_token_offset (Optional[torch.Tensor]): offset of the first 
      token of each expert for grouped gemm.
    Returns:
    - hidden_states (torch.Tensor): The reduced and unpermuted activation
      tensor.
    """
    topk = topk_weights.size(1)
    n_hidden = permuted_hidden_states.size(-1)
    assert (n_hidden * permuted_hidden_states.element_size()
            ) % 16 == 0, "unpermue kernel need hidden dim align to 16B"

    torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
                                   inv_permuted_idx, expert_first_token_offset,
                                   topk, out)