Skip to content

vllm.model_executor.layers.rotary_embedding.common

apply_rotary_emb_dispatch

apply_rotary_emb_dispatch(
    x: Tensor, cos: Tensor, sin: Tensor, is_neox_style: bool
) -> Tensor

Parameters:

Name Type Description Default
x Tensor

[num_tokens, num_heads, head_size]

required
cos Tensor

[num_tokens, head_size // 2]

required
sin Tensor

[num_tokens, head_size // 2]

required
is_neox_style bool

Whether to use the Neox-style or GPT-J-style rotary positional embeddings.

required
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
                              sin: torch.Tensor,
                              is_neox_style: bool) -> torch.Tensor:
    """
    Args:
        x: [num_tokens, num_heads, head_size]
        cos: [num_tokens, head_size // 2]
        sin: [num_tokens, head_size // 2]
        is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
            positional embeddings.
    """
    if current_platform.is_cuda():
        return apply_rotary_emb(x.unsqueeze(0), cos, sin,
                                not is_neox_style).squeeze(0)
    else:
        return apply_rotary_emb_torch(x, cos, sin, is_neox_style)

apply_rotary_emb_torch

apply_rotary_emb_torch(
    x: Tensor, cos: Tensor, sin: Tensor, is_neox_style: bool
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def apply_rotary_emb_torch(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_neox_style: bool,
) -> torch.Tensor:
    cos = cos.unsqueeze(-2).to(x.dtype)
    sin = sin.unsqueeze(-2).to(x.dtype)
    if is_neox_style:
        x1, x2 = torch.chunk(x, 2, dim=-1)
    else:
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
    o1 = x1 * cos - x2 * sin
    o2 = x2 * cos + x1 * sin
    if is_neox_style:
        return torch.cat((o1, o2), dim=-1)
    else:
        return torch.stack((o1, o2), dim=-1).flatten(-2)

rotate_gptj

rotate_gptj(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)

rotate_neox

rotate_neox(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

yarn_find_correction_dim

yarn_find_correction_dim(
    num_rotations: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> float
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_find_correction_dim(num_rotations: int,
                             dim: int,
                             base: float = 10000,
                             max_position_embeddings: int = 2048) -> float:
    return (dim * math.log(max_position_embeddings /
                           (num_rotations * 2 * math.pi))) / (2 *
                                                              math.log(base))

yarn_find_correction_range

yarn_find_correction_range(
    low_rot: int,
    high_rot: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> tuple[int, int]
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_find_correction_range(
        low_rot: int,
        high_rot: int,
        dim: int,
        base: float = 10000,
        max_position_embeddings: int = 2048) -> tuple[int, int]:
    low = math.floor(
        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
    high = math.ceil(
        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case

yarn_get_mscale

yarn_get_mscale(scale: float = 1) -> float
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_get_mscale(scale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0

yarn_linear_ramp_mask

yarn_linear_ramp_mask(
    low: float, high: float, dim: int, dtype: dtype
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_linear_ramp_mask(low: float, high: float, dim: int,
                          dtype: torch.dtype) -> torch.Tensor:
    if low == high:
        high += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func