Skip to content

Pooling Models

vLLM also supports pooling models, such as embedding, classification and reward models.

In vLLM, pooling models implement the VllmModelForPooling interface. These models use a Pooler to extract the final hidden states of the input before returning them.

Note

We currently support pooling models primarily as a matter of convenience. This is not guaranteed to have any performance improvement over using HF Transformers / Sentence Transformers directly.

We are now planning to optimize pooling models in vLLM. Please comment on Issue #21796 if you have any suggestions!

Configuration

Model Runner

Run a model in pooling mode via the option --runner pooling.

Tip

There is no need to set this option in the vast majority of cases as vLLM can automatically detect the model runner to use via --runner auto.

Model Conversion

vLLM can adapt models for various pooling tasks via the option --convert <type>.

If --runner pooling has been set (manually or automatically) but the model does not implement the VllmModelForPooling interface, vLLM will attempt to automatically convert the model according to the architecture names shown in the table below.

Architecture --convert Supported pooling tasks
*ForTextEncoding, *EmbeddingModel, *Model embed encode, embed
*For*Classification, *ClassificationModel classify encode, classify, score
*ForRewardModeling, *RewardModel reward encode

Tip

You can explicitly set --convert <type> to specify how to convert the model.

Pooling Tasks

Each pooling model in vLLM supports one or more of these tasks according to Pooler.get_supported_tasks, enabling the corresponding APIs:

Task APIs
encode LLM.reward(...)
embed LLM.embed(...), LLM.score(...)*
classify LLM.classify(...)
score LLM.score(...)

* The LLM.score(...) API falls back to embed task if the model does not support score task.

Pooler Configuration

Predefined models

If the Pooler defined by the model accepts pooler_config, you can override some of its attributes via the --override-pooler-config option.

Converted models

If the model has been converted via --convert (see above), the pooler assigned to each task has the following attributes by default:

Task Pooling Type Normalization Softmax
reward ALL
embed LAST ✅︎
classify LAST ✅︎

When loading Sentence Transformers models, its Sentence Transformers configuration file (modules.json) takes priority over the model's defaults.

You can further customize this via the --override-pooler-config option, which takes priority over both the model's and Sentence Transformers's defaults.

Offline Inference

The LLM class provides various methods for offline inference. See configuration for a list of options when initializing the model.

LLM.embed

The embed method outputs an embedding vector for each prompt. It is primarily designed for embedding models.

from vllm import LLM

llm = LLM(model="intfloat/e5-small", runner="pooling")
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")

A code example can be found here: examples/offline_inference/basic/embed.py

LLM.classify

The classify method outputs a probability vector for each prompt. It is primarily designed for classification models.

from vllm import LLM

llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", runner="pooling")
(output,) = llm.classify("Hello, my name is")

probs = output.outputs.probs
print(f"Class Probabilities: {probs!r} (size={len(probs)})")

A code example can be found here: examples/offline_inference/basic/classify.py

LLM.score

The score method outputs similarity scores between sentence pairs. It is designed for embedding models and cross-encoder models. Embedding models use cosine similarity, and cross-encoder models serve as rerankers between candidate query-document pairs in RAG systems.

Note

vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. To handle RAG at a higher level, you should use integration frameworks such as LangChain.

from vllm import LLM

llm = LLM(model="BAAI/bge-reranker-v2-m3", runner="pooling")
(output,) = llm.score("What is the capital of France?",
                      "The capital of Brazil is Brasilia.")

score = output.outputs.score
print(f"Score: {score}")

A code example can be found here: examples/offline_inference/basic/score.py

LLM.reward

The reward method is available to all reward models in vLLM. It returns the extracted hidden states directly.

from vllm import LLM

llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True)
(output,) = llm.reward("Hello, my name is")

data = output.outputs.data
print(f"Data: {data!r}")

A code example can be found here: examples/offline_inference/basic/reward.py

LLM.encode

The encode method is available to all pooling models in vLLM. It returns the extracted hidden states directly.

Note

Please use one of the more specific methods or set the task directly when using LLM.encode:

  • For embeddings, use LLM.embed(...) or pooling_task="embed".
  • For classification logits, use LLM.classify(...) or pooling_task="classify".
  • For rewards, use LLM.reward(...) or pooling_task="reward".
  • For similarity scores, use LLM.score(...).
from vllm import LLM

llm = LLM(model="intfloat/e5-small", runner="pooling")
(output,) = llm.encode("Hello, my name is", pooling_task="embed")

data = output.outputs.data
print(f"Data: {data!r}")

Online Serving

Our OpenAI-Compatible Server provides endpoints that correspond to the offline APIs:

  • Pooling API is similar to LLM.encode, being applicable to all types of pooling models.
  • Embeddings API is similar to LLM.embed, accepting both text and multi-modal inputs for embedding models.
  • Classification API is similar to LLM.classify and is applicable to sequence classification models.
  • Score API is similar to LLM.score for cross-encoder models.

Matryoshka Embeddings

Matryoshka Embeddings or Matryoshka Representation Learning (MRL) is a technique used in training embedding models. It allows user to trade off between performance and cost.

Warning

Not all embedding models are trained using Matryoshka Representation Learning. To avoid misuse of the dimensions parameter, vLLM returns an error for requests that attempt to change the output dimension of models that do not support Matryoshka Embeddings.

For example, setting dimensions parameter while using the BAAI/bge-m3 model will result in the following error.

{"object":"error","message":"Model \"BAAI/bge-m3\" does not support matryoshka representation, changing output dimensions will lead to poor results.","type":"BadRequestError","param":null,"code":400}

Manually enable Matryoshka Embeddings

There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if is_matryoshka is True in config.json, it is allowed to change the output to arbitrary dimensions. Using matryoshka_dimensions can control the allowed output dimensions.

For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using hf_overrides={"is_matryoshka": True}, hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]} (offline) or --hf-overrides '{"is_matryoshka": true}', --hf-overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'(online).

Here is an example to serve a model with Matryoshka Embeddings enabled.

vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf-overrides '{"matryoshka_dimensions":[256]}'

Offline Inference

You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter in PoolingParams.

from vllm import LLM, PoolingParams

llm = LLM(model="jinaai/jina-embeddings-v3",
          runner="pooling",
          trust_remote_code=True)
outputs = llm.embed(["Follow the white rabbit."],
                    pooling_params=PoolingParams(dimensions=32))
print(outputs[0].outputs)

A code example can be found here: examples/offline_inference/embed_matryoshka_fy.py

Online Inference

Use the following command to start vllm server.

vllm serve jinaai/jina-embeddings-v3 --trust-remote-code

You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter.

curl http://127.0.0.1:8000/v1/embeddings \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
    "input": "Follow the white rabbit.",
    "model": "jinaai/jina-embeddings-v3",
    "encoding_format": "float",
    "dimensions": 32
  }'

Expected output:

{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}

A openai client example can be found here: examples/online_serving/openai_embedding_matryoshka_fy.py