Long Text Embedding with Chunked Processing¶
Source examples/online_serving/openai_embedding_long_text.
This directory contains examples for using vLLM's chunked processing feature to handle long text embedding that exceeds the model's maximum context length.
๐ Quick Start¶
Start the Server¶
Use the provided script to start a vLLM server with chunked processing enabled:
# Basic usage (supports very long texts up to ~3M tokens)
./service.sh
# Custom configuration with different models
MODEL_NAME="jinaai/jina-embeddings-v3" \
MAX_EMBED_LEN=1048576 \
./service.sh
# For extremely long documents
MODEL_NAME="intfloat/multilingual-e5-large" \
MAX_EMBED_LEN=3072000 \
./service.sh
Test Long Text Embedding¶
Run the comprehensive test client:
๐ Files¶
File | Description |
---|---|
service.sh | Server startup script with chunked processing enabled |
client.py | Comprehensive test client for long text embedding |
โ๏ธ Configuration¶
Server Configuration¶
The key parameters for chunked processing are in the --override-pooler-config
:
{
"pooling_type": "auto",
"normalize": true,
"enable_chunked_processing": true,
"max_embed_len": 3072000
}
Note
pooling_type
sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length.
Chunked Processing Behavior¶
Chunked processing uses MEAN aggregation for cross-chunk combination when input exceeds the model's native maximum length:
Component | Behavior | Description |
---|---|---|
Within chunks | Model's native pooling | Uses the model's configured pooling strategy |
Cross-chunk aggregation | Always MEAN | Weighted averaging based on chunk token counts |
Performance | Optimal | All chunks processed for complete semantic coverage |
Environment Variables¶
Variable | Default | Description |
---|---|---|
MODEL_NAME | intfloat/multilingual-e5-large | Embedding model to use (supports multiple models) |
PORT | 31090 | Server port |
GPU_COUNT | 1 | Number of GPUs to use |
MAX_EMBED_LEN | 3072000 | Maximum embedding input length (supports very long documents) |
POOLING_TYPE | auto | Model's native pooling type: auto , MEAN , CLS , LAST (only affects within-chunk pooling, not cross-chunk aggregation) |
API_KEY | EMPTY | API key for authentication |
๐ง How It Works¶
- Enhanced Input Validation:
max_embed_len
allows accepting inputs longer thanmax_model_len
without environment variables - Smart Chunking: Text is split based on
max_position_embeddings
to maintain semantic integrity - Unified Processing: All chunks processed separately through the model using its configured pooling strategy
- MEAN Aggregation: When input exceeds model's native length, results combined using token count-based weighted averaging across all chunks
- Consistent Output: Final embeddings maintain the same dimensionality as standard processing
Input Length Handling¶
- Within max_embed_len: Input is accepted and processed (up to 3M+ tokens)
- Exceeds max_position_embeddings: Chunked processing is automatically triggered
- Exceeds max_embed_len: Input is rejected with clear error message
- No environment variables required: Works without
VLLM_ALLOW_LONG_MAX_MODEL_LEN
Extreme Long Text Support¶
With MAX_EMBED_LEN=3072000
, you can process:
- Academic papers: Full research papers with references
- Legal documents: Complete contracts and legal texts
- Books: Entire chapters or small books
- Code repositories: Large codebases and documentation
๐ Performance Characteristics¶
Chunked Processing Performance¶
Aspect | Behavior | Performance |
---|---|---|
Chunk Processing | All chunks processed with native pooling | Consistent with input length |
Cross-chunk Aggregation | MEAN weighted averaging | Minimal overhead |
Memory Usage | Proportional to number of chunks | Moderate, scalable |
Semantic Quality | Complete text coverage | Optimal for long documents |
๐งช Test Cases¶
The test client demonstrates:
- โ Short text: Normal processing (baseline)
- โ Medium text: Single chunk processing
- โ Long text: Multi-chunk processing with aggregation
- โ Very long text: Many chunks processing
- โ Extreme long text: Document-level processing (100K+ tokens)
- โ Batch processing: Mixed-length inputs in one request
- โ Consistency: Reproducible results across runs
๐ Troubleshooting¶
Common Issues¶
- Chunked processing not enabled:
Solution: Ensure enable_chunked_processing: true
in pooler config
- Input exceeds max_embed_len:
Solution: Increase max_embed_len
in pooler config or reduce input length
- Memory errors:
Solution: Reduce chunk size by adjusting model's max_position_embeddings
or use fewer GPUs
- Slow processing: Expected: Long text takes more time due to multiple inference calls
Debug Information¶
Server logs show chunked processing activity:
INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing
INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096)
๐ค Contributing¶
To extend chunked processing support to other embedding models:
- Check model compatibility with the pooling architecture
- Test with various text lengths
- Validate embedding quality compared to single-chunk processing
- Submit PR with test cases and documentation updates
๐ Enhanced Features¶
max_embed_len Parameter¶
The new max_embed_len
parameter provides:
- Simplified Configuration: No need for
VLLM_ALLOW_LONG_MAX_MODEL_LEN
environment variable - Flexible Input Validation: Accept inputs longer than
max_model_len
up tomax_embed_len
- Extreme Length Support: Process documents with millions of tokens
- Clear Error Messages: Better feedback when inputs exceed limits
- Backward Compatibility: Existing configurations continue to work
Example materials¶
client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example script demonstrating long text embedding with chunked processing in vLLM.
This example shows how to use vLLM's chunked processing feature to handle text
inputs that exceed the model's maximum token length. The feature automatically
splits long text into chunks and handles different pooling types optimally.
Prerequisites:
1. Start vLLM server with chunked processing enabled:
# MEAN pooling (processes all chunks, recommended for complete coverage)
vllm serve intfloat/multilingual-e5-large \
--override-pooler-config \
'{"pooling_type": "MEAN", "normalize": true, ' \
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
--served-model-name multilingual-e5-large \
--trust-remote-code \
--port 31090 \
--api-key your-api-key
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
vllm serve BAAI/bge-large-en-v1.5 \
--override-pooler-config \
'{"pooling_type": "CLS", "normalize": true, ' \
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
--served-model-name bge-large-en-v1.5 \
--trust-remote-code \
--port 31090 \
--api-key your-api-key
2. Install required dependencies:
pip install openai requests
"""
import time
import numpy as np
from openai import OpenAI
# Configuration
API_KEY = "your-api-key" # Replace with your actual API key
BASE_URL = "http://localhost:31090/v1"
MODEL_NAME = "multilingual-e5-large"
def generate_long_text(base_text: str, repeat_count: int) -> str:
"""Generate long text by repeating base text."""
return base_text * repeat_count
def test_embedding_with_different_lengths():
"""Test embedding generation with different text lengths."""
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
# Test cases with different text lengths
test_cases = [
{
"name": "Short Text",
"text": "Hello, this is a short text for embedding.",
"expected_chunks": 1,
},
{
"name": "Medium Text",
"text": generate_long_text(
"This is a medium-length text that should fit within the "
"model's context window. " * 20,
2,
),
"expected_chunks": 1,
},
{
"name": "Long Text (2 chunks)",
"text": generate_long_text(
"This is a very long text that will exceed the model's "
"maximum context length and trigger chunked processing. " * 50,
5,
),
"expected_chunks": 2,
},
{
"name": "Very Long Text (3+ chunks)",
"text": generate_long_text(
"This text is extremely long and will definitely "
"require multiple chunks for processing. " * 100,
10,
),
"expected_chunks": 3,
},
]
print("๐งช Testing vLLM Long Text Embedding with Chunked Processing")
print("=" * 70)
for i, test_case in enumerate(test_cases, 1):
print(f"\n๐ Test {i}: {test_case['name']}")
print(f"Text length: {len(test_case['text'])} characters")
try:
start_time = time.time()
response = client.embeddings.create(
input=test_case["text"], model=MODEL_NAME, encoding_format="float"
)
end_time = time.time()
processing_time = end_time - start_time
# Extract embedding data
embedding = response.data[0].embedding
embedding_dim = len(embedding)
print("โ
Success!")
print(f" - Embedding dimension: {embedding_dim}")
print(f" - Processing time: {processing_time:.2f}s")
print(f" - Expected chunks: ~{test_case['expected_chunks']}")
print(f" - First 5 values: {embedding[:5]}")
except Exception as e:
print(f"โ Failed: {str(e)}")
def test_batch_embedding():
"""Test batch embedding with mixed-length inputs."""
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
print("\n๐ Testing Batch Embedding with Mixed Lengths")
print("=" * 50)
# Mix of short and long texts
batch_inputs = [
"Short text 1",
generate_long_text("Medium length text that fits in one chunk. " * 20, 1),
"Another short text",
generate_long_text("Long text requiring chunked processing. " * 100, 5),
]
try:
start_time = time.time()
response = client.embeddings.create(
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
)
end_time = time.time()
processing_time = end_time - start_time
print("โ
Batch processing successful!")
print(f" - Number of inputs: {len(batch_inputs)}")
print(f" - Number of embeddings: {len(response.data)}")
print(f" - Total processing time: {processing_time:.2f}s")
print(
f" - Average time per input: {processing_time / len(batch_inputs):.2f}s"
)
for i, data in enumerate(response.data):
input_length = len(batch_inputs[i])
embedding_dim = len(data.embedding)
print(
f" - Input {i + 1}: {input_length} chars โ {embedding_dim}D embedding"
)
except Exception as e:
print(f"โ Batch processing failed: {str(e)}")
def test_multiple_long_texts_batch():
"""Test batch processing with multiple long texts to verify chunk ID uniqueness."""
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
print("\n๐ง Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)")
print("=" * 70)
# Create multiple distinct long texts that will all require chunking
# Note: All pooling types now use MEAN aggregation across chunks:
# - Native pooling (MEAN/CLS/LAST) is used within each chunk
# - MEAN aggregation combines results across all chunks
# - Full semantic coverage for all pooling types
long_texts = [
generate_long_text(
"First long document about artificial intelligence and machine learning. "
* 80,
6,
),
generate_long_text(
"Second long document about natural language processing and transformers. "
* 80,
6,
),
generate_long_text(
"Third long document about computer vision and neural networks. " * 80, 6
),
]
# Add some short texts to mix things up
batch_inputs = [
"Short text before long texts",
long_texts[0],
"Short text between long texts",
long_texts[1],
long_texts[2],
"Short text after long texts",
]
print("๐ Batch composition:")
for i, text in enumerate(batch_inputs):
length = len(text)
text_type = "Long (will be chunked)" if length > 5000 else "Short"
print(f" - Input {i + 1}: {length} chars ({text_type})")
try:
start_time = time.time()
response = client.embeddings.create(
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
)
end_time = time.time()
processing_time = end_time - start_time
print("\nโ
Multiple long texts batch processing successful!")
print(f" - Number of inputs: {len(batch_inputs)}")
print(f" - Number of embeddings returned: {len(response.data)}")
print(f" - Total processing time: {processing_time:.2f}s")
# Verify each embedding is different (no incorrect aggregation)
embeddings = [data.embedding for data in response.data]
if len(embeddings) >= 3:
import numpy as np
# Compare embeddings of the long texts (indices 1, 3, 4)
long_embeddings = [
np.array(embeddings[1]), # First long text
np.array(embeddings[3]), # Second long text
np.array(embeddings[4]), # Third long text
]
print("\n๐ Verifying embedding uniqueness:")
for i in range(len(long_embeddings)):
for j in range(i + 1, len(long_embeddings)):
cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / (
np.linalg.norm(long_embeddings[i])
* np.linalg.norm(long_embeddings[j])
)
print(
f" - Similarity between long text {i + 1} and {j + 1}: "
f"{cosine_sim:.4f}"
)
if (
cosine_sim < 0.9
): # Different content should have lower similarity
print(" โ
Good: Embeddings are appropriately different")
else:
print(
" โ ๏ธ High similarity - may indicate chunk "
"aggregation issue"
)
print("\n๐ Per-input results:")
for i, data in enumerate(response.data):
input_length = len(batch_inputs[i])
embedding_dim = len(data.embedding)
embedding_norm = np.linalg.norm(data.embedding)
print(
f" - Input {i + 1}: {input_length} chars โ {embedding_dim}D "
f"embedding (norm: {embedding_norm:.4f})"
)
print(
"\nโ
This test verifies the fix for chunk ID collisions in "
"batch processing"
)
print(" - Before fix: Multiple long texts would have conflicting chunk IDs")
print(" - After fix: Each prompt's chunks have unique IDs with prompt index")
except Exception as e:
print(f"โ Multiple long texts batch test failed: {str(e)}")
print(" This might indicate the chunk ID collision bug is present!")
def test_embedding_consistency():
"""Test that chunked processing produces consistent results."""
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
print("\n๐ Testing Embedding Consistency")
print("=" * 40)
# Use the same long text multiple times
long_text = generate_long_text(
"Consistency test text for chunked processing validation. " * 50, 3
)
embeddings = []
try:
for i in range(3):
response = client.embeddings.create(
input=long_text, model=MODEL_NAME, encoding_format="float"
)
embeddings.append(response.data[0].embedding)
print(f" - Generated embedding {i + 1}")
# Check consistency (embeddings should be identical)
if len(embeddings) >= 2:
# Calculate similarity between first two embeddings
emb1 = np.array(embeddings[0])
emb2 = np.array(embeddings[1])
# Cosine similarity
cosine_sim = np.dot(emb1, emb2) / (
np.linalg.norm(emb1) * np.linalg.norm(emb2)
)
print("โ
Consistency test completed!")
print(f" - Cosine similarity between runs: {cosine_sim:.6f}")
print(" - Expected: ~1.0 (identical embeddings)")
if cosine_sim > 0.999:
print(" - โ
High consistency achieved!")
else:
print(" - โ ๏ธ Consistency may vary due to numerical precision")
except Exception as e:
print(f"โ Consistency test failed: {str(e)}")
def main():
"""Main function to run all tests."""
print("๐ vLLM Long Text Embedding Client")
print(f"๐ก Connecting to: {BASE_URL}")
print(f"๐ค Model: {MODEL_NAME}")
masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****"
print(f"๐ API Key: {masked_key}")
# Run all test cases
test_embedding_with_different_lengths()
test_batch_embedding()
test_multiple_long_texts_batch()
test_embedding_consistency()
print("\n" + "=" * 70)
print("๐ All tests completed!")
print("\n๐ก Key Features Demonstrated:")
print(" - โ
Automatic chunked processing for long text")
print(" - โ
Seamless handling of mixed-length batches")
print(" - โ
Multiple long texts in single batch (chunk ID fix)")
print(" - โ
Unified chunked processing:")
print(" โข Native pooling used within each chunk")
print(" โข MEAN aggregation across all chunks")
print(" โข Complete semantic coverage for all pooling types")
print(" - โ
Consistent embedding generation")
print(" - โ
Backward compatibility with short text")
print("\n๐ For more information, see:")
print(
" - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html"
)
print(" - Chunked Processing Guide: openai_embedding_long_text.md")
if __name__ == "__main__":
main()
service.sh
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# vLLM Embedding Server with Enhanced Chunked Processing
# This script starts a vLLM server with chunked processing enabled for long text embedding.
# Now supports proper pooling type validation and model-specific configurations.
set -euo pipefail
# Configuration
MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"}
MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"}
PORT=${PORT:-31090}
GPU_COUNT=${GPU_COUNT:-1}
MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000}
API_KEY=${API_KEY:-"your-api-key"}
# Enhanced pooling configuration with model-specific defaults
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
export VLLM_ENABLE_CHUNKED_PROCESSING=true
export CUDA_VISIBLE_DEVICES=2,3,4,5
# export VLLM_ATTENTION_BACKEND=XFORMERS
echo "๐ Starting vLLM Embedding Server with Enhanced Chunked Processing"
echo "=================================================================="
# Environment variables for optimization
export VLLM_WORKER_MULTIPROC_METHOD=spawn
# Function to determine optimal pooling type for known models
get_optimal_pooling_type() {
local model="$1"
case "$model" in
*"e5-"* | *"multilingual-e5"*)
echo "MEAN" # E5 series native pooling
;;
*"bge-"*)
echo "CLS" # BGE series native pooling
;;
*"gte-"*)
echo "LAST" # GTE series native pooling
;;
*"sentence-t5"* | *"st5"*)
echo "MEAN" # Sentence-T5 native pooling
;;
*"jina-embeddings"*)
echo "MEAN" # Jina embeddings native pooling
;;
*"Qwen"*"Embedding"*)
echo "LAST" # Qwen embeddings native pooling
;;
*)
echo "MEAN" # Default native pooling for unknown models
;;
esac
}
# Auto-detect pooling type if not explicitly set
if [ "$POOLING_TYPE" = "auto" ]; then
POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME")
echo "๐ Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME"
fi
# Display configuration
echo "๐ Configuration:"
echo " - Model: $MODEL_NAME"
echo " - Port: $PORT"
echo " - GPU Count: $GPU_COUNT"
echo " - Enhanced Chunked Processing: ${VLLM_ENABLE_CHUNKED_PROCESSING}"
echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens"
echo " - Native Pooling Type: $POOLING_TYPE + Normalization"
echo " - Cross-chunk Aggregation: MEAN (automatic)"
echo ""
# Validate GPU availability
if command -v nvidia-smi &> /dev/null; then
gpu_count=$(nvidia-smi --list-gpus | wc -l)
echo "๐ฅ๏ธ Available GPUs: $gpu_count"
if [ "$GPU_COUNT" -gt "$gpu_count" ]; then
echo "โ ๏ธ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available"
echo " Adjusting to use $gpu_count GPUs"
GPU_COUNT=$gpu_count
fi
else
echo "โ ๏ธ Warning: nvidia-smi not found. GPU detection skipped."
fi
# Chunked processing uses unified MEAN aggregation
echo "โน๏ธ Chunked Processing: Using $POOLING_TYPE pooling within chunks, MEAN aggregation across chunks"
echo " - All chunks processed for complete semantic coverage"
echo " - Weighted averaging based on chunk token counts"
echo ""
echo "๐ง Starting server with enhanced chunked processing configuration..."
# Build pooler config JSON
POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}"
# Start vLLM server with enhanced chunked processing
vllm serve "$MODEL_NAME" \
--tensor-parallel-size "$GPU_COUNT" \
--enforce-eager \
--override-pooler-config "$POOLER_CONFIG" \
--served-model-name ${MODEL_CODE} \
--api-key "$API_KEY" \
--trust-remote-code \
--port "$PORT" \
--host 0.0.0.0
echo ""
echo "โ
vLLM Embedding Server started successfully!"
echo ""
echo "๐ก Server Information:"
echo " - Base URL: http://localhost:$PORT"
echo " - Model Code: ${MODEL_CODE}"
echo " - API Key: $API_KEY"
echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN"
echo ""
echo "๐งช Test the server with:"
echo " python examples/online_serving/openai_embedding_long_text_client.py"
echo ""
echo "๐ Enhanced features enabled:"
echo " โ
Intelligent native pooling type detection"
echo " โ
Unified MEAN aggregation for chunked processing"
echo " โ
Model-specific native pooling optimization"
echo " โ
Enhanced max embedding length (${MAX_EMBED_LEN} tokens)"
echo " โ
Complete semantic coverage for all pooling types"
echo " โ
OpenAI-compatible API"
echo " โ
GPU acceleration"
echo ""
echo "๐ง Advanced usage:"
echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection"
echo " - Set MAX_EMBED_LEN to adjust maximum input length"
echo " - All pooling types use MEAN aggregation across chunks"