Skip to content

vllm.model_executor.layers.mamba.mamba_utils

MambaStateDtypeCalculator

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
class MambaStateDtypeCalculator:

    @classmethod
    def linear_attention_state_dtype(
        cls,
        model_dtype: Union[ModelDType, torch.dtype],
        mamba_cache_dtype: MambaDType,
    ) -> tuple[torch.dtype, ...]:
        # TODO (tdoublep) requires testing
        if mamba_cache_dtype == "float32":
            raise ValueError("fp32 state for minimax is not yet supported")
        state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
        return (state_dtype, )

    @classmethod
    def mamba1_state_dtype(
        cls,
        model_dtype: Union[ModelDType, torch.dtype],
        mamba_cache_dtype: MambaDType,
        mamba_ssm_cache_dtype: MambaDType,
    ) -> tuple[torch.dtype, ...]:
        # TODO (tdoublep) requires kernel changes
        if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
            raise ValueError("fp32 state for mamba1 is not yet supported")
        else:
            return MambaStateDtypeCalculator.mamba2_state_dtype(
                model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)

    @classmethod
    def mamba2_state_dtype(
        cls,
        model_dtype: Union[ModelDType, torch.dtype],
        mamba_cache_dtype: MambaDType,
        mamba_ssm_cache_dtype: MambaDType,
    ) -> tuple[torch.dtype, ...]:
        conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
                                                    model_dtype)
        if mamba_ssm_cache_dtype == "auto":
            temporal_state_dtype = conv_state_dtype
        else:
            temporal_state_dtype = (
                STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])

        return (conv_state_dtype, temporal_state_dtype)

    @classmethod
    def short_conv_state_dtype(
        cls,
        model_dtype: Union[ModelDType, torch.dtype],
        mamba_cache_dtype: MambaDType,
    ) -> tuple[torch.dtype, ...]:
        conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
                                                    model_dtype)
        return (conv_state_dtype, )

linear_attention_state_dtype classmethod

linear_attention_state_dtype(
    model_dtype: Union[ModelDType, dtype],
    mamba_cache_dtype: MambaDType,
) -> tuple[dtype, ...]
Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def linear_attention_state_dtype(
    cls,
    model_dtype: Union[ModelDType, torch.dtype],
    mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
    # TODO (tdoublep) requires testing
    if mamba_cache_dtype == "float32":
        raise ValueError("fp32 state for minimax is not yet supported")
    state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
    return (state_dtype, )

mamba1_state_dtype classmethod

mamba1_state_dtype(
    model_dtype: Union[ModelDType, dtype],
    mamba_cache_dtype: MambaDType,
    mamba_ssm_cache_dtype: MambaDType,
) -> tuple[dtype, ...]
Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def mamba1_state_dtype(
    cls,
    model_dtype: Union[ModelDType, torch.dtype],
    mamba_cache_dtype: MambaDType,
    mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
    # TODO (tdoublep) requires kernel changes
    if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
        raise ValueError("fp32 state for mamba1 is not yet supported")
    else:
        return MambaStateDtypeCalculator.mamba2_state_dtype(
            model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)

mamba2_state_dtype classmethod

mamba2_state_dtype(
    model_dtype: Union[ModelDType, dtype],
    mamba_cache_dtype: MambaDType,
    mamba_ssm_cache_dtype: MambaDType,
) -> tuple[dtype, ...]
Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def mamba2_state_dtype(
    cls,
    model_dtype: Union[ModelDType, torch.dtype],
    mamba_cache_dtype: MambaDType,
    mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
    conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
                                                model_dtype)
    if mamba_ssm_cache_dtype == "auto":
        temporal_state_dtype = conv_state_dtype
    else:
        temporal_state_dtype = (
            STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])

    return (conv_state_dtype, temporal_state_dtype)

short_conv_state_dtype classmethod

short_conv_state_dtype(
    model_dtype: Union[ModelDType, dtype],
    mamba_cache_dtype: MambaDType,
) -> tuple[dtype, ...]
Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def short_conv_state_dtype(
    cls,
    model_dtype: Union[ModelDType, torch.dtype],
    mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
    conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
                                                model_dtype)
    return (conv_state_dtype, )

MambaStateShapeCalculator

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
class MambaStateShapeCalculator:

    @classmethod
    def linear_attention_state_shape(
        cls,
        num_heads: int,
        tp_size: int,
        head_dim: int,
    ) -> tuple[tuple[int, int, int], ...]:

        state_shape = (num_heads // tp_size, head_dim, head_dim)
        return (state_shape, )

    @classmethod
    def mamba1_state_shape(
        cls,
        tp_world_size: int,
        intermediate_size: int,
        state_size: int,
        conv_kernel: int,
        use_v1: bool = True,
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        conv_state_shape = (divide(intermediate_size,
                                   tp_world_size), conv_kernel - 1)

        temporal_state_shape = (divide(intermediate_size,
                                       tp_world_size), state_size)

        # In V0, the conv_state shape was swapped during allocation in
        # MambaCacheManager, but in V1 it needs to be determined here at the
        # calculation level
        if use_v1:
            conv_state_shape = conv_state_shape[1], conv_state_shape[0]

        return conv_state_shape, temporal_state_shape

    @classmethod
    def mamba2_state_shape(
        cls,
        tp_world_size: int,
        intermediate_size: int,
        n_groups: int,
        num_heads: int,
        head_dim: int,
        state_size: int,
        conv_kernel: int,
        use_v1: bool = True,
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        # if n_groups is not divisible by world_size, need to extend the shards
        # to ensure all groups needed by a head is sharded along with it
        n_groups = n_groups + cls.extra_groups_for_head_shards(
            n_groups, tp_world_size)
        # heads and n_groups are TP-ed
        conv_dim = intermediate_size + 2 * n_groups * state_size

        # contiguous along 'dim' axis
        conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
        if not use_v1:
            conv_state_shape = conv_state_shape[1], conv_state_shape[0]

        # These are not TP-ed as they depend on A, dt_bias, D
        # - they are typically small
        #   e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
        temporal_state_shape = (divide(num_heads,
                                       tp_world_size), head_dim, state_size)
        return conv_state_shape, temporal_state_shape

    @classmethod
    def short_conv_state_shape(
        cls,
        tp_world_size: int,
        intermediate_size: int,
        conv_kernel: int,
        use_v1: bool = True,
    ) -> tuple[tuple[int, int]]:
        conv_dim = divide(intermediate_size, tp_world_size)
        conv_state_shape = (conv_kernel - 1, conv_dim)
        if not use_v1:
            conv_state_shape = conv_state_shape[1], conv_state_shape[0]
        return (conv_state_shape, )

    @classmethod
    def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
        """Compute the increase in group numbers to account for
        replication in order to accompany the head shards."""

        # in the case ngoups % tp_size == 0, this will be zero
        if ngroups % tp_size == 0:
            return 0

        # for n_groups == 1, this is exactly tp_size - n_groups
        return tp_size - ngroups

extra_groups_for_head_shards classmethod

extra_groups_for_head_shards(ngroups: int, tp_size: int)

Compute the increase in group numbers to account for replication in order to accompany the head shards.

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
    """Compute the increase in group numbers to account for
    replication in order to accompany the head shards."""

    # in the case ngoups % tp_size == 0, this will be zero
    if ngroups % tp_size == 0:
        return 0

    # for n_groups == 1, this is exactly tp_size - n_groups
    return tp_size - ngroups

linear_attention_state_shape classmethod

linear_attention_state_shape(
    num_heads: int, tp_size: int, head_dim: int
) -> tuple[tuple[int, int, int], ...]
Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def linear_attention_state_shape(
    cls,
    num_heads: int,
    tp_size: int,
    head_dim: int,
) -> tuple[tuple[int, int, int], ...]:

    state_shape = (num_heads // tp_size, head_dim, head_dim)
    return (state_shape, )

mamba1_state_shape classmethod

mamba1_state_shape(
    tp_world_size: int,
    intermediate_size: int,
    state_size: int,
    conv_kernel: int,
    use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int]]
Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def mamba1_state_shape(
    cls,
    tp_world_size: int,
    intermediate_size: int,
    state_size: int,
    conv_kernel: int,
    use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int]]:
    conv_state_shape = (divide(intermediate_size,
                               tp_world_size), conv_kernel - 1)

    temporal_state_shape = (divide(intermediate_size,
                                   tp_world_size), state_size)

    # In V0, the conv_state shape was swapped during allocation in
    # MambaCacheManager, but in V1 it needs to be determined here at the
    # calculation level
    if use_v1:
        conv_state_shape = conv_state_shape[1], conv_state_shape[0]

    return conv_state_shape, temporal_state_shape

mamba2_state_shape classmethod

mamba2_state_shape(
    tp_world_size: int,
    intermediate_size: int,
    n_groups: int,
    num_heads: int,
    head_dim: int,
    state_size: int,
    conv_kernel: int,
    use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]
Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def mamba2_state_shape(
    cls,
    tp_world_size: int,
    intermediate_size: int,
    n_groups: int,
    num_heads: int,
    head_dim: int,
    state_size: int,
    conv_kernel: int,
    use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
    # if n_groups is not divisible by world_size, need to extend the shards
    # to ensure all groups needed by a head is sharded along with it
    n_groups = n_groups + cls.extra_groups_for_head_shards(
        n_groups, tp_world_size)
    # heads and n_groups are TP-ed
    conv_dim = intermediate_size + 2 * n_groups * state_size

    # contiguous along 'dim' axis
    conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
    if not use_v1:
        conv_state_shape = conv_state_shape[1], conv_state_shape[0]

    # These are not TP-ed as they depend on A, dt_bias, D
    # - they are typically small
    #   e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
    temporal_state_shape = (divide(num_heads,
                                   tp_world_size), head_dim, state_size)
    return conv_state_shape, temporal_state_shape

short_conv_state_shape classmethod

short_conv_state_shape(
    tp_world_size: int,
    intermediate_size: int,
    conv_kernel: int,
    use_v1: bool = True,
) -> tuple[tuple[int, int]]
Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def short_conv_state_shape(
    cls,
    tp_world_size: int,
    intermediate_size: int,
    conv_kernel: int,
    use_v1: bool = True,
) -> tuple[tuple[int, int]]:
    conv_dim = divide(intermediate_size, tp_world_size)
    conv_state_shape = (conv_kernel - 1, conv_dim)
    if not use_v1:
        conv_state_shape = conv_state_shape[1], conv_state_shape[0]
    return (conv_state_shape, )