Skip to content

vllm.model_executor.layers.rotary_embedding.linear_scaling_rope

LinearScalingRotaryEmbedding

Bases: RotaryEmbedding

RotaryEmbedding extended with linear scaling.

It supports multiple scaling factors. Since multiple LoRA adapters may have different scaling factors, we need multiple cos/sin caches. In this way, instead of running rotary embedding kernel per lora, we can run multiple lora in a batched way.

In addition to that, we also keep the cos/sin cache for the scaling factor of 1 (default) at all times.

Exemplary for two scaling factors x=1, y and z with embeddings [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],

we construct the cos/sin cache as follows: [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], ... [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]

We then use offsets to index into the cos/sin cache for the respective scaling factors.

The offset to cache can be accessed via scaling_factor_to_offset API.

Credits to the Reddit user /u/kaiokendev

Source code in vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py
class LinearScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with linear scaling.

    It supports multiple scaling factors. Since multiple LoRA adapters may have
    different scaling factors, we need multiple cos/sin caches. In this way,
    instead of running rotary embedding kernel per lora, we can run multiple
    lora in a batched way.

    In addition to that, we also keep the cos/sin cache for the scaling factor
    of 1 (default) at all times.

    Exemplary for two scaling factors x=1, y and z with embeddings
    [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
    [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
    [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],

    we construct the cos/sin cache as follows:
    [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
        ...
     [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]

    We then use offsets to index into the cos/sin cache for
    the respective scaling factors.

    The offset to cache can be accessed via `scaling_factor_to_offset` API.

    Credits to the Reddit user /u/kaiokendev
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_factors: Union[list[float], float],
        dtype: torch.dtype,
    ) -> None:
        if isinstance(scaling_factors, float):
            scaling_factors = [scaling_factors]
        self.scaling_factors: list[float] = scaling_factors  # noqa
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)
        # Lazy initialized.
        self._scaling_factor_to_offset: dict[float, int]

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.base)
        cache_list: list[torch.Tensor] = []
        # offsets to the next cache in a tensor.
        # Each offset corresponds to the same index in scaling_factors.
        offsets: list[int] = []
        for scaling_factor in self.scaling_factors:
            # NOTE(woosuk): self.max_position_embeddings is the original
            # maximum length before applying the rope scaling.
            # Thus, the maximum length after applying the rope scaling is
            # self.max_position_embeddings * self.scaling_factor.
            max_len = self.max_position_embeddings * scaling_factor
            t = torch.arange(max_len, dtype=torch.float)
            t = t / scaling_factor

            freqs = torch.einsum("i,j -> ij", t, inv_freq)
            cos = freqs.cos()
            sin = freqs.sin()
            cache = torch.cat((cos, sin), dim=-1)
            if not cache_list:
                offset = 0
            else:
                last_offset = offsets[-1]
                next_max_len = cache_list[-1].shape[0]
                offset = last_offset + next_max_len
            offsets.append(offset)
            cache_list.append(cache)
        self._scaling_factor_to_offset = {
            float(scaling_factor): offsets[i]
            for i, scaling_factor in enumerate(self.scaling_factors)
        }
        assert len(self.scaling_factors) == len(offsets)
        return torch.cat(cache_list, dim=0)

    @property
    def scaling_factor_to_offset(self) -> dict[float, int]:
        return self._scaling_factor_to_offset

_scaling_factor_to_offset instance-attribute

_scaling_factor_to_offset: dict[float, int]

scaling_factor_to_offset property

scaling_factor_to_offset: dict[float, int]

scaling_factors instance-attribute

scaling_factors: list[float] = scaling_factors

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factors: Union[list[float], float],
    dtype: dtype,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factors: Union[list[float], float],
    dtype: torch.dtype,
) -> None:
    if isinstance(scaling_factors, float):
        scaling_factors = [scaling_factors]
    self.scaling_factors: list[float] = scaling_factors  # noqa
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)
    # Lazy initialized.
    self._scaling_factor_to_offset: dict[float, int]

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.base)
    cache_list: list[torch.Tensor] = []
    # offsets to the next cache in a tensor.
    # Each offset corresponds to the same index in scaling_factors.
    offsets: list[int] = []
    for scaling_factor in self.scaling_factors:
        # NOTE(woosuk): self.max_position_embeddings is the original
        # maximum length before applying the rope scaling.
        # Thus, the maximum length after applying the rope scaling is
        # self.max_position_embeddings * self.scaling_factor.
        max_len = self.max_position_embeddings * scaling_factor
        t = torch.arange(max_len, dtype=torch.float)
        t = t / scaling_factor

        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        if not cache_list:
            offset = 0
        else:
            last_offset = offsets[-1]
            next_max_len = cache_list[-1].shape[0]
            offset = last_offset + next_max_len
        offsets.append(offset)
        cache_list.append(cache)
    self._scaling_factor_to_offset = {
        float(scaling_factor): offsets[i]
        for i, scaling_factor in enumerate(self.scaling_factors)
    }
    assert len(self.scaling_factors) == len(offsets)
    return torch.cat(cache_list, dim=0)