Skip to content

vllm.model_executor.models.phi

Inference-only Phi-1.5 model compatible with HuggingFace weights.

PhiAttention

Bases: Module

Source code in vllm/model_executor/models/phi.py
class PhiAttention(nn.Module):

    def __init__(self,
                 config: PhiConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_size,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
        )
        self.dense = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            quant_config=quant_config,
        )

        scaling = self.head_size**-0.5
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
        rope_theta = getattr(config, "rope_theta", 10000.0)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          2048)
        self.rotary_emb = get_rope(
            self.head_size,
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
        )
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.dense(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads,
    head_size,
    scaling,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

dense instance-attribute

dense = RowParallelLinear(
    hidden_size, hidden_size, quant_config=quant_config
)

head_size instance-attribute

head_size = hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

num_heads instance-attribute

num_heads = (
    total_num_heads // tensor_model_parallel_world_size
)

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_size,
    total_num_heads,
    bias=True,
    quant_config=quant_config,
)

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_size,
    rotary_dim=rotary_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
)

total_num_heads instance-attribute

total_num_heads = num_attention_heads

__init__

__init__(
    config: PhiConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/phi.py
def __init__(self,
             config: PhiConfig,
             cache_config: Optional[CacheConfig] = None,
             quant_config: Optional[QuantizationConfig] = None,
             prefix: str = ""):
    super().__init__()
    self.total_num_heads = config.num_attention_heads
    self.hidden_size = config.hidden_size
    self.head_size = self.hidden_size // self.total_num_heads

    tensor_model_parallel_world_size = (
        get_tensor_model_parallel_world_size())
    assert self.total_num_heads % tensor_model_parallel_world_size == 0
    self.num_heads = (self.total_num_heads //
                      tensor_model_parallel_world_size)

    # pylint: disable=C0103
    self.qkv_proj = QKVParallelLinear(
        self.hidden_size,
        self.head_size,
        self.total_num_heads,
        bias=True,
        quant_config=quant_config,
    )
    self.dense = RowParallelLinear(
        self.hidden_size,
        self.hidden_size,
        quant_config=quant_config,
    )

    scaling = self.head_size**-0.5
    rotary_dim = int(config.partial_rotary_factor *
                     (config.hidden_size // config.num_attention_heads))
    assert rotary_dim % 2 == 0

    # pylint: disable=C0301
    # Refer to:
    # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
    rope_theta = getattr(config, "rope_theta", 10000.0)
    max_position_embeddings = getattr(config, "max_position_embeddings",
                                      2048)
    self.rotary_emb = get_rope(
        self.head_size,
        rotary_dim=rotary_dim,
        max_position=max_position_embeddings,
        base=rope_theta,
    )
    self.attn = Attention(self.num_heads,
                          self.head_size,
                          scaling,
                          cache_config=cache_config,
                          quant_config=quant_config,
                          prefix=f"{prefix}.attn")

forward

forward(
    position_ids: Tensor, hidden_states: Tensor
) -> Tensor
Source code in vllm/model_executor/models/phi.py
def forward(
    self,
    position_ids: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.chunk(chunks=3, dim=-1)
    q, k = self.rotary_emb(position_ids, q, k)
    attn_output = self.attn(q, k, v)
    output, _ = self.dense(attn_output)
    return output

PhiForCausalLM

Bases: Module, SupportsLoRA, SupportsPP

Source code in vllm/model_executor/models/phi.py
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        self.config = config
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
        self.lora_config = lora_config

        self.quant_config = quant_config

        self.model = PhiModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "model"))

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      bias=True,
                                      quant_config=quant_config)
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

config instance-attribute

config = config

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size,
    hidden_size,
    bias=True,
    quant_config=quant_config,
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

lora_config instance-attribute

lora_config = lora_config

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = PhiModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"]
}

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/phi.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config
    self.config = config
    # lm_head use bias, cannot share word embeddings
    assert not config.tie_word_embeddings
    self.lora_config = lora_config

    self.quant_config = quant_config

    self.model = PhiModel(vllm_config=vllm_config,
                          prefix=maybe_prefix(prefix, "model"))

    self.lm_head = ParallelLMHead(config.vocab_size,
                                  config.hidden_size,
                                  bias=True,
                                  quant_config=quant_config)
    self.logits_processor = LogitsProcessor(config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]
Source code in vllm/model_executor/models/phi.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata, self.lm_head.bias)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/phi.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds)

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/phi.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/phi.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights)

PhiLayer

Bases: Module

Source code in vllm/model_executor/models/phi.py
class PhiLayer(nn.Module):

    def __init__(self,
                 config: PhiConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.self_attn = PhiAttention(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.self_attn")
        self.mlp = PhiMLP(config, quant_config)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states

input_layernorm instance-attribute

input_layernorm = LayerNorm(hidden_size, eps=layer_norm_eps)

mlp instance-attribute

mlp = PhiMLP(config, quant_config)

self_attn instance-attribute

self_attn = PhiAttention(
    config,
    cache_config,
    quant_config,
    prefix=f"{prefix}.self_attn",
)

__init__

__init__(
    config: PhiConfig,
    cache_config: Optional[CacheConfig] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/phi.py
def __init__(self,
             config: PhiConfig,
             cache_config: Optional[CacheConfig] = None,
             quant_config: Optional[QuantizationConfig] = None,
             prefix: str = ""):
    super().__init__()
    self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
    self.self_attn = PhiAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.self_attn")
    self.mlp = PhiMLP(config, quant_config)

forward

forward(
    position_ids: Tensor, hidden_states: Tensor
) -> Tensor
Source code in vllm/model_executor/models/phi.py
def forward(
    self,
    position_ids: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    residual = hidden_states
    hidden_states = self.input_layernorm(hidden_states)
    attn_outputs = self.self_attn(
        position_ids=position_ids,
        hidden_states=hidden_states,
    )
    feed_forward_hidden_states = self.mlp(hidden_states)
    hidden_states = attn_outputs + feed_forward_hidden_states + residual
    return hidden_states

PhiMLP

Bases: Module

Source code in vllm/model_executor/models/phi.py
class PhiMLP(nn.Module):

    def __init__(self,
                 config: PhiConfig,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
            quant_config=quant_config,
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
            quant_config=quant_config,
        )
        self.act = get_act_fn(config.hidden_act)

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states

act instance-attribute

act = get_act_fn(hidden_act)

fc1 instance-attribute

fc1 = ColumnParallelLinear(
    hidden_size, n_inner, quant_config=quant_config
)

fc2 instance-attribute

fc2 = RowParallelLinear(
    n_inner, hidden_size, quant_config=quant_config
)

__init__

__init__(
    config: PhiConfig,
    quant_config: Optional[QuantizationConfig] = None,
)
Source code in vllm/model_executor/models/phi.py
def __init__(self,
             config: PhiConfig,
             quant_config: Optional[QuantizationConfig] = None):
    super().__init__()

    n_inner = getattr(config, "n_inner", None)
    n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

    self.fc1 = ColumnParallelLinear(
        config.hidden_size,
        n_inner,
        quant_config=quant_config,
    )
    self.fc2 = RowParallelLinear(
        n_inner,
        config.hidden_size,
        quant_config=quant_config,
    )
    self.act = get_act_fn(config.hidden_act)

forward

forward(hidden_states)
Source code in vllm/model_executor/models/phi.py
def forward(self, hidden_states):
    hidden_states, _ = self.fc1(hidden_states)
    hidden_states = self.act(hidden_states)
    hidden_states, _ = self.fc2(hidden_states)
    return hidden_states

PhiModel

Bases: Module

Source code in vllm/model_executor/models/phi.py
@support_torch_compile
class PhiModel(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.quant_config = quant_config
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: PhiLayer(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers")
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(positions, hidden_states)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        hidden_states = self.final_layernorm(hidden_states)

        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size
)

final_layernorm instance-attribute

final_layernorm = LayerNorm(hidden_size, eps=layer_norm_eps)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], hidden_size
    )
)

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/phi.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config

    self.config = config
    self.quant_config = quant_config
    self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                               config.hidden_size)
    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: PhiLayer(
            config, cache_config, quant_config, prefix=prefix),
        prefix=f"{prefix}.layers")
    self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(["hidden_states"],
                                                config.hidden_size))

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/phi.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors],
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
    for layer in self.layers[self.start_layer:self.end_layer]:
        hidden_states = layer(positions, hidden_states)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors({"hidden_states": hidden_states})

    hidden_states = self.final_layernorm(hidden_states)

    return hidden_states

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/phi.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/phi.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v")
    ]
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    for name, loaded_weight in weights:
        if "rotary_emb.inv_freq" in name:
            continue

        for (param_name, weight_name, shard_id) in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            # pylint: disable=E1136

            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params