vllm.model_executor.models.gritlm
GritLM ¶
Bases: LlamaForCausalLM
This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
The class inherits from LlamaForCausalLM and provides a custom pooling
layer.
The main difference between the pooling layer in GritLM and the one in
LlamaForCausalLM is that GritLM ignores the query instruction in the prompt
when pooling the hidden states.
Embedding prompts should be in the following format:
- With instruction: "<|user|>
INSTRUCTION <|embed|> PROMPT". - Without instruction: "<|embed|> PROMPT".
Generation prompts should be in the following format:
- "<|user|>
PROMPT <|assistant|> "
Source code in vllm/model_executor/models/gritlm.py
pooler instance-attribute
¶
pooler = DispatchPooler(
{
"encode": for_encode(pooler_config),
"embed": GritLMPooler(model_config),
}
)
__init__ ¶
__init__(
vllm_config: VllmConfig, prefix: str = "", **kwargs
) -> None
Source code in vllm/model_executor/models/gritlm.py
GritLMMeanPool ¶
Bases: Module
As MeanPool
, but only includes non-instruction tokens.
Source code in vllm/model_executor/models/gritlm.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
|
embed_newline_pattern_ids instance-attribute
¶
embed_pattern_ids instance-attribute
¶
token_ids instance-attribute
¶
token_ids = {
tok: (convert_tokens_to_ids([tok])[0])
for tok in [
"<s>",
"▁<",
"<",
"|",
"embed",
">",
"<0x0A>",
"user",
]
}
user_pattern_ids instance-attribute
¶
__init__ ¶
__init__(model_config: ModelConfig)
Source code in vllm/model_executor/models/gritlm.py
_find_array ¶
_find_array(
arr: ndarray,
target: ndarray,
start_idx: int = 0,
end_idx: Optional[int] = None,
) -> int
Find the first occurrence of target
in arr
starting from start_idx
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arr | ndarray | The array to search within. | required |
target | ndarray | The consecutive subsequence to find. | required |
start_idx | int | The starting index to search from (inclusive). | 0 |
end_idx | Optional[int] | The ending index to search from (exclusive). | None |
Returns:
Type | Description |
---|---|
int | The index of the first occurrence of |
Source code in vllm/model_executor/models/gritlm.py
_get_instruction_len ¶
Get the length of the instruction in the prompt.
We do a pattern matching to find the instruction in the prompt, and then return the length of the instruction.
The pattern matching is done using integers instead of strings because the prompt is given as a list of token IDs.
Source code in vllm/model_executor/models/gritlm.py
forward ¶
forward(
hidden_states: Union[Tensor, list[Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[Tensor], Tensor]
Source code in vllm/model_executor/models/gritlm.py
forward_all ¶
forward_all(
hidden_states: Tensor,
prompt_lens: Tensor,
instr_lens: Tensor,
) -> Union[list[Tensor], Tensor]
Source code in vllm/model_executor/models/gritlm.py
forward_one ¶
forward_one(
hidden_states: Tensor,
prompt_len: Optional[Tensor] = None,
instr_len: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/models/gritlm.py
get_pooling_updates ¶
get_pooling_updates(
task: PoolingTask,
) -> PoolingParamsUpdate
get_supported_tasks ¶
get_supported_tasks() -> Set[PoolingTask]
GritLMPooler ¶
Bases: Pooler
Source code in vllm/model_executor/models/gritlm.py
__init__ ¶
__init__(model_config: ModelConfig)
forward ¶
forward(
hidden_states: Tensor, pooling_metadata: PoolingMetadata
) -> PoolerOutput
Source code in vllm/model_executor/models/gritlm.py
get_pooling_updates ¶
get_pooling_updates(
task: PoolingTask,
) -> PoolingParamsUpdate
get_supported_tasks ¶
get_supported_tasks() -> Set[PoolingTask]