Encoder Decoder Multimodal¶
Source examples/offline_inference/encoder_decoder_multimodal.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
import time
from collections.abc import Sequence
from dataclasses import asdict
from typing import NamedTuple
from vllm import LLM, EngineArgs, PromptType, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompts: Sequence[PromptType]
def run_donut():
engine_args = EngineArgs(
model="naver-clova-ix/donut-base-finetuned-docvqa",
max_num_seqs=2,
limit_mm_per_prompt={"image": 1},
dtype="float16",
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
)
# The input image size for donut-base-finetuned-docvqa is 2560 x 1920,
# and the patch_size is 4 x 4.
# Therefore, the initial number of patches is:
# Height: 1920 / 4 = 480 patches
# Width: 2560 / 4 = 640 patches
# The Swin model uses a staged downsampling approach,
# defined by the "depths": [2, 2, 14, 2] configuration.
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
# which halves the feature map's dimensions (dividing both height and width by 2).
# Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320.
# Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160.
# Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80.
# Because vLLM needs to fill the image features with an encoder_prompt,
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
# we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799.
prompts = [
{
"encoder_prompt": {
"prompt": "".join(["$"] * 4799),
"multi_modal_data": {
"image": fetch_image(
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
) # noqa: E501
},
},
"decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501
},
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_florence2():
engine_args = EngineArgs(
model="microsoft/Florence-2-large",
tokenizer="Isotr0py/Florence-2-tokenizer",
max_num_seqs=8,
trust_remote_code=True,
limit_mm_per_prompt={"image": 1},
dtype="half",
)
prompts = [
{ # implicit prompt with task token
"prompt": "<DETAILED_CAPTION>",
"multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
},
{ # explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "Describe in detail what is shown in the image.",
"multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
},
"decoder_prompt": "",
},
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_mllama():
engine_args = EngineArgs(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": 1},
dtype="half",
)
prompts = [
{ # Implicit prompt
"prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image,
},
},
{ # Explicit prompt
"encoder_prompt": {
"prompt": "<|image|>",
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image,
},
},
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
},
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_whisper():
engine_args = EngineArgs(
model="openai/whisper-large-v3-turbo",
max_model_len=448,
max_num_seqs=16,
limit_mm_per_prompt={"audio": 1},
dtype="half",
)
prompts = [
{ # Test implicit prompt
"prompt": "<|startoftranscript|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt": "<|startoftranscript|>",
},
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
model_example_map = {
"donut": run_donut,
"florence2": run_florence2,
"mllama": run_mllama,
"whisper": run_whisper,
}
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using vLLM for offline inference with "
"vision language models for text generation"
)
parser.add_argument(
"--model-type",
"-m",
type=str,
default="mllama",
choices=model_example_map.keys(),
help='Huggingface "model_type".',
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()
def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
req_data = model_example_map[model]()
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
prompts = req_data.prompts
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=64,
skip_special_tokens=False,
)
start = time.time()
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
duration = time.time() - start
print("Duration:", duration)
print("RPS:", len(prompts) / duration)
if __name__ == "__main__":
args = parse_args()
main(args)