Skip to content

vllm.model_executor.models.minimax_text_01

Inference-only MiniMaxText01 model.

MiniMaxText01Attention

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01Attention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: int,
        num_kv_heads: int,
        rotary_dim: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        sliding_window: Optional[int] = None,
        quant_config: Optional[QuantizationConfig] = None,
        layer_idx: int = None,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "mha",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx

        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window
        self.prefix = prefix

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        return

    def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
                **kwargs) -> torch.Tensor:
        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        if envs.VLLM_USE_V1:
            if attn_metadata is not None:
                q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
                    positions, q, k)
        else:
            q, k = attn_metadata.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scaling,
    num_kv_heads=num_kv_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

head_dim instance-attribute

head_dim = head_dim

hidden_size instance-attribute

hidden_size = hidden_size

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

layer_idx instance-attribute

layer_idx = layer_idx

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

prefix instance-attribute

prefix = prefix

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rope_theta instance-attribute

rope_theta = rope_theta

scaling instance-attribute

scaling = head_dim ** -0.5

sliding_window instance-attribute

sliding_window = sliding_window

total_num_heads instance-attribute

total_num_heads = num_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_kv_heads

__init__

__init__(
    hidden_size: int,
    num_heads: int,
    head_dim: int,
    num_kv_heads: int,
    rotary_dim: int,
    max_position: int = 4096 * 32,
    rope_theta: float = 10000,
    sliding_window: Optional[int] = None,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = None,
    cache_config: Optional[CacheConfig] = None,
    prefix: str = "mha",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    hidden_size: int,
    num_heads: int,
    head_dim: int,
    num_kv_heads: int,
    rotary_dim: int,
    max_position: int = 4096 * 32,
    rope_theta: float = 10000,
    sliding_window: Optional[int] = None,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = None,
    cache_config: Optional[CacheConfig] = None,
    prefix: str = "mha",
) -> None:
    super().__init__()
    self.layer_idx = layer_idx

    self.hidden_size = hidden_size
    tp_size = get_tensor_model_parallel_world_size()
    self.total_num_heads = num_heads
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.total_num_kv_heads = num_kv_heads
    if self.total_num_kv_heads >= tp_size:
        assert self.total_num_kv_heads % tp_size == 0
    else:
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    self.head_dim = head_dim

    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim
    self.scaling = self.head_dim**-0.5
    self.rope_theta = rope_theta
    self.sliding_window = sliding_window
    self.prefix = prefix

    self.qkv_proj = QKVParallelLinear(
        hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.o_proj = RowParallelLinear(
        self.total_num_heads * self.head_dim,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )
    self.attn = Attention(
        self.num_heads,
        self.head_dim,
        self.scaling,
        num_kv_heads=self.num_kv_heads,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
    )
    return

forward

forward(
    hidden_states: Tensor, positions: Tensor, **kwargs
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
            **kwargs) -> torch.Tensor:
    forward_context = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    if envs.VLLM_USE_V1:
        if attn_metadata is not None:
            q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
                positions, q, k)
    else:
        q, k = attn_metadata.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v)
    output, _ = self.o_proj(attn_output)
    return output

MiniMaxText01DecoderLayer

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01DecoderLayer(nn.Module):

    def __init__(
        self,
        config: MiniMaxConfig,
        model_config: Optional[ModelConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        expert_num: int = 1,
        layer_id: int = None,
        linear_layer_id: Optional[int] = None,
        prefix: str = "decoder",
    ) -> None:
        self._ilayer = layer_id
        self._irank = get_tensor_model_parallel_rank()
        self.prefix = prefix
        super().__init__()

        self.hidden_size = config.hidden_size
        self.expert_num = expert_num

        rope_theta = getattr(config, "rope_theta", 10000)

        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = config.hidden_size // config.num_attention_heads
        if hasattr(config, "max_model_len") and isinstance(
                config.max_model_len, int):
            max_position_embeddings = min(config.max_position_embeddings,
                                          config.max_model_len)
        if config.attention_type == 0:
            use_headxdim = True
            hidden_inner = (head_dim * config.num_attention_heads
                            if use_headxdim else config.hidden_size)
            self.self_attn = MiniMaxText01LinearAttention(
                hidden_size=self.hidden_size,
                hidden_inner_size=hidden_inner,
                num_heads=config.num_attention_heads,
                head_dim=head_dim,
                max_position=max_position_embeddings,
                block_size=config.block if hasattr(config, "block") else 256,
                num_hidden_layer=config.num_hidden_layers,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                layer_idx=self._ilayer,
                linear_layer_idx=linear_layer_id,
                prefix=prefix)
        elif config.attention_type == 1:
            self.self_attn = MiniMaxText01Attention(
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
                head_dim=head_dim,
                rotary_dim=config.rotary_dim
                if hasattr(config, "rotary_dim") else head_dim,
                num_kv_heads=config.num_key_value_heads,
                max_position=max_position_embeddings,
                rope_theta=rope_theta,
                sliding_window=config.sliding_window,
                quant_config=quant_config,
                layer_idx=self._ilayer,
                cache_config=cache_config,
                prefix=prefix)
        else:
            raise ValueError(
                f"Unsupported attention type: {self.config.attention_type}")

        if expert_num == 1:
            self.mlp = MiniMaxText01MLP(
                hidden_size=self.hidden_size,
                intermediate_size=config.intermediate_size,
                quant_config=quant_config,
                layer_idx=self._ilayer,
                prefix=prefix)
        else:
            self.block_sparse_moe = MiniMaxText01MoE(
                num_experts=expert_num,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                layer_idx=self._ilayer,
                quant_config=quant_config,
                prefix=prefix)

        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
        if config.attention_type == 0:
            self.layernorm_attention_alpha = getattr(
                config, 'layernorm_linear_attention_alpha',
                getattr(config, 'linear_attn_alpha_factor', 1))
            self.layernorm_attention_beta = getattr(
                config, 'layernorm_linear_attention_beta',
                getattr(config, 'linear_attn_beta_factor', 1))
        else:
            self.layernorm_attention_alpha = getattr(
                config, 'layernorm_full_attention_alpha',
                getattr(config, 'full_attn_alpha_factor', 1))
            self.layernorm_attention_beta = getattr(
                config, 'layernorm_full_attention_beta',
                getattr(config, 'full_attn_beta_factor', 1))
        self.layernorm_mlp_alpha = getattr(
            config, 'layernorm_mlp_alpha',
            getattr(config, 'mlp_alpha_factor', 1))
        self.layernorm_mlp_beta = getattr(
            config, 'layernorm_mlp_beta', getattr(config, 'mlp_beta_factor',
                                                  1))
        self.postnorm = getattr(config, 'postnorm', False)
        self.shared_moe = False

        shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
        if isinstance(shared_intermediate, list):
            shared_intermediate = shared_intermediate[
                layer_id] if layer_id < len(shared_intermediate) else 0
        if shared_intermediate > 0:
            self.shared_moe = True
            self.shared_mlp = MiniMaxText01MLP(
                hidden_size=self.hidden_size,
                intermediate_size=shared_intermediate,
                quant_config=quant_config,
                layer_idx=self._ilayer,
                prefix=prefix)
            self.coefficient = ReplicatedLinear(
                self.hidden_size,
                1,
                bias=False,
                quant_config=quant_config,
                params_dtype=torch.float32,
            )
            self.coefficient.weight.weight_loader = (
                self.shared_moe_coefficient_loader)
            self.shared_moe_mode = getattr(config, 'shared_moe_mode',
                                           'softmax')
        return

    def forward(self,
                hidden_states: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: Union[list[dict], Optional[torch.Tensor]],
                attn_metadata: AttentionMetadata,
                residual: Optional[torch.Tensor],
                is_warmup: bool = False,
                **kwargs) -> tuple[torch.Tensor, torch.Tensor]:

        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        layernorm_input = hidden_states
        layernorm_output = self.input_layernorm(layernorm_input)
        residual = layernorm_output if self.postnorm else layernorm_input
        self_attention_output = self.self_attn(
            hidden_states=layernorm_output,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
        )

        residual = residual * self.layernorm_attention_alpha
        self_attention_output = (self_attention_output *
                                 self.layernorm_attention_beta)

        layernorm_input = residual + self_attention_output
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        residual = layernorm_output if self.postnorm else layernorm_input

        if self.expert_num == 1:
            hidden_states = self.mlp(layernorm_output)
        else:
            moe_hidden_states = self.block_sparse_moe(
                copy.deepcopy(layernorm_output))
            if self.shared_moe:
                before_moe_dtype = layernorm_output.dtype
                moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
                output_mlp = self.shared_mlp(layernorm_output).to(
                    torch.float32)

                coef, _ = self.coefficient(layernorm_output.to(torch.float32))

                if self.shared_moe_mode == 'softmax':
                    coef = torch.nn.functional.softmax(coef, dim=-1)
                    hidden_states = moe_hidden_fp32 * (
                        1 - coef) + output_mlp * coef
                elif self.shared_moe_mode == 'sigmoid':
                    coef = torch.nn.functional.sigmoid(coef)
                    hidden_states = moe_hidden_fp32 * (
                        1 - coef) + output_mlp * coef

                hidden_states = hidden_states.to(before_moe_dtype)
            else:
                hidden_states = moe_hidden_states

        residual = residual * self.layernorm_mlp_alpha
        hidden_states = hidden_states * self.layernorm_mlp_beta

        hidden_states = residual + hidden_states

        return hidden_states, None

    @staticmethod
    def shared_moe_coefficient_loader(param: torch.Tensor,
                                      loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()

        param.data.copy_(loaded_weight.to(torch.float32))
        return

_ilayer instance-attribute

_ilayer = layer_id

_irank instance-attribute

block_sparse_moe instance-attribute

block_sparse_moe = MiniMaxText01MoE(
    num_experts=expert_num,
    top_k=num_experts_per_tok,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    layer_idx=_ilayer,
    quant_config=quant_config,
    prefix=prefix,
)

coefficient instance-attribute

coefficient = ReplicatedLinear(
    hidden_size,
    1,
    bias=False,
    quant_config=quant_config,
    params_dtype=float32,
)

expert_num instance-attribute

expert_num = expert_num

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

layernorm_attention_alpha instance-attribute

layernorm_attention_alpha = getattr(
    config,
    "layernorm_linear_attention_alpha",
    getattr(config, "linear_attn_alpha_factor", 1),
)

layernorm_attention_beta instance-attribute

layernorm_attention_beta = getattr(
    config,
    "layernorm_linear_attention_beta",
    getattr(config, "linear_attn_beta_factor", 1),
)

layernorm_mlp_alpha instance-attribute

layernorm_mlp_alpha = getattr(
    config,
    "layernorm_mlp_alpha",
    getattr(config, "mlp_alpha_factor", 1),
)

layernorm_mlp_beta instance-attribute

layernorm_mlp_beta = getattr(
    config,
    "layernorm_mlp_beta",
    getattr(config, "mlp_beta_factor", 1),
)

mlp instance-attribute

mlp = MiniMaxText01MLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    quant_config=quant_config,
    layer_idx=_ilayer,
    prefix=prefix,
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

postnorm instance-attribute

postnorm = getattr(config, 'postnorm', False)

prefix instance-attribute

prefix = prefix

self_attn instance-attribute

self_attn = MiniMaxText01LinearAttention(
    hidden_size=hidden_size,
    hidden_inner_size=hidden_inner,
    num_heads=num_attention_heads,
    head_dim=head_dim,
    max_position=max_position_embeddings,
    block_size=block if hasattr(config, "block") else 256,
    num_hidden_layer=num_hidden_layers,
    model_config=model_config,
    cache_config=cache_config,
    quant_config=quant_config,
    layer_idx=_ilayer,
    linear_layer_idx=linear_layer_id,
    prefix=prefix,
)

shared_mlp instance-attribute

shared_mlp = MiniMaxText01MLP(
    hidden_size=hidden_size,
    intermediate_size=shared_intermediate,
    quant_config=quant_config,
    layer_idx=_ilayer,
    prefix=prefix,
)

shared_moe instance-attribute

shared_moe = False

shared_moe_mode instance-attribute

shared_moe_mode = getattr(
    config, "shared_moe_mode", "softmax"
)

__init__

__init__(
    config: MiniMaxConfig,
    model_config: Optional[ModelConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    expert_num: int = 1,
    layer_id: int = None,
    linear_layer_id: Optional[int] = None,
    prefix: str = "decoder",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    config: MiniMaxConfig,
    model_config: Optional[ModelConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    expert_num: int = 1,
    layer_id: int = None,
    linear_layer_id: Optional[int] = None,
    prefix: str = "decoder",
) -> None:
    self._ilayer = layer_id
    self._irank = get_tensor_model_parallel_rank()
    self.prefix = prefix
    super().__init__()

    self.hidden_size = config.hidden_size
    self.expert_num = expert_num

    rope_theta = getattr(config, "rope_theta", 10000)

    head_dim = getattr(config, "head_dim", None)
    if head_dim is None:
        head_dim = config.hidden_size // config.num_attention_heads
    if hasattr(config, "max_model_len") and isinstance(
            config.max_model_len, int):
        max_position_embeddings = min(config.max_position_embeddings,
                                      config.max_model_len)
    if config.attention_type == 0:
        use_headxdim = True
        hidden_inner = (head_dim * config.num_attention_heads
                        if use_headxdim else config.hidden_size)
        self.self_attn = MiniMaxText01LinearAttention(
            hidden_size=self.hidden_size,
            hidden_inner_size=hidden_inner,
            num_heads=config.num_attention_heads,
            head_dim=head_dim,
            max_position=max_position_embeddings,
            block_size=config.block if hasattr(config, "block") else 256,
            num_hidden_layer=config.num_hidden_layers,
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
            layer_idx=self._ilayer,
            linear_layer_idx=linear_layer_id,
            prefix=prefix)
    elif config.attention_type == 1:
        self.self_attn = MiniMaxText01Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            head_dim=head_dim,
            rotary_dim=config.rotary_dim
            if hasattr(config, "rotary_dim") else head_dim,
            num_kv_heads=config.num_key_value_heads,
            max_position=max_position_embeddings,
            rope_theta=rope_theta,
            sliding_window=config.sliding_window,
            quant_config=quant_config,
            layer_idx=self._ilayer,
            cache_config=cache_config,
            prefix=prefix)
    else:
        raise ValueError(
            f"Unsupported attention type: {self.config.attention_type}")

    if expert_num == 1:
        self.mlp = MiniMaxText01MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
            layer_idx=self._ilayer,
            prefix=prefix)
    else:
        self.block_sparse_moe = MiniMaxText01MoE(
            num_experts=expert_num,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            layer_idx=self._ilayer,
            quant_config=quant_config,
            prefix=prefix)

    self.input_layernorm = RMSNorm(config.hidden_size,
                                   eps=config.rms_norm_eps)
    self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                            eps=config.rms_norm_eps)
    if config.attention_type == 0:
        self.layernorm_attention_alpha = getattr(
            config, 'layernorm_linear_attention_alpha',
            getattr(config, 'linear_attn_alpha_factor', 1))
        self.layernorm_attention_beta = getattr(
            config, 'layernorm_linear_attention_beta',
            getattr(config, 'linear_attn_beta_factor', 1))
    else:
        self.layernorm_attention_alpha = getattr(
            config, 'layernorm_full_attention_alpha',
            getattr(config, 'full_attn_alpha_factor', 1))
        self.layernorm_attention_beta = getattr(
            config, 'layernorm_full_attention_beta',
            getattr(config, 'full_attn_beta_factor', 1))
    self.layernorm_mlp_alpha = getattr(
        config, 'layernorm_mlp_alpha',
        getattr(config, 'mlp_alpha_factor', 1))
    self.layernorm_mlp_beta = getattr(
        config, 'layernorm_mlp_beta', getattr(config, 'mlp_beta_factor',
                                              1))
    self.postnorm = getattr(config, 'postnorm', False)
    self.shared_moe = False

    shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
    if isinstance(shared_intermediate, list):
        shared_intermediate = shared_intermediate[
            layer_id] if layer_id < len(shared_intermediate) else 0
    if shared_intermediate > 0:
        self.shared_moe = True
        self.shared_mlp = MiniMaxText01MLP(
            hidden_size=self.hidden_size,
            intermediate_size=shared_intermediate,
            quant_config=quant_config,
            layer_idx=self._ilayer,
            prefix=prefix)
        self.coefficient = ReplicatedLinear(
            self.hidden_size,
            1,
            bias=False,
            quant_config=quant_config,
            params_dtype=torch.float32,
        )
        self.coefficient.weight.weight_loader = (
            self.shared_moe_coefficient_loader)
        self.shared_moe_mode = getattr(config, 'shared_moe_mode',
                                       'softmax')
    return

forward

forward(
    hidden_states: Tensor,
    positions: Tensor,
    kv_caches: Union[list[dict], Optional[Tensor]],
    attn_metadata: AttentionMetadata,
    residual: Optional[Tensor],
    is_warmup: bool = False,
    **kwargs,
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self,
            hidden_states: torch.Tensor,
            positions: torch.Tensor,
            kv_caches: Union[list[dict], Optional[torch.Tensor]],
            attn_metadata: AttentionMetadata,
            residual: Optional[torch.Tensor],
            is_warmup: bool = False,
            **kwargs) -> tuple[torch.Tensor, torch.Tensor]:

    forward_context = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    layernorm_input = hidden_states
    layernorm_output = self.input_layernorm(layernorm_input)
    residual = layernorm_output if self.postnorm else layernorm_input
    self_attention_output = self.self_attn(
        hidden_states=layernorm_output,
        positions=positions,
        kv_caches=kv_caches,
        attn_metadata=attn_metadata,
    )

    residual = residual * self.layernorm_attention_alpha
    self_attention_output = (self_attention_output *
                             self.layernorm_attention_beta)

    layernorm_input = residual + self_attention_output
    layernorm_output = self.post_attention_layernorm(layernorm_input)
    residual = layernorm_output if self.postnorm else layernorm_input

    if self.expert_num == 1:
        hidden_states = self.mlp(layernorm_output)
    else:
        moe_hidden_states = self.block_sparse_moe(
            copy.deepcopy(layernorm_output))
        if self.shared_moe:
            before_moe_dtype = layernorm_output.dtype
            moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
            output_mlp = self.shared_mlp(layernorm_output).to(
                torch.float32)

            coef, _ = self.coefficient(layernorm_output.to(torch.float32))

            if self.shared_moe_mode == 'softmax':
                coef = torch.nn.functional.softmax(coef, dim=-1)
                hidden_states = moe_hidden_fp32 * (
                    1 - coef) + output_mlp * coef
            elif self.shared_moe_mode == 'sigmoid':
                coef = torch.nn.functional.sigmoid(coef)
                hidden_states = moe_hidden_fp32 * (
                    1 - coef) + output_mlp * coef

            hidden_states = hidden_states.to(before_moe_dtype)
        else:
            hidden_states = moe_hidden_states

    residual = residual * self.layernorm_mlp_alpha
    hidden_states = hidden_states * self.layernorm_mlp_beta

    hidden_states = residual + hidden_states

    return hidden_states, None

shared_moe_coefficient_loader staticmethod

shared_moe_coefficient_loader(
    param: Tensor, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def shared_moe_coefficient_loader(param: torch.Tensor,
                                  loaded_weight: torch.Tensor) -> None:
    assert param.size() == loaded_weight.size()

    param.data.copy_(loaded_weight.to(torch.float32))
    return

MiniMaxText01ForCausalLM

Bases: Module, HasInnerState, IsHybrid

Source code in vllm/model_executor/models/minimax_text_01.py
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        self.config = config
        self.lora_config = lora_config

        if not hasattr(config, "sliding_window"):
            config.sliding_window = None

        self.CONCAT_FFN = True

        self.unpadded_vocab_size = self.config.vocab_size
        if hasattr(vllm_config.model_config, "max_model_len"):
            self.config.max_model_len = vllm_config.model_config.max_model_len
        self.model = MiniMaxText01Model(
            self.config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            quant_config=quant_config,
            scheduler_config=vllm_config.scheduler_config,
            prefix=maybe_prefix(prefix, "model"))
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                self.config.hidden_size,
                org_num_embeddings=self.config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
            )

            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    self.config.vocab_size)

        else:
            self.lm_head = PPMissingLayer()
        self.lm_head.float()
        flash_layer_count = sum(
            1 for attn_type in self.model.decoder_attention_types
            if attn_type == 1)
        self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
        return

    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
        return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
        return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
            batch_size)

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds, **kwargs)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states.float(),
                                       sampling_metadata)

        return logits

    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        def which_layer(name: str) -> int:
            if "layers" in name:
                after_layer = name.split("layers")[-1]
                return int(after_layer.split(".")[1])
            return None

        def is_linear_attn_layer(layer_idx: int) -> bool:
            if layer_idx is None or layer_idx >= len(
                    self.model.decoder_attention_types):
                return False
            return self.model.decoder_attention_types[layer_idx] == 0

        def is_moe_weight(name: str) -> bool:
            return "block_sparse_moe" in name and not name.endswith(".bias")

        def get_expert_id(param_name):
            pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.'
            match = re.search(pattern, param_name)
            if match:
                return match.group(1)
            return None

        def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
                                   self) -> None:
            if isinstance(self.config.num_local_experts, list):
                expert_params_mapping = [
                    ("w13_weight"
                     if weight_name in ["w1", "w3"] else "w2_weight",
                     f"experts.{expert_id}.{weight_name}.weight", expert_id)
                    for expert_id in range(max(self.config.num_local_experts))
                    for weight_name in ["w1", "w2", "w3"]
                ]
            else:
                expert_params_mapping = [
                    ("w13_scale" if weight_name in ["w1", "w3"] else
                     "w2_scale", f"{expert_id}.{weight_name}.weight_scale",
                     expert_id, weight_name)
                    for expert_id in range(self.config.num_local_experts)
                    for weight_name in ["w1", "w2", "w3"]
                ] + [("w13_weight" if weight_name in ["w1", "w3"] else
                      "w2_weight", f"{expert_id}.{weight_name}.weight",
                      expert_id, weight_name)
                     for expert_id in range(self.config.num_local_experts)
                     for weight_name in ["w1", "w2", "w3"]]
            for (param_name, weight_name, expert_id,
                 shard_id) in expert_params_mapping:
                name_expert_id = get_expert_id(name)
                if name_expert_id is not None and int(name_expert_id) != int(
                        expert_id):
                    continue
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if is_pp_missing_parameter(name, self):
                    return
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader = weight_loader_with_alias(name)(weight_loader)
                weight_loader(param,
                              loaded_weight,
                              weight_name,
                              expert_id=expert_id,
                              shard_id=shard_id)
                loaded_params.add(name)
                break
            else:
                if is_pp_missing_parameter(name, self):
                    return
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader = weight_loader_with_alias(name)(weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
            return

        def is_shared_mlp_weight(name: str) -> bool:
            return "shared_mlp" in name and not name.endswith(".bias")

        def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
                                   self) -> None:
            if not self.CONCAT_FFN:
                if "gate_proj" in name:
                    name = name.replace("gate_proj", "w1", 1)
                elif "up_proj" in name:
                    name = name.replace("up_proj", "w3", 1)
                elif "down_proj" in name:
                    name = name.replace("down_proj", "w2", 1)
            else:
                if "gate_proj" in name:
                    name = name.replace("gate_proj", "gate_up_proj", 1)
                    loaded_shard_id = 0
                elif "up_proj" in name:
                    name = name.replace("up_proj", "gate_up_proj", 1)
                    loaded_shard_id = 1
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            if not self.CONCAT_FFN:
                weight_loader(param, loaded_weight)
            else:
                if "gate_up_proj" in name:
                    weight_loader(param, loaded_weight, loaded_shard_id)
                elif "down_proj" in name:
                    weight_loader(param, loaded_weight)
                else:
                    raise AssertionError(
                        "MLP weight not in [gate_up_proj, down_proj]")
            loaded_params.add(name)
            return

        def is_mha_weight(name: str) -> bool:
            return "self_attn" in name and not name.endswith(".bias")

        def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
                                    self) -> None:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]

            weight_loader = getattr(
                param, "weight_loader",
                MiniMaxText01LinearAttention.weight_direct_load)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
            return

        def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
                                   self) -> None:

            flash_mha_params_mapping = [
                ("qkv_proj", "q_proj", "q"),
                ("qkv_proj", "k_proj", "k"),
                ("qkv_proj", "v_proj", "v"),
                ("gate_up_proj", "gate_proj", 0),
                ("gate_up_proj", "up_proj", 1),
            ]
            for (param_name, weight_name,
                 shard_id) in flash_mha_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if is_pp_missing_parameter(name, self):
                    return
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader = weight_loader_with_alias(name)(weight_loader)
                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(name)
                break
            else:
                if is_pp_missing_parameter(name, self):
                    return
                param = params_dict[name]

                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader = weight_loader_with_alias(name)(weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
            return

        def is_layer_norm_weight(name: str) -> bool:
            return "norm" in name and not name.endswith(
                ".bias") and name in params_dict

        def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
                                   self) -> None:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
            return

        def load_basic_weight(name: str, loaded_weight: torch.Tensor,
                              self) -> None:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
            return

        for name, loaded_weight in weights:
            weight_at_layer = which_layer(name)
            if weight_at_layer and weight_at_layer >= len(
                    self.model.decoder_attention_types):
                continue

            if is_layer_norm_weight(name):
                load_layer_norm_weight(name, loaded_weight, self)
                continue
            if is_mha_weight(name):
                if is_linear_attn_layer(weight_at_layer):
                    load_linear_attn_weight(name, loaded_weight, self)
                else:
                    load_flash_attn_weight(name, loaded_weight, self)
                continue
            if is_moe_weight(name):
                load_sparse_moe_weight(name, loaded_weight, self)
                continue
            if is_shared_mlp_weight(name):
                load_shared_mlp_weight(name, loaded_weight, self)
                continue

            if "rotary_emb.inv_freq" in name:
                continue

            load_basic_weight(name, loaded_weight, self)
        return loaded_params

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:

        return MambaStateDtypeCalculator.linear_attention_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
        )

    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
        use_v1: bool = True,
    ) -> tuple[tuple[int, ...], ...]:
        """Calculate shape for MiniMaxText01LinearAttention cache.

        Args:
            vllm_config: vLLM config
            use_v1: Get shapes for V1 (or V0)

        Returns:
            Tuple containing:
            - state_shape: Shape of the cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config

        return MambaStateShapeCalculator.linear_attention_state_shape(
            num_heads=hf_config.num_attention_heads,
            tp_size=parallel_config.tensor_parallel_size,
            head_dim=hf_config.head_dim,
        )

CONCAT_FFN instance-attribute

CONCAT_FFN = True

config instance-attribute

config = config

kv_cache instance-attribute

kv_cache = [
    (tensor([])) for _ in (range(flash_layer_count))
]

lm_head instance-attribute

lm_head = ParallelLMHead(
    unpadded_vocab_size,
    hidden_size,
    org_num_embeddings=vocab_size,
    padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(
    unpadded_vocab_size, vocab_size
)

lora_config instance-attribute

lora_config = lora_config

model instance-attribute

model = MiniMaxText01Model(
    config,
    model_config=model_config,
    cache_config=cache_config,
    quant_config=quant_config,
    scheduler_config=scheduler_config,
    prefix=maybe_prefix(prefix, "model"),
)

unpadded_vocab_size instance-attribute

unpadded_vocab_size = vocab_size

__init__

__init__(
    *, vllm_config: VllmConfig, prefix: str = ""
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config
    self.config = config
    self.lora_config = lora_config

    if not hasattr(config, "sliding_window"):
        config.sliding_window = None

    self.CONCAT_FFN = True

    self.unpadded_vocab_size = self.config.vocab_size
    if hasattr(vllm_config.model_config, "max_model_len"):
        self.config.max_model_len = vllm_config.model_config.max_model_len
    self.model = MiniMaxText01Model(
        self.config,
        model_config=vllm_config.model_config,
        cache_config=vllm_config.cache_config,
        quant_config=quant_config,
        scheduler_config=vllm_config.scheduler_config,
        prefix=maybe_prefix(prefix, "model"))
    if get_pp_group().is_last_rank:
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
        )

        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                self.config.vocab_size)

    else:
        self.lm_head = PPMissingLayer()
    self.lm_head.float()
    flash_layer_count = sum(
        1 for attn_type in self.model.decoder_attention_types
        if attn_type == 1)
    self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
    return

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def compute_logits(self, hidden_states: torch.Tensor,
                   sampling_metadata: SamplingMetadata) -> torch.Tensor:
    logits = self.logits_processor(self.lm_head, hidden_states.float(),
                                   sampling_metadata)

    return logits

copy_inputs_before_cuda_graphs

copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
Source code in vllm/model_executor/models/minimax_text_01.py
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
    return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
        input_buffers, **kwargs)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            **kwargs) -> torch.Tensor:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds, **kwargs)

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def get_input_embeddings(
    self,
    input_ids: torch.Tensor,
) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

get_mamba_state_dtype_from_config classmethod

get_mamba_state_dtype_from_config(
    vllm_config: VllmConfig,
) -> tuple[dtype, dtype]
Source code in vllm/model_executor/models/minimax_text_01.py
@classmethod
def get_mamba_state_dtype_from_config(
    cls,
    vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:

    return MambaStateDtypeCalculator.linear_attention_state_dtype(
        vllm_config.model_config.dtype,
        vllm_config.cache_config.mamba_cache_dtype,
    )

get_mamba_state_shape_from_config classmethod

get_mamba_state_shape_from_config(
    vllm_config: VllmConfig, use_v1: bool = True
) -> tuple[tuple[int, ...], ...]

Calculate shape for MiniMaxText01LinearAttention cache.

Parameters:

Name Type Description Default
vllm_config VllmConfig

vLLM config

required
use_v1 bool

Get shapes for V1 (or V0)

True

Returns:

Type Description
tuple[int, ...]

Tuple containing:

...
  • state_shape: Shape of the cache
Source code in vllm/model_executor/models/minimax_text_01.py
@classmethod
def get_mamba_state_shape_from_config(
    cls,
    vllm_config: "VllmConfig",
    use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]:
    """Calculate shape for MiniMaxText01LinearAttention cache.

    Args:
        vllm_config: vLLM config
        use_v1: Get shapes for V1 (or V0)

    Returns:
        Tuple containing:
        - state_shape: Shape of the cache
    """
    parallel_config = vllm_config.parallel_config
    hf_config = vllm_config.model_config.hf_config

    return MambaStateShapeCalculator.linear_attention_state_shape(
        num_heads=hf_config.num_attention_heads,
        tp_size=parallel_config.tensor_parallel_size,
        head_dim=hf_config.head_dim,
    )

get_seqlen_agnostic_capture_inputs

get_seqlen_agnostic_capture_inputs(batch_size: int)
Source code in vllm/model_executor/models/minimax_text_01.py
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
    return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
        batch_size)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/minimax_text_01.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    def which_layer(name: str) -> int:
        if "layers" in name:
            after_layer = name.split("layers")[-1]
            return int(after_layer.split(".")[1])
        return None

    def is_linear_attn_layer(layer_idx: int) -> bool:
        if layer_idx is None or layer_idx >= len(
                self.model.decoder_attention_types):
            return False
        return self.model.decoder_attention_types[layer_idx] == 0

    def is_moe_weight(name: str) -> bool:
        return "block_sparse_moe" in name and not name.endswith(".bias")

    def get_expert_id(param_name):
        pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.'
        match = re.search(pattern, param_name)
        if match:
            return match.group(1)
        return None

    def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
                               self) -> None:
        if isinstance(self.config.num_local_experts, list):
            expert_params_mapping = [
                ("w13_weight"
                 if weight_name in ["w1", "w3"] else "w2_weight",
                 f"experts.{expert_id}.{weight_name}.weight", expert_id)
                for expert_id in range(max(self.config.num_local_experts))
                for weight_name in ["w1", "w2", "w3"]
            ]
        else:
            expert_params_mapping = [
                ("w13_scale" if weight_name in ["w1", "w3"] else
                 "w2_scale", f"{expert_id}.{weight_name}.weight_scale",
                 expert_id, weight_name)
                for expert_id in range(self.config.num_local_experts)
                for weight_name in ["w1", "w2", "w3"]
            ] + [("w13_weight" if weight_name in ["w1", "w3"] else
                  "w2_weight", f"{expert_id}.{weight_name}.weight",
                  expert_id, weight_name)
                 for expert_id in range(self.config.num_local_experts)
                 for weight_name in ["w1", "w2", "w3"]]
        for (param_name, weight_name, expert_id,
             shard_id) in expert_params_mapping:
            name_expert_id = get_expert_id(name)
            if name_expert_id is not None and int(name_expert_id) != int(
                    expert_id):
                continue
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param,
                          loaded_weight,
                          weight_name,
                          expert_id=expert_id,
                          shard_id=shard_id)
            loaded_params.add(name)
            break
        else:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return

    def is_shared_mlp_weight(name: str) -> bool:
        return "shared_mlp" in name and not name.endswith(".bias")

    def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
                               self) -> None:
        if not self.CONCAT_FFN:
            if "gate_proj" in name:
                name = name.replace("gate_proj", "w1", 1)
            elif "up_proj" in name:
                name = name.replace("up_proj", "w3", 1)
            elif "down_proj" in name:
                name = name.replace("down_proj", "w2", 1)
        else:
            if "gate_proj" in name:
                name = name.replace("gate_proj", "gate_up_proj", 1)
                loaded_shard_id = 0
            elif "up_proj" in name:
                name = name.replace("up_proj", "gate_up_proj", 1)
                loaded_shard_id = 1
        if is_pp_missing_parameter(name, self):
            return
        param = params_dict[name]
        weight_loader = getattr(param, "weight_loader",
                                default_weight_loader)
        weight_loader = weight_loader_with_alias(name)(weight_loader)
        if not self.CONCAT_FFN:
            weight_loader(param, loaded_weight)
        else:
            if "gate_up_proj" in name:
                weight_loader(param, loaded_weight, loaded_shard_id)
            elif "down_proj" in name:
                weight_loader(param, loaded_weight)
            else:
                raise AssertionError(
                    "MLP weight not in [gate_up_proj, down_proj]")
        loaded_params.add(name)
        return

    def is_mha_weight(name: str) -> bool:
        return "self_attn" in name and not name.endswith(".bias")

    def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
                                self) -> None:
        if is_pp_missing_parameter(name, self):
            return
        param = params_dict[name]

        weight_loader = getattr(
            param, "weight_loader",
            MiniMaxText01LinearAttention.weight_direct_load)
        weight_loader = weight_loader_with_alias(name)(weight_loader)
        weight_loader(param, loaded_weight)
        loaded_params.add(name)
        return

    def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
                               self) -> None:

        flash_mha_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        for (param_name, weight_name,
             shard_id) in flash_mha_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight, shard_id)
            loaded_params.add(name)
            break
        else:
            if is_pp_missing_parameter(name, self):
                return
            param = params_dict[name]

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader = weight_loader_with_alias(name)(weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return

    def is_layer_norm_weight(name: str) -> bool:
        return "norm" in name and not name.endswith(
            ".bias") and name in params_dict

    def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
                               self) -> None:
        if is_pp_missing_parameter(name, self):
            return
        param = params_dict[name]
        weight_loader = getattr(param, "weight_loader",
                                default_weight_loader)
        weight_loader = weight_loader_with_alias(name)(weight_loader)
        weight_loader(param, loaded_weight)
        loaded_params.add(name)
        return

    def load_basic_weight(name: str, loaded_weight: torch.Tensor,
                          self) -> None:
        if is_pp_missing_parameter(name, self):
            return
        param = params_dict[name]
        weight_loader = getattr(param, "weight_loader",
                                default_weight_loader)
        weight_loader = weight_loader_with_alias(name)(weight_loader)
        weight_loader(param, loaded_weight)
        loaded_params.add(name)
        return

    for name, loaded_weight in weights:
        weight_at_layer = which_layer(name)
        if weight_at_layer and weight_at_layer >= len(
                self.model.decoder_attention_types):
            continue

        if is_layer_norm_weight(name):
            load_layer_norm_weight(name, loaded_weight, self)
            continue
        if is_mha_weight(name):
            if is_linear_attn_layer(weight_at_layer):
                load_linear_attn_weight(name, loaded_weight, self)
            else:
                load_flash_attn_weight(name, loaded_weight, self)
            continue
        if is_moe_weight(name):
            load_sparse_moe_weight(name, loaded_weight, self)
            continue
        if is_shared_mlp_weight(name):
            load_shared_mlp_weight(name, loaded_weight, self)
            continue

        if "rotary_emb.inv_freq" in name:
            continue

        load_basic_weight(name, loaded_weight, self)
    return loaded_params

make_empty_intermediate_tensors

make_empty_intermediate_tensors(
    batch_size: int, dtype: dtype, device: device
) -> IntermediateTensors
Source code in vllm/model_executor/models/minimax_text_01.py
def make_empty_intermediate_tensors(
        self, batch_size: int, dtype: torch.dtype,
        device: torch.device) -> IntermediateTensors:
    return IntermediateTensors({
        "hidden_states":
        torch.zeros((batch_size, self.config.hidden_size),
                    dtype=dtype,
                    device=device),
        "residual":
        torch.zeros((batch_size, self.config.hidden_size),
                    dtype=dtype,
                    device=device),
    })

MiniMaxText01LinearAttention

Bases: Module, MambaBase

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01LinearAttention(nn.Module, MambaBase):

    @property
    def mamba_type(self) -> str:
        return "linear_attention"

    def get_attn_backend(self) -> type["AttentionBackend"]:
        from vllm.v1.attention.backends.linear_attn import (
            LinearAttentionBackend)
        return LinearAttentionBackend

    def get_state_dtype(self) -> tuple[torch.dtype]:
        return MambaStateDtypeCalculator.linear_attention_state_dtype(
            self.model_config.dtype,
            self.cache_config.mamba_cache_dtype,
        )

    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.linear_attention_state_shape(
            num_heads=self.num_heads,
            tp_size=self.tp_size,
            head_dim=self.head_dim)

    def __init__(
        self,
        hidden_size: int,
        hidden_inner_size: int,
        num_heads: int,
        head_dim: int,
        max_position: int,
        block_size: int,
        num_hidden_layer: int,
        model_config: Optional[ModelConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        layer_idx: int = 0,
        linear_layer_idx: int = 0,
        prefix: str = "linear_attn",
    ) -> None:
        super().__init__()

        self.layer_idx = layer_idx
        self.BLOCK = block_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.total_num_heads = num_heads
        self.hidden_inner_size = hidden_inner_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        assert self.total_num_heads % self.tp_size == 0
        self.tp_heads = self.total_num_heads // self.tp_size
        self.qkv_size = self.num_heads * self.head_dim
        self.tp_hidden = self.head_dim * self.tp_heads
        self.model_config = model_config
        self.cache_config = cache_config
        self.prefix = prefix

        self.qkv_proj = ColumnParallelLinear(
            hidden_size,
            self.hidden_inner_size * 3,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.output_gate = ColumnParallelLinear(
            hidden_size,
            self.hidden_inner_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.output_gate",
        )
        self.out_proj = RowParallelLinear(
            self.hidden_inner_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
        self.norm = MiniMaxText01RMSNormTP(
            self.hidden_inner_size,
            eps=1e-5,
        )

        slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
            self.num_heads)
        if num_hidden_layer <= 1:
            self.slope_rate = slope_rate * (1 + 1e-5)
        else:
            self.slope_rate = slope_rate * (1 - layer_idx /
                                            (num_hidden_layer - 1) + 1e-5)
        self.tp_slope = self.slope_rate[self.tp_rank *
                                        self.tp_heads:(self.tp_rank + 1) *
                                        self.tp_heads].contiguous()

        if envs.VLLM_USE_V1:
            compilation_config = get_current_vllm_config().compilation_config
            if prefix in compilation_config.static_forward_context:
                raise ValueError(f"Duplicate layer name: {prefix}")
            compilation_config.static_forward_context[prefix] = self

    @staticmethod
    def weight_direct_load(param: torch.Tensor,
                           loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)
        return

    @staticmethod
    def _build_slope_tensor(n_attention_heads: int):

        def get_slopes(n):

            def get_slopes_power_of_2(n):
                start = 2**(-(2**-(math.log2(n) - 3)))
                ratio = start
                return [start * ratio**i for i in range(n)]

            if math.log2(n).is_integer():
                return get_slopes_power_of_2(n)
            else:
                closest_power_of_2 = 2**math.floor(math.log2(n))
                return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
                    2 * closest_power_of_2)[0::2][:n - closest_power_of_2])

        slopes = torch.tensor(get_slopes(n_attention_heads),
                              dtype=torch.float32).reshape(
                                  n_attention_heads, 1, 1)
        return slopes

    def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
                               attn_metadata):
        hidden = []
        for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
            if _prefill_idx >= len(attn_metadata.query_start_loc):
                break
            if _prefill_idx >= len(state_indices_tensor):
                break
            # prefills are packed at end of batch in V1
            offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
            _start = attn_metadata.query_start_loc[offset + _prefill_idx]
            _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
            slot_id = state_indices_tensor[offset + _prefill_idx]
            qs = q[_start:_end].transpose(0, 1).contiguous()
            ks = k[_start:_end].transpose(0, 1).contiguous()
            vs = v[_start:_end].transpose(0, 1).contiguous()
            slice_layer_cache = kv_cache[slot_id, ...]

            out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
                qs,
                ks,
                vs,
                slice_layer_cache,
                self.tp_slope,
                self.BLOCK,
                layer_idx=self.layer_idx)
            hidden.append(out_slice.contiguous())
        if attn_metadata.num_decode_tokens > 0:
            hidden_decode = self._decode_infer(q, k, v, kv_cache,
                                               state_indices_tensor,
                                               attn_metadata)
            if envs.VLLM_USE_V1:
                hidden.insert(0, hidden_decode)
            else:
                hidden.append(hidden_decode)

        if not hidden:
            return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)

        hidden = torch.concat(hidden, dim=0).contiguous()
        return hidden

    def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
                      attn_metadata):
        if not envs.VLLM_USE_V1:
            q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
            k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
            v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
            num_prefills = getattr(attn_metadata, "num_prefills", 0)
            slot_id = state_indices_tensor[num_prefills:]
        else:
            q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
            k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
            v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
            slot_id = state_indices_tensor[:attn_metadata.num_decodes]
        hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
                                              slot_id, 32)
        return hidden

    def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
                kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        qkv32 = qkv.to(torch.float32)
        qkvact = torch.nn.functional.silu(qkv32)
        qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
        q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        if envs.VLLM_USE_V1:
            if attn_metadata is not None:
                assert isinstance(attn_metadata, dict)
                attn_metadata = attn_metadata[self.prefix]
                assert isinstance(attn_metadata, LinearAttentionMetadata)
                kv_cache = self.kv_cache[forward_context.virtual_engine][0]
                state_indices_tensor = attn_metadata.state_indices_tensor

                num_prefills = getattr(attn_metadata, "num_prefills", 0)
                if num_prefills > 0:
                    num_decode_tokens = getattr(attn_metadata,
                                                "num_decode_tokens", 0)
                    for prefill_idx in range(num_prefills):
                        q_start = attn_metadata.query_start_loc[
                            num_decode_tokens + prefill_idx]
                        q_end = attn_metadata.query_start_loc[num_decode_tokens
                                                              + prefill_idx +
                                                              1]
                        query_len = q_end - q_start
                        context_len = attn_metadata.seq_lens[
                            num_decode_tokens + prefill_idx] - query_len
                        if context_len == 0:
                            block_to_clear = state_indices_tensor[
                                num_decode_tokens + prefill_idx]
                            kv_cache[block_to_clear, ...] = 0
        else:
            kv_cache = kv_caches.minimax_cache
            state_indices_tensor = kv_caches.state_indices_tensor

        decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
        if attn_metadata is None:
            hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
                                 device=q.device,
                                 dtype=q.dtype)
        else:
            if not decode_only:
                hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
                                                     state_indices_tensor,
                                                     attn_metadata)
            else:
                hidden = self._decode_infer(q, k, v, kv_cache,
                                            state_indices_tensor,
                                            attn_metadata)

        hidden = self.norm._forward(hidden)
        gate, _ = self.output_gate(hidden_states)
        hidden = F.sigmoid(gate) * hidden
        hidden = hidden.to(hidden_states.dtype)
        hidden, _ = self.out_proj(hidden)
        return hidden

BLOCK instance-attribute

BLOCK = block_size

cache_config instance-attribute

cache_config = cache_config

head_dim instance-attribute

head_dim = head_dim

hidden_inner_size instance-attribute

hidden_inner_size = hidden_inner_size

hidden_size instance-attribute

hidden_size = hidden_size

layer_idx instance-attribute

layer_idx = layer_idx

mamba_type property

mamba_type: str

model_config instance-attribute

model_config = model_config

norm instance-attribute

norm = MiniMaxText01RMSNormTP(hidden_inner_size, eps=1e-05)

num_heads instance-attribute

num_heads = num_heads

out_proj instance-attribute

out_proj = RowParallelLinear(
    hidden_inner_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.out_proj",
)

output_gate instance-attribute

output_gate = ColumnParallelLinear(
    hidden_size,
    hidden_inner_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.output_gate",
)

prefix instance-attribute

prefix = prefix

qkv_proj instance-attribute

qkv_proj = ColumnParallelLinear(
    hidden_size,
    hidden_inner_size * 3,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

qkv_size instance-attribute

qkv_size = num_heads * head_dim

slope_rate instance-attribute

slope_rate = slope_rate * (1 + 1e-05)

total_num_heads instance-attribute

total_num_heads = num_heads

tp_heads instance-attribute

tp_heads = total_num_heads // tp_size

tp_hidden instance-attribute

tp_hidden = head_dim * tp_heads

tp_rank instance-attribute

tp_size instance-attribute

tp_slope instance-attribute

tp_slope = contiguous()

__init__

__init__(
    hidden_size: int,
    hidden_inner_size: int,
    num_heads: int,
    head_dim: int,
    max_position: int,
    block_size: int,
    num_hidden_layer: int,
    model_config: Optional[ModelConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = 0,
    linear_layer_idx: int = 0,
    prefix: str = "linear_attn",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    hidden_size: int,
    hidden_inner_size: int,
    num_heads: int,
    head_dim: int,
    max_position: int,
    block_size: int,
    num_hidden_layer: int,
    model_config: Optional[ModelConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = 0,
    linear_layer_idx: int = 0,
    prefix: str = "linear_attn",
) -> None:
    super().__init__()

    self.layer_idx = layer_idx
    self.BLOCK = block_size
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.head_dim = head_dim
    self.total_num_heads = num_heads
    self.hidden_inner_size = hidden_inner_size
    self.tp_size = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()

    assert self.total_num_heads % self.tp_size == 0
    self.tp_heads = self.total_num_heads // self.tp_size
    self.qkv_size = self.num_heads * self.head_dim
    self.tp_hidden = self.head_dim * self.tp_heads
    self.model_config = model_config
    self.cache_config = cache_config
    self.prefix = prefix

    self.qkv_proj = ColumnParallelLinear(
        hidden_size,
        self.hidden_inner_size * 3,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )
    self.output_gate = ColumnParallelLinear(
        hidden_size,
        self.hidden_inner_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.output_gate",
    )
    self.out_proj = RowParallelLinear(
        self.hidden_inner_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.out_proj",
    )
    self.norm = MiniMaxText01RMSNormTP(
        self.hidden_inner_size,
        eps=1e-5,
    )

    slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
        self.num_heads)
    if num_hidden_layer <= 1:
        self.slope_rate = slope_rate * (1 + 1e-5)
    else:
        self.slope_rate = slope_rate * (1 - layer_idx /
                                        (num_hidden_layer - 1) + 1e-5)
    self.tp_slope = self.slope_rate[self.tp_rank *
                                    self.tp_heads:(self.tp_rank + 1) *
                                    self.tp_heads].contiguous()

    if envs.VLLM_USE_V1:
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

_build_slope_tensor staticmethod

_build_slope_tensor(n_attention_heads: int)
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def _build_slope_tensor(n_attention_heads: int):

    def get_slopes(n):

        def get_slopes_power_of_2(n):
            start = 2**(-(2**-(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            closest_power_of_2 = 2**math.floor(math.log2(n))
            return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
                2 * closest_power_of_2)[0::2][:n - closest_power_of_2])

    slopes = torch.tensor(get_slopes(n_attention_heads),
                          dtype=torch.float32).reshape(
                              n_attention_heads, 1, 1)
    return slopes

_decode_infer

_decode_infer(
    q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
Source code in vllm/model_executor/models/minimax_text_01.py
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
                  attn_metadata):
    if not envs.VLLM_USE_V1:
        q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
        num_prefills = getattr(attn_metadata, "num_prefills", 0)
        slot_id = state_indices_tensor[num_prefills:]
    else:
        q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
        k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
        v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
        slot_id = state_indices_tensor[:attn_metadata.num_decodes]
    hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
                                          slot_id, 32)
    return hidden

_prefill_and_mix_infer

_prefill_and_mix_infer(
    q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
Source code in vllm/model_executor/models/minimax_text_01.py
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
                           attn_metadata):
    hidden = []
    for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
        if _prefill_idx >= len(attn_metadata.query_start_loc):
            break
        if _prefill_idx >= len(state_indices_tensor):
            break
        # prefills are packed at end of batch in V1
        offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
        _start = attn_metadata.query_start_loc[offset + _prefill_idx]
        _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
        slot_id = state_indices_tensor[offset + _prefill_idx]
        qs = q[_start:_end].transpose(0, 1).contiguous()
        ks = k[_start:_end].transpose(0, 1).contiguous()
        vs = v[_start:_end].transpose(0, 1).contiguous()
        slice_layer_cache = kv_cache[slot_id, ...]

        out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
            qs,
            ks,
            vs,
            slice_layer_cache,
            self.tp_slope,
            self.BLOCK,
            layer_idx=self.layer_idx)
        hidden.append(out_slice.contiguous())
    if attn_metadata.num_decode_tokens > 0:
        hidden_decode = self._decode_infer(q, k, v, kv_cache,
                                           state_indices_tensor,
                                           attn_metadata)
        if envs.VLLM_USE_V1:
            hidden.insert(0, hidden_decode)
        else:
            hidden.append(hidden_decode)

    if not hidden:
        return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)

    hidden = torch.concat(hidden, dim=0).contiguous()
    return hidden

forward

forward(
    hidden_states: Tensor,
    positions: Tensor,
    kv_caches: MinimaxCacheParams,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
            kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    qkv32 = qkv.to(torch.float32)
    qkvact = torch.nn.functional.silu(qkv32)
    qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
    q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
    forward_context = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if envs.VLLM_USE_V1:
        if attn_metadata is not None:
            assert isinstance(attn_metadata, dict)
            attn_metadata = attn_metadata[self.prefix]
            assert isinstance(attn_metadata, LinearAttentionMetadata)
            kv_cache = self.kv_cache[forward_context.virtual_engine][0]
            state_indices_tensor = attn_metadata.state_indices_tensor

            num_prefills = getattr(attn_metadata, "num_prefills", 0)
            if num_prefills > 0:
                num_decode_tokens = getattr(attn_metadata,
                                            "num_decode_tokens", 0)
                for prefill_idx in range(num_prefills):
                    q_start = attn_metadata.query_start_loc[
                        num_decode_tokens + prefill_idx]
                    q_end = attn_metadata.query_start_loc[num_decode_tokens
                                                          + prefill_idx +
                                                          1]
                    query_len = q_end - q_start
                    context_len = attn_metadata.seq_lens[
                        num_decode_tokens + prefill_idx] - query_len
                    if context_len == 0:
                        block_to_clear = state_indices_tensor[
                            num_decode_tokens + prefill_idx]
                        kv_cache[block_to_clear, ...] = 0
    else:
        kv_cache = kv_caches.minimax_cache
        state_indices_tensor = kv_caches.state_indices_tensor

    decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
    if attn_metadata is None:
        hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
                             device=q.device,
                             dtype=q.dtype)
    else:
        if not decode_only:
            hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
                                                 state_indices_tensor,
                                                 attn_metadata)
        else:
            hidden = self._decode_infer(q, k, v, kv_cache,
                                        state_indices_tensor,
                                        attn_metadata)

    hidden = self.norm._forward(hidden)
    gate, _ = self.output_gate(hidden_states)
    hidden = F.sigmoid(gate) * hidden
    hidden = hidden.to(hidden_states.dtype)
    hidden, _ = self.out_proj(hidden)
    return hidden

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/models/minimax_text_01.py
def get_attn_backend(self) -> type["AttentionBackend"]:
    from vllm.v1.attention.backends.linear_attn import (
        LinearAttentionBackend)
    return LinearAttentionBackend

get_state_dtype

get_state_dtype() -> tuple[dtype]
Source code in vllm/model_executor/models/minimax_text_01.py
def get_state_dtype(self) -> tuple[torch.dtype]:
    return MambaStateDtypeCalculator.linear_attention_state_dtype(
        self.model_config.dtype,
        self.cache_config.mamba_cache_dtype,
    )

get_state_shape

get_state_shape() -> tuple[
    tuple[int, ...], tuple[int, ...]
]
Source code in vllm/model_executor/models/minimax_text_01.py
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
    return MambaStateShapeCalculator.linear_attention_state_shape(
        num_heads=self.num_heads,
        tp_size=self.tp_size,
        head_dim=self.head_dim)

weight_direct_load staticmethod

weight_direct_load(
    param: Tensor, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def weight_direct_load(param: torch.Tensor,
                       loaded_weight: torch.Tensor) -> None:
    assert param.size() == loaded_weight.size()
    param.data.copy_(loaded_weight)
    return

MiniMaxText01LinearKernel

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01LinearKernel:

    @staticmethod
    def jit_linear_forward_prefix(q: torch.Tensor,
                                  k: torch.Tensor,
                                  v: torch.Tensor,
                                  kv_caches: torch.Tensor,
                                  slope_rate: torch.Tensor,
                                  block_size: int,
                                  layer_idx: int = None,
                                  **kwargs) -> torch.Tensor:

        slope_rate = slope_rate.to(torch.float32)
        should_pad_dim = q.dim() == 3
        if should_pad_dim:
            q = q.unsqueeze(0)
            k = k.unsqueeze(0)
            v = v.unsqueeze(0)
        b, h, n, d = q.shape
        e = d
        kv_history = kv_caches.reshape(1, h, d, e).contiguous()
        output, kv_history = lightning_attention(q,
                                                 k,
                                                 v,
                                                 slope_rate,
                                                 block_size=block_size,
                                                 kv_history=kv_history)
        kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
        assert output.shape[0] == 1, "batch size must be 1"
        return rearrange(output.squeeze(0), "h n d -> n (h d)")

jit_linear_forward_prefix staticmethod

jit_linear_forward_prefix(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    kv_caches: Tensor,
    slope_rate: Tensor,
    block_size: int,
    layer_idx: int = None,
    **kwargs,
) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
                              k: torch.Tensor,
                              v: torch.Tensor,
                              kv_caches: torch.Tensor,
                              slope_rate: torch.Tensor,
                              block_size: int,
                              layer_idx: int = None,
                              **kwargs) -> torch.Tensor:

    slope_rate = slope_rate.to(torch.float32)
    should_pad_dim = q.dim() == 3
    if should_pad_dim:
        q = q.unsqueeze(0)
        k = k.unsqueeze(0)
        v = v.unsqueeze(0)
    b, h, n, d = q.shape
    e = d
    kv_history = kv_caches.reshape(1, h, d, e).contiguous()
    output, kv_history = lightning_attention(q,
                                             k,
                                             v,
                                             slope_rate,
                                             block_size=block_size,
                                             kv_history=kv_history)
    kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
    assert output.shape[0] == 1, "batch size must be 1"
    return rearrange(output.squeeze(0), "h n d -> n (h d)")

MiniMaxText01MLP

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01MLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant_config: Optional[QuantizationConfig] = None,
        layer_idx: int = None,
        prefix: str = "mlp",
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx

        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
        self.act_fn = SiluAndMul()
        return

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

act_fn instance-attribute

act_fn = SiluAndMul()

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size] * 2,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

layer_idx instance-attribute

layer_idx = layer_idx

__init__

__init__(
    hidden_size: int,
    intermediate_size: int,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = None,
    prefix: str = "mlp",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    quant_config: Optional[QuantizationConfig] = None,
    layer_idx: int = None,
    prefix: str = "mlp",
) -> None:
    super().__init__()
    self.layer_idx = layer_idx

    self.gate_up_proj = MergedColumnParallelLinear(
        hidden_size,
        [intermediate_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.gate_up_proj",
    )
    self.down_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.down_proj",
    )
    self.act_fn = SiluAndMul()
    return

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self, x: torch.Tensor) -> torch.Tensor:

    gate_up, _ = self.gate_up_proj(x)
    x = self.act_fn(gate_up)
    x, _ = self.down_proj(x)
    return x

MiniMaxText01MoE

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01MoE(nn.Module):

    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: Optional[torch.dtype] = None,
        layer_idx: int = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "moe",
    ) -> None:
        super().__init__()

        self.layer_idx = layer_idx
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size // self.tp_size
        self.quant_config = quant_config

        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        self.gate = ReplicatedLinear(
            self.hidden_size,
            self.num_total_experts,
            bias=False,
            params_dtype=torch.float32,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
        self.gate.weight.weight_loader = MiniMaxText01MoE.gate_weight_loader

        self.experts = FusedMoE(
            num_experts=self.num_total_experts,
            top_k=self.top_k,
            hidden_size=self.hidden_size,
            intermediate_size=self.intermediate_size * self.tp_size,
            params_dtype=self.params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=self.quant_config,
            tp_size=self.tp_size,
            prefix=f"{prefix}.experts",
        )
        return

    @staticmethod
    def gate_weight_loader(param: nn.Parameter,
                           loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight.to(torch.float32))
        return

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32))
        final_hidden_states = self.experts(
            hidden_states, router_logits_fp32.to(hidden_states.dtype))
        final_hidden = final_hidden_states.view(num_tokens, hidden_size)
        return final_hidden

experts instance-attribute

experts = FusedMoE(
    num_experts=num_total_experts,
    top_k=top_k,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size * tp_size,
    params_dtype=params_dtype,
    reduce_results=True,
    renormalize=True,
    quant_config=quant_config,
    tp_size=tp_size,
    prefix=f"{prefix}.experts",
)

gate instance-attribute

gate = ReplicatedLinear(
    hidden_size,
    num_total_experts,
    bias=False,
    params_dtype=float32,
    quant_config=None,
    prefix=f"{prefix}.gate",
)

hidden_size instance-attribute

hidden_size = hidden_size

intermediate_size instance-attribute

intermediate_size = intermediate_size // tp_size

layer_idx instance-attribute

layer_idx = layer_idx

num_total_experts instance-attribute

num_total_experts = num_experts

params_dtype instance-attribute

params_dtype = params_dtype

quant_config instance-attribute

quant_config = quant_config

top_k instance-attribute

top_k = top_k

tp_size instance-attribute

__init__

__init__(
    num_experts: int,
    top_k: int,
    hidden_size: int,
    intermediate_size: int,
    params_dtype: Optional[dtype] = None,
    layer_idx: int = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "moe",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    num_experts: int,
    top_k: int,
    hidden_size: int,
    intermediate_size: int,
    params_dtype: Optional[torch.dtype] = None,
    layer_idx: int = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "moe",
) -> None:
    super().__init__()

    self.layer_idx = layer_idx
    self.tp_size = get_tensor_model_parallel_world_size()
    self.num_total_experts = num_experts
    self.top_k = top_k
    self.hidden_size = hidden_size
    self.intermediate_size = intermediate_size // self.tp_size
    self.quant_config = quant_config

    if params_dtype is None:
        params_dtype = torch.get_default_dtype()
    self.params_dtype = params_dtype

    self.gate = ReplicatedLinear(
        self.hidden_size,
        self.num_total_experts,
        bias=False,
        params_dtype=torch.float32,
        quant_config=None,
        prefix=f"{prefix}.gate",
    )
    self.gate.weight.weight_loader = MiniMaxText01MoE.gate_weight_loader

    self.experts = FusedMoE(
        num_experts=self.num_total_experts,
        top_k=self.top_k,
        hidden_size=self.hidden_size,
        intermediate_size=self.intermediate_size * self.tp_size,
        params_dtype=self.params_dtype,
        reduce_results=True,
        renormalize=True,
        quant_config=self.quant_config,
        tp_size=self.tp_size,
        prefix=f"{prefix}.experts",
    )
    return

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    num_tokens, hidden_size = hidden_states.shape
    hidden_states = hidden_states.view(-1, self.hidden_size)
    router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32))
    final_hidden_states = self.experts(
        hidden_states, router_logits_fp32.to(hidden_states.dtype))
    final_hidden = final_hidden_states.view(num_tokens, hidden_size)
    return final_hidden

gate_weight_loader staticmethod

gate_weight_loader(
    param: Parameter, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def gate_weight_loader(param: nn.Parameter,
                       loaded_weight: torch.Tensor) -> None:
    assert param.size() == loaded_weight.size()
    param.data.copy_(loaded_weight.to(torch.float32))
    return

MiniMaxText01Model

Bases: Module

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01Model(nn.Module):

    def __init__(
        self,
        config: MiniMaxConfig,
        model_config: Optional[ModelConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        scheduler_config=None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.decoder_attention_types = getattr(
            config, "attn_type_list", False) or getattr(
                config, "decoder_attention_types", False)
        # The HF format uses "layer_types" instead of "attn_type_list"
        # where "linear_attention" is 0 and "full_attention" is 1
        if not self.decoder_attention_types and hasattr(config, "layer_types"):
            self.decoder_attention_types = []
            for layer_type in config.layer_types:
                if layer_type == "linear_attention":
                    self.decoder_attention_types.append(0)
                elif layer_type == "full_attention":
                    self.decoder_attention_types.append(1)
                else:
                    raise ValueError(f"Unsupported layer type: {layer_type}")
        # Default to full attention
        if not self.decoder_attention_types:
            self.decoder_attention_types = [1] * config.num_hidden_layers
        self.num_layers = config.num_hidden_layers

        self._layer_barrier = False
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=self.vocab_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()

        def layer_fn(prefix):
            layer_idx = int(prefix.split('.')[-1])
            layer_config = config
            layer_config.attention_type = self.decoder_attention_types[
                layer_idx]
            layer_config.layer_idx = layer_idx

            decoder_kwargs = {
                "quant_config": quant_config,
                "layer_id": layer_idx,
                "model_config": model_config,
                "cache_config": cache_config
            }

            if layer_config.attention_type == 0:
                decoder_kwargs["linear_layer_id"] = sum(
                    1 for i in range(layer_idx)
                    if self.decoder_attention_types[i] == 0)
            else:
                decoder_kwargs["linear_layer_id"] = None

            if hasattr(config, "num_local_experts") and isinstance(
                    config.num_local_experts, list):
                decoder_kwargs["expert_num"] = config.num_local_experts[
                    layer_idx]
            elif hasattr(config, "num_local_experts") and isinstance(
                    config.num_local_experts, int):
                decoder_kwargs["expert_num"] = config.num_local_experts
            else:
                decoder_kwargs["expert_num"] = 1

            return MiniMaxText01DecoderLayer(layer_config,
                                             **decoder_kwargs,
                                             prefix=prefix)

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers")

        linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
                                if self.decoder_attention_types[i] == 0)
        max_slots_number = scheduler_config.max_num_seqs
        self.cache_shape = (linear_layer_nums, max_slots_number,
                            config.num_attention_heads //
                            get_tensor_model_parallel_world_size(),
                            config.head_dim, config.head_dim)
        _dummy = torch.zeros(1)
        self._dtype = _dummy.dtype
        del _dummy

        if not envs.VLLM_USE_V1:
            self.minimax_cache = MinimaxCacheManager(
                dtype=torch.float32, cache_shape=self.cache_shape)

        rope_theta = getattr(config, "rope_theta", 10000)
        head_dim = getattr(config, "head_dim", None)
        if head_dim is None:
            head_dim = config.hidden_size // config.num_attention_heads
        if hasattr(config, "max_model_len") and isinstance(
                config.max_model_len, int):
            max_position_embeddings = min(config.max_position_embeddings,
                                          config.max_model_len)
        self.rotary_emb = MiniMaxText01RotaryEmbedding(
            head_dim,
            rotary_dim=config.rotary_dim
            if hasattr(config, "rotary_dim") else head_dim,
            max_position=max_position_embeddings,
            base=int(rope_theta),
            is_neox_style=True,
            cache_dtype=torch.float32,
        )

        norm_kwargs = {}
        if hasattr(config, "rms_norm_eps"):
            norm_kwargs["eps"] = config.rms_norm_eps
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, **norm_kwargs)
        else:
            self.norm = PPMissingLayer()
        self.embed_scale = 1.0
        return

    def _clear_prefill_cache(self, attn_metadata,
                             minimax_cache_tensors: torch.Tensor, **kwargs):
        seq_to_slot_maps = {}
        seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), [])
        for _, seq_to_slot_map in (
                self.minimax_cache.cache_indices_mapping.items()):
            seq_to_slot_maps.update(seq_to_slot_map)

        slots_to_clear = []
        for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
            if _prefill_id >= len(seq_id_map):
                break
            seq_id = seq_id_map[_prefill_id]
            if attn_metadata.context_lens_tensor[
                    _prefill_id] == 0 and seq_id in seq_to_slot_maps:
                slots_to_clear.append(seq_to_slot_maps[seq_id])

        if slots_to_clear:
            slots_tensor = torch.tensor(slots_to_clear,
                                        device=minimax_cache_tensors.device,
                                        dtype=torch.long)
            minimax_cache_tensors[:, slots_tensor, ...] = 0

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(self,
                input_ids: Optional[torch.Tensor],
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        if not envs.VLLM_USE_V1 and attn_metadata is None:
            return None
        if "request_ids_to_seq_ids" not in kwargs:
            kwargs["request_ids_to_seq_ids"] = {}
        if "finished_requests_ids" not in kwargs:
            kwargs["finished_requests_ids"] = []

        if not envs.VLLM_USE_V1:
            (
                minimax_cache_tensors,
                state_indices_tensor,
            ) = self.minimax_cache.current_run_tensors(**kwargs)
            if getattr(attn_metadata, "num_prefills", 0) > 0:
                self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
                                          **kwargs)

            minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
                                                      state_indices_tensor)
        else:
            minimax_cache_params = None

        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                hidden_states = self.embed_scale * self.embed_tokens(input_ids)
            else:
                hidden_states = inputs_embeds
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        minimax_cache_index = 0

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            if attn_metadata is not None:
                # TODO (tdoublep): this whole thing with the rotary_emb is
                # weird. we shouldn't be passing it via attn_metadata imo.
                if envs.VLLM_USE_V1:
                    if isinstance(layer.self_attn, MiniMaxText01Attention):
                        attn_metadata[layer.prefix +
                                      ".attn"].rotary_emb = self.rotary_emb
                else:
                    attn_metadata.rotary_emb = self.rotary_emb

            _caches = None
            if not envs.VLLM_USE_V1 and isinstance(
                    layer.self_attn, MiniMaxText01LinearAttention):
                current_state_layer = minimax_cache_index
                _caches = minimax_cache_params.at_layer_idx(
                    current_state_layer)
                minimax_cache_index += 1
            hidden_states, residual = layer(
                hidden_states=hidden_states,
                positions=positions,
                kv_caches=_caches,
                attn_metadata=attn_metadata,
                residual=residual,
            )
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
        if residual is not None:
            hidden_states, _ = self.norm(hidden_states, residual)
        else:
            hidden_states = self.norm(hidden_states)

        return hidden_states

_dtype instance-attribute

_dtype = dtype

_layer_barrier instance-attribute

_layer_barrier = False

cache_shape instance-attribute

cache_shape = (
    linear_layer_nums,
    max_slots_number,
    num_attention_heads
    // get_tensor_model_parallel_world_size(),
    head_dim,
    head_dim,
)

decoder_attention_types instance-attribute

decoder_attention_types = getattr(
    config, "attn_type_list", False
) or getattr(config, "decoder_attention_types", False)

embed_scale instance-attribute

embed_scale = 1.0

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size, org_num_embeddings=vocab_size
)

minimax_cache instance-attribute

minimax_cache = MinimaxCacheManager(
    dtype=float32, cache_shape=cache_shape
)

norm instance-attribute

norm = RMSNorm(hidden_size, **norm_kwargs)

num_layers instance-attribute

num_layers = num_hidden_layers

padding_idx instance-attribute

padding_idx = pad_token_id

rotary_emb instance-attribute

rotary_emb = MiniMaxText01RotaryEmbedding(
    head_dim,
    rotary_dim=rotary_dim
    if hasattr(config, "rotary_dim")
    else head_dim,
    max_position=max_position_embeddings,
    base=int(rope_theta),
    is_neox_style=True,
    cache_dtype=float32,
)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    config: MiniMaxConfig,
    model_config: Optional[ModelConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    scheduler_config=None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    config: MiniMaxConfig,
    model_config: Optional[ModelConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    scheduler_config=None,
    prefix: str = "",
) -> None:
    super().__init__()

    self.padding_idx = config.pad_token_id
    self.vocab_size = config.vocab_size

    self.decoder_attention_types = getattr(
        config, "attn_type_list", False) or getattr(
            config, "decoder_attention_types", False)
    # The HF format uses "layer_types" instead of "attn_type_list"
    # where "linear_attention" is 0 and "full_attention" is 1
    if not self.decoder_attention_types and hasattr(config, "layer_types"):
        self.decoder_attention_types = []
        for layer_type in config.layer_types:
            if layer_type == "linear_attention":
                self.decoder_attention_types.append(0)
            elif layer_type == "full_attention":
                self.decoder_attention_types.append(1)
            else:
                raise ValueError(f"Unsupported layer type: {layer_type}")
    # Default to full attention
    if not self.decoder_attention_types:
        self.decoder_attention_types = [1] * config.num_hidden_layers
    self.num_layers = config.num_hidden_layers

    self._layer_barrier = False
    if get_pp_group().is_first_rank:
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=self.vocab_size,
        )
    else:
        self.embed_tokens = PPMissingLayer()

    def layer_fn(prefix):
        layer_idx = int(prefix.split('.')[-1])
        layer_config = config
        layer_config.attention_type = self.decoder_attention_types[
            layer_idx]
        layer_config.layer_idx = layer_idx

        decoder_kwargs = {
            "quant_config": quant_config,
            "layer_id": layer_idx,
            "model_config": model_config,
            "cache_config": cache_config
        }

        if layer_config.attention_type == 0:
            decoder_kwargs["linear_layer_id"] = sum(
                1 for i in range(layer_idx)
                if self.decoder_attention_types[i] == 0)
        else:
            decoder_kwargs["linear_layer_id"] = None

        if hasattr(config, "num_local_experts") and isinstance(
                config.num_local_experts, list):
            decoder_kwargs["expert_num"] = config.num_local_experts[
                layer_idx]
        elif hasattr(config, "num_local_experts") and isinstance(
                config.num_local_experts, int):
            decoder_kwargs["expert_num"] = config.num_local_experts
        else:
            decoder_kwargs["expert_num"] = 1

        return MiniMaxText01DecoderLayer(layer_config,
                                         **decoder_kwargs,
                                         prefix=prefix)

    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers")

    linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
                            if self.decoder_attention_types[i] == 0)
    max_slots_number = scheduler_config.max_num_seqs
    self.cache_shape = (linear_layer_nums, max_slots_number,
                        config.num_attention_heads //
                        get_tensor_model_parallel_world_size(),
                        config.head_dim, config.head_dim)
    _dummy = torch.zeros(1)
    self._dtype = _dummy.dtype
    del _dummy

    if not envs.VLLM_USE_V1:
        self.minimax_cache = MinimaxCacheManager(
            dtype=torch.float32, cache_shape=self.cache_shape)

    rope_theta = getattr(config, "rope_theta", 10000)
    head_dim = getattr(config, "head_dim", None)
    if head_dim is None:
        head_dim = config.hidden_size // config.num_attention_heads
    if hasattr(config, "max_model_len") and isinstance(
            config.max_model_len, int):
        max_position_embeddings = min(config.max_position_embeddings,
                                      config.max_model_len)
    self.rotary_emb = MiniMaxText01RotaryEmbedding(
        head_dim,
        rotary_dim=config.rotary_dim
        if hasattr(config, "rotary_dim") else head_dim,
        max_position=max_position_embeddings,
        base=int(rope_theta),
        is_neox_style=True,
        cache_dtype=torch.float32,
    )

    norm_kwargs = {}
    if hasattr(config, "rms_norm_eps"):
        norm_kwargs["eps"] = config.rms_norm_eps
    if get_pp_group().is_last_rank:
        self.norm = RMSNorm(config.hidden_size, **norm_kwargs)
    else:
        self.norm = PPMissingLayer()
    self.embed_scale = 1.0
    return

_clear_prefill_cache

_clear_prefill_cache(
    attn_metadata, minimax_cache_tensors: Tensor, **kwargs
)
Source code in vllm/model_executor/models/minimax_text_01.py
def _clear_prefill_cache(self, attn_metadata,
                         minimax_cache_tensors: torch.Tensor, **kwargs):
    seq_to_slot_maps = {}
    seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), [])
    for _, seq_to_slot_map in (
            self.minimax_cache.cache_indices_mapping.items()):
        seq_to_slot_maps.update(seq_to_slot_map)

    slots_to_clear = []
    for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
        if _prefill_id >= len(seq_id_map):
            break
        seq_id = seq_id_map[_prefill_id]
        if attn_metadata.context_lens_tensor[
                _prefill_id] == 0 and seq_id in seq_to_slot_maps:
            slots_to_clear.append(seq_to_slot_maps[seq_id])

    if slots_to_clear:
        slots_tensor = torch.tensor(slots_to_clear,
                                    device=minimax_cache_tensors.device,
                                    dtype=torch.long)
        minimax_cache_tensors[:, slots_tensor, ...] = 0

forward

forward(
    input_ids: Optional[Tensor],
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(self,
            input_ids: Optional[torch.Tensor],
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
    forward_context = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if not envs.VLLM_USE_V1 and attn_metadata is None:
        return None
    if "request_ids_to_seq_ids" not in kwargs:
        kwargs["request_ids_to_seq_ids"] = {}
    if "finished_requests_ids" not in kwargs:
        kwargs["finished_requests_ids"] = []

    if not envs.VLLM_USE_V1:
        (
            minimax_cache_tensors,
            state_indices_tensor,
        ) = self.minimax_cache.current_run_tensors(**kwargs)
        if getattr(attn_metadata, "num_prefills", 0) > 0:
            self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
                                      **kwargs)

        minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
                                                  state_indices_tensor)
    else:
        minimax_cache_params = None

    if get_pp_group().is_first_rank:
        if inputs_embeds is None:
            hidden_states = self.embed_scale * self.embed_tokens(input_ids)
        else:
            hidden_states = inputs_embeds
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    minimax_cache_index = 0

    for i in range(self.start_layer, self.end_layer):
        layer = self.layers[i]
        if attn_metadata is not None:
            # TODO (tdoublep): this whole thing with the rotary_emb is
            # weird. we shouldn't be passing it via attn_metadata imo.
            if envs.VLLM_USE_V1:
                if isinstance(layer.self_attn, MiniMaxText01Attention):
                    attn_metadata[layer.prefix +
                                  ".attn"].rotary_emb = self.rotary_emb
            else:
                attn_metadata.rotary_emb = self.rotary_emb

        _caches = None
        if not envs.VLLM_USE_V1 and isinstance(
                layer.self_attn, MiniMaxText01LinearAttention):
            current_state_layer = minimax_cache_index
            _caches = minimax_cache_params.at_layer_idx(
                current_state_layer)
            minimax_cache_index += 1
        hidden_states, residual = layer(
            hidden_states=hidden_states,
            positions=positions,
            kv_caches=_caches,
            attn_metadata=attn_metadata,
            residual=residual,
        )
    if not get_pp_group().is_last_rank:
        return IntermediateTensors({
            "hidden_states": hidden_states,
            "residual": residual
        })
    if residual is not None:
        hidden_states, _ = self.norm(hidden_states, residual)
    else:
        hidden_states = self.norm(hidden_states)

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def get_input_embeddings(
    self,
    input_ids: torch.Tensor,
) -> torch.Tensor:
    return self.embed_tokens(input_ids)

MiniMaxText01RMSNormTP

Bases: CustomOp

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01RMSNormTP(CustomOp):
    name = "MiniMaxText01RMSNormTP"

    def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.tp_world = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.weight = nn.Parameter(torch.ones(int(hidden_size /
                                                  self.tp_world)))

        self.weight.weight_loader = self.weight_loader
        self.variance_epsilon = eps
        return

    @staticmethod
    def weight_loader(
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
    ) -> None:
        tp_world = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()

        shard_size = loaded_weight.shape[0] // tp_world
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        param.data.copy_(loaded_weight[shard])
        return

    def _forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
        if self.tp_world > 1:
            variance = tensor_model_parallel_all_reduce(
                variance) / self.tp_world
        x = x * torch.rsqrt(variance + self.variance_epsilon)

        weight = self.weight
        if x.size(-1) != self.weight.size(0):
            if self.weight.size(0) < x.size(-1):
                repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
                full_weight = self.weight.repeat(repeat_count)
                weight = full_weight[:x.size(-1)]
            else:
                weight = self.weight[:x.size(-1)]

        x = x.to(orig_dtype) * weight
        return x

    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        assert residual is None, "RMSNorm does not support residual connection."
        return self._forward(x)

name class-attribute instance-attribute

name = 'MiniMaxText01RMSNormTP'

tp_rank instance-attribute

tp_world instance-attribute

variance_epsilon instance-attribute

variance_epsilon = eps

weight instance-attribute

weight = Parameter(ones(int(hidden_size / tp_world)))

__init__

__init__(hidden_size: int, eps: float = 1e-06) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
    super().__init__()
    self.tp_world = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()
    self.weight = nn.Parameter(torch.ones(int(hidden_size /
                                              self.tp_world)))

    self.weight.weight_loader = self.weight_loader
    self.variance_epsilon = eps
    return

_forward

_forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_text_01.py
def _forward(
    self,
    x: torch.Tensor,
) -> torch.Tensor:
    orig_dtype = x.dtype
    x = x.to(torch.float32)
    variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
    if self.tp_world > 1:
        variance = tensor_model_parallel_all_reduce(
            variance) / self.tp_world
    x = x * torch.rsqrt(variance + self.variance_epsilon)

    weight = self.weight
    if x.size(-1) != self.weight.size(0):
        if self.weight.size(0) < x.size(-1):
            repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
            full_weight = self.weight.repeat(repeat_count)
            weight = full_weight[:x.size(-1)]
        else:
            weight = self.weight[:x.size(-1)]

    x = x.to(orig_dtype) * weight
    return x

forward

forward(
    x: Tensor, residual: Optional[Tensor] = None
) -> Union[Tensor, tuple[Tensor, Tensor]]
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(
    self,
    x: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    assert residual is None, "RMSNorm does not support residual connection."
    return self._forward(x)

weight_loader staticmethod

weight_loader(
    param: Parameter, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
@staticmethod
def weight_loader(
    param: nn.Parameter,
    loaded_weight: torch.Tensor,
) -> None:
    tp_world = get_tensor_model_parallel_world_size()
    tp_rank = get_tensor_model_parallel_rank()

    shard_size = loaded_weight.shape[0] // tp_world
    shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
    param.data.copy_(loaded_weight[shard])
    return

MiniMaxText01RotaryEmbedding

Bases: CustomOp

Source code in vllm/model_executor/models/minimax_text_01.py
class MiniMaxText01RotaryEmbedding(CustomOp):
    name = "MiniMaxText01RotaryEmbedding"

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position: int,
        base: float,
        is_neox_style: bool,
        cache_dtype: torch.dtype,
    ) -> None:
        super().__init__()
        self.head_size = head_size
        self.rotary_dim = rotary_dim
        self.max_position_embeddings = max_position
        self.base = base
        self.is_neox_style = is_neox_style
        self.cache_dtype = cache_dtype
        cache = self._compute_cos_sin_cache().to(cache_dtype)
        self.register_buffer("cos_sin_cache", cache, persistent=False)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        """Compute the inverse frequency."""
        inv_freq = 1.0 / (base**(torch.arange(
            0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        """Compute the cos and sin cache."""
        inv_freq = self._compute_inv_freq(self.base)
        t = torch.arange(self.max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        from vllm import _custom_ops as ops
        self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
        query_cast = query.to(self.cache_dtype)
        key_cast = key.to(self.cache_dtype)
        ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
                             self.cos_sin_cache, self.is_neox_style)
        query = query_cast.to(query.dtype)
        key = key_cast.to(key.dtype)
        return query, key

base instance-attribute

base = base

cache_dtype instance-attribute

cache_dtype = cache_dtype

head_size instance-attribute

head_size = head_size

is_neox_style instance-attribute

is_neox_style = is_neox_style

max_position_embeddings instance-attribute

max_position_embeddings = max_position

name class-attribute instance-attribute

name = 'MiniMaxText01RotaryEmbedding'

rotary_dim instance-attribute

rotary_dim = rotary_dim

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position: int,
    base: float,
    is_neox_style: bool,
    cache_dtype: dtype,
) -> None
Source code in vllm/model_executor/models/minimax_text_01.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position: int,
    base: float,
    is_neox_style: bool,
    cache_dtype: torch.dtype,
) -> None:
    super().__init__()
    self.head_size = head_size
    self.rotary_dim = rotary_dim
    self.max_position_embeddings = max_position
    self.base = base
    self.is_neox_style = is_neox_style
    self.cache_dtype = cache_dtype
    cache = self._compute_cos_sin_cache().to(cache_dtype)
    self.register_buffer("cos_sin_cache", cache, persistent=False)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor

Compute the cos and sin cache.

Source code in vllm/model_executor/models/minimax_text_01.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    """Compute the cos and sin cache."""
    inv_freq = self._compute_inv_freq(self.base)
    t = torch.arange(self.max_position_embeddings, dtype=torch.float)
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos()
    sin = freqs.sin()
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor

Compute the inverse frequency.

Source code in vllm/model_executor/models/minimax_text_01.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    """Compute the inverse frequency."""
    inv_freq = 1.0 / (base**(torch.arange(
        0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
    return inv_freq

forward

forward(
    positions: Tensor, query: Tensor, key: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/models/minimax_text_01.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    from vllm import _custom_ops as ops
    self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
    query_cast = query.to(self.cache_dtype)
    key_cast = key.to(self.cache_dtype)
    ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
                         self.cos_sin_cache, self.is_neox_style)
    query = query_cast.to(query.dtype)
    key = key_cast.to(key.dtype)
    return query, key

replace_weight_name

replace_weight_name(
    name: str,
    key: str = None,
    to: str = None,
    count: int = None,
    prefix: str = None,
) -> str
Source code in vllm/model_executor/models/minimax_text_01.py
def replace_weight_name(name: str,
                        key: str = None,
                        to: str = None,
                        count: int = None,
                        prefix: str = None) -> str:
    name = name.replace(key, to) if count is None else \
        name.replace(key, to, count)
    return name

weight_loader_with_alias

weight_loader_with_alias(alias: str)
Source code in vllm/model_executor/models/minimax_text_01.py
def weight_loader_with_alias(alias: str):

    def wrapper(func: callable):

        def inner_func(param: torch.Tensor,
                       loaded_weight: torch.Tensor,
                       *args,
                       prefix: str = None,
                       **kwargs):
            value = func(param, loaded_weight, *args, **kwargs)
            return value

        return inner_func

    return wrapper