Skip to content

vllm.model_executor.layers.rotary_embedding.llama3_rope

Llama3RotaryEmbedding

Bases: RotaryEmbedding

Source code in vllm/model_executor/layers/rotary_embedding/llama3_rope.py
class Llama3RotaryEmbedding(RotaryEmbedding):

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
        scaling_factor: float,
        low_freq_factor: float,
        high_freq_factor: float,
        orig_max_position: int,
    ) -> None:
        self.scaling_factor = scaling_factor
        self.low_freq_factor = low_freq_factor
        self.high_freq_factor = high_freq_factor
        self.orig_max_position = orig_max_position
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, base: float) -> torch.Tensor:
        inv_freqs = super()._compute_inv_freq(base)
        low_freq_wavelen = self.orig_max_position / self.low_freq_factor
        high_freq_wavelen = self.orig_max_position / self.high_freq_factor

        wave_len = 2 * math.pi / inv_freqs
        if self.low_freq_factor != self.high_freq_factor:
            smooth = (self.orig_max_position / wave_len - self.low_freq_factor
                      ) / (self.high_freq_factor - self.low_freq_factor)
        else:
            smooth = 0
        new_freqs = torch.where(
            wave_len < high_freq_wavelen,
            inv_freqs,
            torch.where(
                wave_len > low_freq_wavelen,
                inv_freqs / self.scaling_factor,
                (1 - smooth) * inv_freqs / self.scaling_factor +
                smooth * inv_freqs,
            ),
        )
        return new_freqs

high_freq_factor instance-attribute

high_freq_factor = high_freq_factor

low_freq_factor instance-attribute

low_freq_factor = low_freq_factor

orig_max_position instance-attribute

orig_max_position = orig_max_position

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
    scaling_factor: float,
    low_freq_factor: float,
    high_freq_factor: float,
    orig_max_position: int,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/llama3_rope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
    scaling_factor: float,
    low_freq_factor: float,
    high_freq_factor: float,
    orig_max_position: int,
) -> None:
    self.scaling_factor = scaling_factor
    self.low_freq_factor = low_freq_factor
    self.high_freq_factor = high_freq_factor
    self.orig_max_position = orig_max_position
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_inv_freq

_compute_inv_freq(base: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/llama3_rope.py
def _compute_inv_freq(self, base: float) -> torch.Tensor:
    inv_freqs = super()._compute_inv_freq(base)
    low_freq_wavelen = self.orig_max_position / self.low_freq_factor
    high_freq_wavelen = self.orig_max_position / self.high_freq_factor

    wave_len = 2 * math.pi / inv_freqs
    if self.low_freq_factor != self.high_freq_factor:
        smooth = (self.orig_max_position / wave_len - self.low_freq_factor
                  ) / (self.high_freq_factor - self.low_freq_factor)
    else:
        smooth = 0
    new_freqs = torch.where(
        wave_len < high_freq_wavelen,
        inv_freqs,
        torch.where(
            wave_len > low_freq_wavelen,
            inv_freqs / self.scaling_factor,
            (1 - smooth) * inv_freqs / self.scaling_factor +
            smooth * inv_freqs,
        ),
    )
    return new_freqs