Skip to content

vllm.v1.sample.logits_processor

Modules:

Name Description
builtin
interface
state

BUILTIN_LOGITS_PROCESSORS module-attribute

LOGITSPROCS_GROUP module-attribute

LOGITSPROCS_GROUP = 'vllm.logits_processors'

STR_POOLING_REJECTS_LOGITSPROCS module-attribute

STR_POOLING_REJECTS_LOGITSPROCS = "Pooling models do not support custom logits processors."

__all__ module-attribute

__all__ = [
    "LogitsProcessor",
    "LogitBiasLogitsProcessor",
    "MinPLogitsProcessor",
    "MinTokensLogitsProcessor",
    "BatchUpdate",
    "BatchUpdateBuilder",
    "MoveDirectionality",
    "LogitsProcessors",
    "build_logitsprocs",
    "STR_POOLING_REJECTS_LOGITSPROCS",
    "LOGITSPROCS_GROUP",
]

logger module-attribute

logger = init_logger(__name__)

BatchUpdate dataclass

Persistent batch state change info for logitsprocs

Source code in vllm/v1/sample/logits_processor/interface.py
@dataclass(frozen=True)
class BatchUpdate:
    """Persistent batch state change info for logitsprocs"""
    batch_size: int  # Current num reqs in batch

    # Metadata for requests added to, removed from, and moved
    # within the persistent batch.
    #
    # Key assumption: the `output_tok_ids` list (which is an element of each
    # tuple in `added`) is a reference to the request's running output tokens
    # list; via this reference, the logits processors always see the latest
    # list of generated output tokens
    removed: Sequence[RemovedRequest]
    moved: Sequence[MovedRequest]
    added: Sequence[AddedRequest]

added instance-attribute

batch_size instance-attribute

batch_size: int

moved instance-attribute

removed instance-attribute

__init__

__init__(
    batch_size: int,
    removed: Sequence[RemovedRequest],
    moved: Sequence[MovedRequest],
    added: Sequence[AddedRequest],
) -> None

BatchUpdateBuilder

Helps track persistent batch state changes and build a batch update data structure for logitsprocs Assumptions: * All information about requests removed from persistent batch during a step is aggregated in self._removed through calls to self.removed_append() at the beginning of a step. This must happen before the first time that self.removed, self.pop_removed() or self.peek_removed() are invoked in a given step * After the first time that self.removed, self.pop_removed() or self.peek_removed() are read in a step, no new removals are registered using self.removed_append() * Elements of self._removed are never directly modified, added or removed (i.e. modification is only via self.removed_append() and self.pop_removed()) Guarantees under above assumptions: * self.removed is always sorted in descending order * self.pop_removed() and self.peek_removed() both return the lowest removed request index in the current step

Source code in vllm/v1/sample/logits_processor/state.py
class BatchUpdateBuilder:
    """Helps track persistent batch state changes and build
    a batch update data structure for logitsprocs
    Assumptions:
    * All information about requests removed from persistent batch
      during a step is aggregated in self._removed through calls to
      self.removed_append() at the beginning of a step. This must happen
      before the first time that self.removed, self.pop_removed()
      or self.peek_removed() are invoked in a given step
    * After the first time that self.removed, self.pop_removed()
      or self.peek_removed() are read in a step, no new removals
      are registered using self.removed_append()
    * Elements of self._removed are never directly modified, added or
      removed (i.e. modification is only via self.removed_append() and
      self.pop_removed())
    Guarantees under above assumptions:
    * self.removed is always sorted in descending order
    * self.pop_removed() and self.peek_removed() both return
      the lowest removed request index in the current step
    """

    _removed: list[RemovedRequest]
    _is_removed_sorted: bool
    moved: list[MovedRequest]
    added: list[AddedRequest]

    def __init__(
        self,
        removed: Optional[list[RemovedRequest]] = None,
        moved: Optional[list[MovedRequest]] = None,
        added: Optional[list[AddedRequest]] = None,
    ) -> None:
        self._removed = removed or []
        self.moved = moved or []
        self.added = added or []
        self._is_removed_sorted = False

        # Used to track changes in the pooling case
        # where we don't populate the added list.
        self.batch_changed = False

    def _ensure_removed_sorted(self) -> None:
        """Sort removed request indices in
        descending order.
        Idempotent after first call in a
        given step, until reset.
        """
        if not self._is_removed_sorted:
            self._removed.sort(reverse=True)
            self._is_removed_sorted = True

    @property
    def removed(self) -> list[RemovedRequest]:
        """Removed request indices sorted in
        descending order"""
        self._ensure_removed_sorted()
        return self._removed

    def removed_append(self, index: int) -> None:
        """Register the removal of a request from the persistent batch.

        Must not be called after the first time self.removed,
        self.pop_removed() or self.peek_removed() are invoked.

        Args:
          index: request index
        """
        if self._is_removed_sorted:
            raise RuntimeError("Cannot register new removed request after"
                               " self.removed has been read.")
        self._removed.append(index)
        self.batch_changed = True

    def has_removed(self) -> bool:
        return bool(self._removed)

    def peek_removed(self) -> Optional[int]:
        """Return lowest removed request index"""
        if self.has_removed():
            self._ensure_removed_sorted()
            return self._removed[-1]
        return None

    def pop_removed(self) -> Optional[int]:
        """Pop lowest removed request index"""
        if self.has_removed():
            self._ensure_removed_sorted()
            return self._removed.pop()
        return None

    def reset(self) -> bool:
        """Returns True if there were any changes to the batch."""
        self._is_removed_sorted = False
        self._removed.clear()
        self.moved.clear()
        self.added.clear()
        batch_changed = self.batch_changed
        self.batch_changed = False
        return batch_changed

    def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
        """Generate a logitsprocs batch update data structure and reset
        internal batch update builder state.

        Args:
          batch_size: current persistent batch size

        Returns:
          Frozen logitsprocs batch update instance; `None` if no updates
        """
        # Reset removal-sorting logic
        self._is_removed_sorted = False
        self.batch_changed = False
        if not any((self._removed, self.moved, self.added)):
            # No update; short-circuit
            return None
        # Build batch state update
        batch_update = BatchUpdate(
            batch_size=batch_size,
            removed=self._removed,
            moved=self.moved,
            added=self.added,
        )
        self._removed = []
        self.moved = []
        self.added = []
        return batch_update

_is_removed_sorted instance-attribute

_is_removed_sorted: bool = False

_removed instance-attribute

_removed: list[RemovedRequest] = removed or []

added instance-attribute

added: list[AddedRequest] = added or []

batch_changed instance-attribute

batch_changed = False

moved instance-attribute

moved: list[MovedRequest] = moved or []

removed property

removed: list[RemovedRequest]

Removed request indices sorted in descending order

__init__

__init__(
    removed: Optional[list[RemovedRequest]] = None,
    moved: Optional[list[MovedRequest]] = None,
    added: Optional[list[AddedRequest]] = None,
) -> None
Source code in vllm/v1/sample/logits_processor/state.py
def __init__(
    self,
    removed: Optional[list[RemovedRequest]] = None,
    moved: Optional[list[MovedRequest]] = None,
    added: Optional[list[AddedRequest]] = None,
) -> None:
    self._removed = removed or []
    self.moved = moved or []
    self.added = added or []
    self._is_removed_sorted = False

    # Used to track changes in the pooling case
    # where we don't populate the added list.
    self.batch_changed = False

_ensure_removed_sorted

_ensure_removed_sorted() -> None

Sort removed request indices in descending order. Idempotent after first call in a given step, until reset.

Source code in vllm/v1/sample/logits_processor/state.py
def _ensure_removed_sorted(self) -> None:
    """Sort removed request indices in
    descending order.
    Idempotent after first call in a
    given step, until reset.
    """
    if not self._is_removed_sorted:
        self._removed.sort(reverse=True)
        self._is_removed_sorted = True

get_and_reset

get_and_reset(batch_size: int) -> Optional[BatchUpdate]

Generate a logitsprocs batch update data structure and reset internal batch update builder state.

Parameters:

Name Type Description Default
batch_size int

current persistent batch size

required

Returns:

Type Description
Optional[BatchUpdate]

Frozen logitsprocs batch update instance; None if no updates

Source code in vllm/v1/sample/logits_processor/state.py
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
    """Generate a logitsprocs batch update data structure and reset
    internal batch update builder state.

    Args:
      batch_size: current persistent batch size

    Returns:
      Frozen logitsprocs batch update instance; `None` if no updates
    """
    # Reset removal-sorting logic
    self._is_removed_sorted = False
    self.batch_changed = False
    if not any((self._removed, self.moved, self.added)):
        # No update; short-circuit
        return None
    # Build batch state update
    batch_update = BatchUpdate(
        batch_size=batch_size,
        removed=self._removed,
        moved=self.moved,
        added=self.added,
    )
    self._removed = []
    self.moved = []
    self.added = []
    return batch_update

has_removed

has_removed() -> bool
Source code in vllm/v1/sample/logits_processor/state.py
def has_removed(self) -> bool:
    return bool(self._removed)

peek_removed

peek_removed() -> Optional[int]

Return lowest removed request index

Source code in vllm/v1/sample/logits_processor/state.py
def peek_removed(self) -> Optional[int]:
    """Return lowest removed request index"""
    if self.has_removed():
        self._ensure_removed_sorted()
        return self._removed[-1]
    return None

pop_removed

pop_removed() -> Optional[int]

Pop lowest removed request index

Source code in vllm/v1/sample/logits_processor/state.py
def pop_removed(self) -> Optional[int]:
    """Pop lowest removed request index"""
    if self.has_removed():
        self._ensure_removed_sorted()
        return self._removed.pop()
    return None

removed_append

removed_append(index: int) -> None

Register the removal of a request from the persistent batch.

Must not be called after the first time self.removed, self.pop_removed() or self.peek_removed() are invoked.

Parameters:

Name Type Description Default
index int

request index

required
Source code in vllm/v1/sample/logits_processor/state.py
def removed_append(self, index: int) -> None:
    """Register the removal of a request from the persistent batch.

    Must not be called after the first time self.removed,
    self.pop_removed() or self.peek_removed() are invoked.

    Args:
      index: request index
    """
    if self._is_removed_sorted:
        raise RuntimeError("Cannot register new removed request after"
                           " self.removed has been read.")
    self._removed.append(index)
    self.batch_changed = True

reset

reset() -> bool

Returns True if there were any changes to the batch.

Source code in vllm/v1/sample/logits_processor/state.py
def reset(self) -> bool:
    """Returns True if there were any changes to the batch."""
    self._is_removed_sorted = False
    self._removed.clear()
    self.moved.clear()
    self.added.clear()
    batch_changed = self.batch_changed
    self.batch_changed = False
    return batch_changed

LogitBiasLogitsProcessor

Bases: LogitsProcessor

Source code in vllm/v1/sample/logits_processor/builtin.py
class LogitBiasLogitsProcessor(LogitsProcessor):

    def __init__(self, _, device: torch.device, is_pin_memory: bool):
        self.device = device
        self.pin_memory = is_pin_memory
        self.biases: dict[int, dict[int, float]] = {}

        self.bias_tensor: torch.Tensor = torch.tensor(())
        self.logits_slice = (self._device_tensor([], torch.int32),
                             self._device_tensor([], torch.int32))

    def is_argmax_invariant(self) -> bool:
        """Logit bias can rebalance token probabilities and change the
        outcome of argmax in greedy sampling."""
        return False

    def update_state(self, batch_update: Optional[BatchUpdate]):
        if not batch_update:
            return

        needs_update: bool = False
        # Process added requests.
        for index, params, _, _ in batch_update.added:
            if lb := params.logit_bias:
                self.biases[index] = lb
                needs_update = True
            else:
                # Drop biases metadata at batch index
                if self.biases.pop(index, None) is not None:
                    # If a new request replaces an old request which
                    # specified biases, we should update processor tensors
                    needs_update = True

        if self.biases:
            # Process removed requests.
            for index in batch_update.removed:
                if self.biases.pop(index, None):
                    needs_update = True

            # Process moved requests, unidirectional (a->b) and swap (a<->b)
            for a_index, b_index, direct in batch_update.moved:
                if direct == MoveDirectionality.UNIDIRECTIONAL:
                    if (a_entry := self.biases.pop(a_index, None)) is None:
                        if self.biases.pop(b_index, None) is not None:
                            needs_update = True
                    else:
                        self.biases[b_index] = a_entry
                        needs_update = True
                else:
                    a_entry = self.biases.pop(a_index, None)
                    if (b_entry := self.biases.pop(b_index, None)) is not None:
                        self.biases[a_index] = b_entry
                        needs_update = True
                    if a_entry is not None:
                        self.biases[b_index] = a_entry
                        needs_update = True

        # Update tensors if needed.
        if needs_update:
            reqs, tok_ids, biases = [], [], []
            for req, lb in self.biases.items():
                reqs.extend([req] * len(lb))
                tok_ids.extend(lb.keys())
                biases.extend(lb.values())

            self.bias_tensor = self._device_tensor(biases, torch.float32)
            self.logits_slice = (self._device_tensor(reqs, torch.int32),
                                 self._device_tensor(tok_ids, torch.int32))

    def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
        return (torch.tensor(data,
                             device="cpu",
                             dtype=dtype,
                             pin_memory=self.pin_memory).to(device=self.device,
                                                            non_blocking=True))

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if self.biases:
            logits[self.logits_slice] += self.bias_tensor
        return logits

bias_tensor instance-attribute

bias_tensor: Tensor = tensor(())

biases instance-attribute

biases: dict[int, dict[int, float]] = {}

device instance-attribute

device = device

logits_slice instance-attribute

logits_slice = (
    _device_tensor([], int32),
    _device_tensor([], int32),
)

pin_memory instance-attribute

pin_memory = is_pin_memory

__init__

__init__(_, device: device, is_pin_memory: bool)
Source code in vllm/v1/sample/logits_processor/builtin.py
def __init__(self, _, device: torch.device, is_pin_memory: bool):
    self.device = device
    self.pin_memory = is_pin_memory
    self.biases: dict[int, dict[int, float]] = {}

    self.bias_tensor: torch.Tensor = torch.tensor(())
    self.logits_slice = (self._device_tensor([], torch.int32),
                         self._device_tensor([], torch.int32))

_device_tensor

_device_tensor(data: list, dtype: dtype) -> Tensor
Source code in vllm/v1/sample/logits_processor/builtin.py
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
    return (torch.tensor(data,
                         device="cpu",
                         dtype=dtype,
                         pin_memory=self.pin_memory).to(device=self.device,
                                                        non_blocking=True))

apply

apply(logits: Tensor) -> Tensor
Source code in vllm/v1/sample/logits_processor/builtin.py
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    if self.biases:
        logits[self.logits_slice] += self.bias_tensor
    return logits

is_argmax_invariant

is_argmax_invariant() -> bool

Logit bias can rebalance token probabilities and change the outcome of argmax in greedy sampling.

Source code in vllm/v1/sample/logits_processor/builtin.py
def is_argmax_invariant(self) -> bool:
    """Logit bias can rebalance token probabilities and change the
    outcome of argmax in greedy sampling."""
    return False

update_state

update_state(batch_update: Optional[BatchUpdate])
Source code in vllm/v1/sample/logits_processor/builtin.py
def update_state(self, batch_update: Optional[BatchUpdate]):
    if not batch_update:
        return

    needs_update: bool = False
    # Process added requests.
    for index, params, _, _ in batch_update.added:
        if lb := params.logit_bias:
            self.biases[index] = lb
            needs_update = True
        else:
            # Drop biases metadata at batch index
            if self.biases.pop(index, None) is not None:
                # If a new request replaces an old request which
                # specified biases, we should update processor tensors
                needs_update = True

    if self.biases:
        # Process removed requests.
        for index in batch_update.removed:
            if self.biases.pop(index, None):
                needs_update = True

        # Process moved requests, unidirectional (a->b) and swap (a<->b)
        for a_index, b_index, direct in batch_update.moved:
            if direct == MoveDirectionality.UNIDIRECTIONAL:
                if (a_entry := self.biases.pop(a_index, None)) is None:
                    if self.biases.pop(b_index, None) is not None:
                        needs_update = True
                else:
                    self.biases[b_index] = a_entry
                    needs_update = True
            else:
                a_entry = self.biases.pop(a_index, None)
                if (b_entry := self.biases.pop(b_index, None)) is not None:
                    self.biases[a_index] = b_entry
                    needs_update = True
                if a_entry is not None:
                    self.biases[b_index] = a_entry
                    needs_update = True

    # Update tensors if needed.
    if needs_update:
        reqs, tok_ids, biases = [], [], []
        for req, lb in self.biases.items():
            reqs.extend([req] * len(lb))
            tok_ids.extend(lb.keys())
            biases.extend(lb.values())

        self.bias_tensor = self._device_tensor(biases, torch.float32)
        self.logits_slice = (self._device_tensor(reqs, torch.int32),
                             self._device_tensor(tok_ids, torch.int32))

LogitsProcessor

Bases: ABC

Source code in vllm/v1/sample/logits_processor/interface.py
class LogitsProcessor(ABC):

    @abstractmethod
    def __init__(self, vllm_config: "VllmConfig", device: torch.device,
                 is_pin_memory: bool) -> None:
        raise NotImplementedError

    @abstractmethod
    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @abstractmethod
    def is_argmax_invariant(self) -> bool:
        """True if logits processor has no impact on the
        argmax computation in greedy sampling.
        NOTE: may or may not have the same value for all
        instances of a given LogitsProcessor subclass,
        depending on subclass implementation.
        """
        raise NotImplementedError

    @abstractmethod
    def update_state(
        self,
        batch_update: Optional["BatchUpdate"],
    ) -> None:
        """Called when there are new output tokens, prior
        to each forward pass.

        Args:
            batch_update is non-None iff there have been
            changes to the batch makeup.
        """
        raise NotImplementedError

__init__ abstractmethod

__init__(
    vllm_config: VllmConfig,
    device: device,
    is_pin_memory: bool,
) -> None
Source code in vllm/v1/sample/logits_processor/interface.py
@abstractmethod
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
             is_pin_memory: bool) -> None:
    raise NotImplementedError

apply abstractmethod

apply(logits: Tensor) -> Tensor
Source code in vllm/v1/sample/logits_processor/interface.py
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    raise NotImplementedError

is_argmax_invariant abstractmethod

is_argmax_invariant() -> bool

True if logits processor has no impact on the argmax computation in greedy sampling. NOTE: may or may not have the same value for all instances of a given LogitsProcessor subclass, depending on subclass implementation.

Source code in vllm/v1/sample/logits_processor/interface.py
@abstractmethod
def is_argmax_invariant(self) -> bool:
    """True if logits processor has no impact on the
    argmax computation in greedy sampling.
    NOTE: may or may not have the same value for all
    instances of a given LogitsProcessor subclass,
    depending on subclass implementation.
    """
    raise NotImplementedError

update_state abstractmethod

update_state(batch_update: Optional[BatchUpdate]) -> None

Called when there are new output tokens, prior to each forward pass.

Source code in vllm/v1/sample/logits_processor/interface.py
@abstractmethod
def update_state(
    self,
    batch_update: Optional["BatchUpdate"],
) -> None:
    """Called when there are new output tokens, prior
    to each forward pass.

    Args:
        batch_update is non-None iff there have been
        changes to the batch makeup.
    """
    raise NotImplementedError

LogitsProcessors

Encapsulates initialized logitsproc objects.

Source code in vllm/v1/sample/logits_processor/state.py
class LogitsProcessors:
    """Encapsulates initialized logitsproc objects."""

    def __init__(
            self,
            logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None:
        self.argmax_invariant: list[LogitsProcessor] = []
        self.non_argmax_invariant: list[LogitsProcessor] = []
        if logitsprocs:
            for logitproc in logitsprocs:
                (self.argmax_invariant if logitproc.is_argmax_invariant() else
                 self.non_argmax_invariant).append(logitproc)

    @property
    def all(self) -> Iterator["LogitsProcessor"]:
        """Iterator over all logits processors."""
        return chain(self.argmax_invariant, self.non_argmax_invariant)

all property

Iterator over all logits processors.

argmax_invariant instance-attribute

argmax_invariant: list[LogitsProcessor] = []

non_argmax_invariant instance-attribute

non_argmax_invariant: list[LogitsProcessor] = []

__init__

__init__(
    logitsprocs: Optional[Iterator[LogitsProcessor]] = None,
) -> None
Source code in vllm/v1/sample/logits_processor/state.py
def __init__(
        self,
        logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None:
    self.argmax_invariant: list[LogitsProcessor] = []
    self.non_argmax_invariant: list[LogitsProcessor] = []
    if logitsprocs:
        for logitproc in logitsprocs:
            (self.argmax_invariant if logitproc.is_argmax_invariant() else
             self.non_argmax_invariant).append(logitproc)

MinPLogitsProcessor

Bases: LogitsProcessor

Source code in vllm/v1/sample/logits_processor/builtin.py
class MinPLogitsProcessor(LogitsProcessor):

    def __init__(self, vllm_config: "VllmConfig", device: torch.device,
                 is_pin_memory: bool):
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        self.min_p_count: int = 0

        self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
                                            dtype=torch.float32,
                                            device="cpu",
                                            pin_memory=is_pin_memory)
        self.min_p_cpu = self.min_p_cpu_tensor.numpy()

        self.use_double_tensor = torch.device(device).type != "cpu"

        if self.use_double_tensor:
            # Pre-allocated device tensor
            self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ),
                                                          dtype=torch.float32,
                                                          device=device)
        else:
            self.min_p_device = self.min_p_cpu_tensor
        # Current slice of the device tensor
        self.min_p: torch.Tensor = self.min_p_device[:0]

    def is_argmax_invariant(self) -> bool:
        """Min-p never impacts greedy sampling"""
        return True

    def get_min_p_by_index(self, index: int) -> float:
        return float(self.min_p_cpu[index])

    def update_state(self, batch_update: Optional[BatchUpdate]):
        if not batch_update:
            return

        needs_update = False
        # Process added requests.
        for index, params, _, _ in batch_update.added:
            min_p = params.min_p
            min_p_before = self.min_p_cpu[index]
            if min_p_before != min_p:
                needs_update = True
                self.min_p_cpu[index] = min_p
                if min_p and not min_p_before:
                    self.min_p_count += 1
                elif not min_p and min_p_before:
                    self.min_p_count -= 1

        if self.min_p_count:
            # Process removed requests.
            if batch_update.removed:
                needs_update = True
                for index in batch_update.removed:
                    if self.min_p_cpu[index]:
                        self.min_p_cpu[index] = 0
                        self.min_p_count -= 1

            # Process moved requests, unidirectional (a->b) and swap (a<->b).
            for adx, bdx, direct in batch_update.moved:
                min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
                if min_p_a != min_p_b:
                    needs_update = True
                    self.min_p_cpu[bdx] = min_p_a
                    if direct == MoveDirectionality.SWAP:
                        self.min_p_cpu[adx] = min_p_b
                if direct == MoveDirectionality.UNIDIRECTIONAL:
                    if min_p_a:
                        self.min_p_cpu[adx] = 0
                    if min_p_b:
                        self.min_p_count -= 1

        # Update tensors if needed.
        size = batch_update.batch_size
        if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
            self.min_p = self.min_p_device[:size]
            if self.use_double_tensor:
                self.min_p.copy_(self.min_p_cpu_tensor[:size],
                                 non_blocking=True)
            self.min_p.unsqueeze_(1)

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if not self.min_p_count:
            return logits

        # Convert logits to probability distribution
        probability_values = torch.nn.functional.softmax(logits, dim=-1)
        # Calculate maximum probabilities per sequence
        max_probabilities = torch.amax(probability_values,
                                       dim=-1,
                                       keepdim=True)
        # Adjust min_p
        adjusted_min_p = max_probabilities.mul_(self.min_p)
        # Identify valid tokens using threshold comparison
        invalid_token_mask = probability_values < adjusted_min_p
        # Apply mask using boolean indexing
        logits[invalid_token_mask] = -float('inf')
        return logits

min_p instance-attribute

min_p: Tensor = min_p_device[:0]

min_p_count instance-attribute

min_p_count: int = 0

min_p_cpu instance-attribute

min_p_cpu = numpy()

min_p_cpu_tensor instance-attribute

min_p_cpu_tensor = zeros(
    (max_num_reqs,),
    dtype=float32,
    device="cpu",
    pin_memory=is_pin_memory,
)

min_p_device instance-attribute

min_p_device: Tensor = empty(
    (max_num_reqs,), dtype=float32, device=device
)

use_double_tensor instance-attribute

use_double_tensor = type != 'cpu'

__init__

__init__(
    vllm_config: VllmConfig,
    device: device,
    is_pin_memory: bool,
)
Source code in vllm/v1/sample/logits_processor/builtin.py
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
             is_pin_memory: bool):
    max_num_reqs = vllm_config.scheduler_config.max_num_seqs
    self.min_p_count: int = 0

    self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
                                        dtype=torch.float32,
                                        device="cpu",
                                        pin_memory=is_pin_memory)
    self.min_p_cpu = self.min_p_cpu_tensor.numpy()

    self.use_double_tensor = torch.device(device).type != "cpu"

    if self.use_double_tensor:
        # Pre-allocated device tensor
        self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ),
                                                      dtype=torch.float32,
                                                      device=device)
    else:
        self.min_p_device = self.min_p_cpu_tensor
    # Current slice of the device tensor
    self.min_p: torch.Tensor = self.min_p_device[:0]

apply

apply(logits: Tensor) -> Tensor
Source code in vllm/v1/sample/logits_processor/builtin.py
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    if not self.min_p_count:
        return logits

    # Convert logits to probability distribution
    probability_values = torch.nn.functional.softmax(logits, dim=-1)
    # Calculate maximum probabilities per sequence
    max_probabilities = torch.amax(probability_values,
                                   dim=-1,
                                   keepdim=True)
    # Adjust min_p
    adjusted_min_p = max_probabilities.mul_(self.min_p)
    # Identify valid tokens using threshold comparison
    invalid_token_mask = probability_values < adjusted_min_p
    # Apply mask using boolean indexing
    logits[invalid_token_mask] = -float('inf')
    return logits

get_min_p_by_index

get_min_p_by_index(index: int) -> float
Source code in vllm/v1/sample/logits_processor/builtin.py
def get_min_p_by_index(self, index: int) -> float:
    return float(self.min_p_cpu[index])

is_argmax_invariant

is_argmax_invariant() -> bool

Min-p never impacts greedy sampling

Source code in vllm/v1/sample/logits_processor/builtin.py
def is_argmax_invariant(self) -> bool:
    """Min-p never impacts greedy sampling"""
    return True

update_state

update_state(batch_update: Optional[BatchUpdate])
Source code in vllm/v1/sample/logits_processor/builtin.py
def update_state(self, batch_update: Optional[BatchUpdate]):
    if not batch_update:
        return

    needs_update = False
    # Process added requests.
    for index, params, _, _ in batch_update.added:
        min_p = params.min_p
        min_p_before = self.min_p_cpu[index]
        if min_p_before != min_p:
            needs_update = True
            self.min_p_cpu[index] = min_p
            if min_p and not min_p_before:
                self.min_p_count += 1
            elif not min_p and min_p_before:
                self.min_p_count -= 1

    if self.min_p_count:
        # Process removed requests.
        if batch_update.removed:
            needs_update = True
            for index in batch_update.removed:
                if self.min_p_cpu[index]:
                    self.min_p_cpu[index] = 0
                    self.min_p_count -= 1

        # Process moved requests, unidirectional (a->b) and swap (a<->b).
        for adx, bdx, direct in batch_update.moved:
            min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
            if min_p_a != min_p_b:
                needs_update = True
                self.min_p_cpu[bdx] = min_p_a
                if direct == MoveDirectionality.SWAP:
                    self.min_p_cpu[adx] = min_p_b
            if direct == MoveDirectionality.UNIDIRECTIONAL:
                if min_p_a:
                    self.min_p_cpu[adx] = 0
                if min_p_b:
                    self.min_p_count -= 1

    # Update tensors if needed.
    size = batch_update.batch_size
    if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
        self.min_p = self.min_p_device[:size]
        if self.use_double_tensor:
            self.min_p.copy_(self.min_p_cpu_tensor[:size],
                             non_blocking=True)
        self.min_p.unsqueeze_(1)

MinTokensLogitsProcessor

Bases: LogitsProcessor

Source code in vllm/v1/sample/logits_processor/builtin.py
class MinTokensLogitsProcessor(LogitsProcessor):

    def __init__(self, vllm_config: "VllmConfig", device: torch.device,
                 is_pin_memory: bool):
        # index -> (min_toks, output_token_ids, stop_token_ids)
        self.device = device
        self.pin_memory = is_pin_memory
        self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}

        # (req_idx_tensor,eos_tok_id_tensor)
        self.logits_slice: tuple[torch.Tensor,
                                 torch.Tensor] = (self._device_tensor(
                                     [], torch.int32),
                                                  self._device_tensor(
                                                      [], torch.int32))

    def is_argmax_invariant(self) -> bool:
        """By censoring stop tokens, min-tokens can change the outcome
        of the argmax operation in greedy sampling."""
        return False

    def update_state(self, batch_update: Optional[BatchUpdate]):
        needs_update = False

        if batch_update:
            # Process added requests.
            for index, params, _, output_tok_ids in batch_update.added:
                if ((min_tokens := params.min_tokens)
                        and len(output_tok_ids) < min_tokens):
                    # Replace request metadata at batch index
                    self.min_toks[index] = (min_tokens, output_tok_ids,
                                            params.all_stop_token_ids)
                    needs_update = True
                else:
                    # Drop min_toks metadata at batch index
                    if self.min_toks.pop(index, None) is not None:
                        # If a new request replaces an old request which
                        # specified min_toks, we should update processor tensors
                        needs_update = True

            if self.min_toks:
                # Process removed requests.
                for index in batch_update.removed:
                    if self.min_toks.pop(index, None):
                        needs_update = True

                # Process moved requests, unidirectional (a->b) and
                # swapped (a<->b)
                for a_index, b_index, direct in batch_update.moved:
                    if direct == MoveDirectionality.UNIDIRECTIONAL:
                        if (a_entry := self.min_toks.pop(a_index,
                                                         None)) is None:
                            if self.min_toks.pop(b_index, None) is not None:
                                needs_update = True
                        else:
                            self.min_toks[b_index] = a_entry
                            needs_update = True
                    else:
                        a_entry = self.min_toks.pop(a_index, None)
                        if (b_entry := self.min_toks.pop(b_index,
                                                         None)) is not None:
                            self.min_toks[a_index] = b_entry
                            needs_update = True
                        if a_entry is not None:
                            self.min_toks[b_index] = a_entry
                            needs_update = True

        if self.min_toks:
            # Check for any requests that have attained their min tokens.
            to_remove = tuple(index for index, (min_toks, out_tok_ids,
                                                _) in self.min_toks.items()
                              if len(out_tok_ids) >= min_toks)
            if to_remove:
                needs_update = True
                for index in to_remove:
                    del self.min_toks[index]

        # Update tensors if needed.
        if needs_update:
            reqs: list[int] = []
            tok_ids: list[int] = []
            for req, (_, _, stop_tok_ids) in self.min_toks.items():
                reqs.extend([req] * len(stop_tok_ids))
                tok_ids.extend(stop_tok_ids)

            self.logits_slice = (self._device_tensor(reqs, torch.int32),
                                 self._device_tensor(tok_ids, torch.int32))

    def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
        return (torch.tensor(data,
                             device="cpu",
                             dtype=dtype,
                             pin_memory=self.pin_memory).to(device=self.device,
                                                            non_blocking=True))

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if self.min_toks:
            # Inhibit EOS token for requests which have not reached min length
            logits[self.logits_slice] = -float("inf")
        return logits

device instance-attribute

device = device

logits_slice instance-attribute

logits_slice: tuple[Tensor, Tensor] = (
    _device_tensor([], int32),
    _device_tensor([], int32),
)

min_toks instance-attribute

min_toks: dict[
    int, tuple[int, Sequence[int], set[int]]
] = {}

pin_memory instance-attribute

pin_memory = is_pin_memory

__init__

__init__(
    vllm_config: VllmConfig,
    device: device,
    is_pin_memory: bool,
)
Source code in vllm/v1/sample/logits_processor/builtin.py
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
             is_pin_memory: bool):
    # index -> (min_toks, output_token_ids, stop_token_ids)
    self.device = device
    self.pin_memory = is_pin_memory
    self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}

    # (req_idx_tensor,eos_tok_id_tensor)
    self.logits_slice: tuple[torch.Tensor,
                             torch.Tensor] = (self._device_tensor(
                                 [], torch.int32),
                                              self._device_tensor(
                                                  [], torch.int32))

_device_tensor

_device_tensor(data: list, dtype: dtype) -> Tensor
Source code in vllm/v1/sample/logits_processor/builtin.py
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
    return (torch.tensor(data,
                         device="cpu",
                         dtype=dtype,
                         pin_memory=self.pin_memory).to(device=self.device,
                                                        non_blocking=True))

apply

apply(logits: Tensor) -> Tensor
Source code in vllm/v1/sample/logits_processor/builtin.py
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    if self.min_toks:
        # Inhibit EOS token for requests which have not reached min length
        logits[self.logits_slice] = -float("inf")
    return logits

is_argmax_invariant

is_argmax_invariant() -> bool

By censoring stop tokens, min-tokens can change the outcome of the argmax operation in greedy sampling.

Source code in vllm/v1/sample/logits_processor/builtin.py
def is_argmax_invariant(self) -> bool:
    """By censoring stop tokens, min-tokens can change the outcome
    of the argmax operation in greedy sampling."""
    return False

update_state

update_state(batch_update: Optional[BatchUpdate])
Source code in vllm/v1/sample/logits_processor/builtin.py
def update_state(self, batch_update: Optional[BatchUpdate]):
    needs_update = False

    if batch_update:
        # Process added requests.
        for index, params, _, output_tok_ids in batch_update.added:
            if ((min_tokens := params.min_tokens)
                    and len(output_tok_ids) < min_tokens):
                # Replace request metadata at batch index
                self.min_toks[index] = (min_tokens, output_tok_ids,
                                        params.all_stop_token_ids)
                needs_update = True
            else:
                # Drop min_toks metadata at batch index
                if self.min_toks.pop(index, None) is not None:
                    # If a new request replaces an old request which
                    # specified min_toks, we should update processor tensors
                    needs_update = True

        if self.min_toks:
            # Process removed requests.
            for index in batch_update.removed:
                if self.min_toks.pop(index, None):
                    needs_update = True

            # Process moved requests, unidirectional (a->b) and
            # swapped (a<->b)
            for a_index, b_index, direct in batch_update.moved:
                if direct == MoveDirectionality.UNIDIRECTIONAL:
                    if (a_entry := self.min_toks.pop(a_index,
                                                     None)) is None:
                        if self.min_toks.pop(b_index, None) is not None:
                            needs_update = True
                    else:
                        self.min_toks[b_index] = a_entry
                        needs_update = True
                else:
                    a_entry = self.min_toks.pop(a_index, None)
                    if (b_entry := self.min_toks.pop(b_index,
                                                     None)) is not None:
                        self.min_toks[a_index] = b_entry
                        needs_update = True
                    if a_entry is not None:
                        self.min_toks[b_index] = a_entry
                        needs_update = True

    if self.min_toks:
        # Check for any requests that have attained their min tokens.
        to_remove = tuple(index for index, (min_toks, out_tok_ids,
                                            _) in self.min_toks.items()
                          if len(out_tok_ids) >= min_toks)
        if to_remove:
            needs_update = True
            for index in to_remove:
                del self.min_toks[index]

    # Update tensors if needed.
    if needs_update:
        reqs: list[int] = []
        tok_ids: list[int] = []
        for req, (_, _, stop_tok_ids) in self.min_toks.items():
            reqs.extend([req] * len(stop_tok_ids))
            tok_ids.extend(stop_tok_ids)

        self.logits_slice = (self._device_tensor(reqs, torch.int32),
                             self._device_tensor(tok_ids, torch.int32))

MoveDirectionality

Bases: Enum

Source code in vllm/v1/sample/logits_processor/interface.py
class MoveDirectionality(Enum):
    # One-way i1->i2 req move within batch
    UNIDIRECTIONAL = auto()
    # Two-way i1<->i2 req swap within batch
    SWAP = auto()

SWAP class-attribute instance-attribute

SWAP = auto()

UNIDIRECTIONAL class-attribute instance-attribute

UNIDIRECTIONAL = auto()

_load_custom_logitsprocs

_load_custom_logitsprocs(
    logits_processors: Optional[
        Sequence[Union[str, type[LogitsProcessor]]]
    ],
) -> list[type[LogitsProcessor]]

Load all custom logits processors.

  • First load all installed logitproc plugins
  • Second load custom logitsprocs pass by the user at initialization time

Parameters:

Name Type Description Default
logits_processors Optional[Sequence[Union[str, type[LogitsProcessor]]]]

potentially mixed list of logitproc types and logitproc type fully-qualified names (FQCNs) which need to be loaded

required

Returns:

Type Description
list[type[LogitsProcessor]]

A list of all loaded logitproc types

Source code in vllm/v1/sample/logits_processor/__init__.py
def _load_custom_logitsprocs(
    logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]],
) -> list[type[LogitsProcessor]]:
    """Load all custom logits processors.

    * First load all installed logitproc plugins
    * Second load custom logitsprocs pass by the user at initialization time

    Args:
      logits_processors: potentially mixed list of logitproc types and
                         logitproc type fully-qualified names (FQCNs)
                         which need to be loaded

    Returns:
      A list of all loaded logitproc types
    """
    from vllm.platforms import current_platform
    if current_platform.is_tpu():
        # No logitsprocs specified by caller
        # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
        return []

    return (_load_logitsprocs_plugins() +
            _load_logitsprocs_by_fqcns(logits_processors))

_load_logitsprocs_by_fqcns

_load_logitsprocs_by_fqcns(
    logits_processors: Optional[
        Sequence[Union[str, type[LogitsProcessor]]]
    ],
) -> list[type[LogitsProcessor]]

Load logit processor types, identifying them by fully-qualified class names (FQCNs).

Effectively, a mixed list of logitproc types and FQCN strings is converted into a list of entirely logitproc types, by loading from the FQCNs.

FQCN syntax is : i.e. x.y.z:CustomLogitProc

Already-loaded logitproc types must be subclasses of LogitsProcessor

Parameters:

Name Type Description Default
logits_processors Optional[Sequence[Union[str, type[LogitsProcessor]]]]

Potentially mixed list of logitsprocs types and FQCN strings for logitproc types

required

Returns:

Type Description
list[type[LogitsProcessor]]

List of logitproc types

Source code in vllm/v1/sample/logits_processor/__init__.py
def _load_logitsprocs_by_fqcns(
    logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]]
) -> list[type[LogitsProcessor]]:
    """Load logit processor types, identifying them by fully-qualified class
    names (FQCNs).

    Effectively, a mixed list of logitproc types and FQCN strings is converted
    into a list of entirely logitproc types, by loading from the FQCNs.

    FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc

    Already-loaded logitproc types must be subclasses of LogitsProcessor

    Args:
      logits_processors: Potentially mixed list of logitsprocs types and FQCN
                         strings for logitproc types

    Returns:
      List of logitproc types

    """
    if not logits_processors:
        return []

    logger.debug(
        "%s additional custom logits processors specified, checking whether "
        "they need to be loaded.", len(logits_processors))

    classes: list[type[LogitsProcessor]] = []
    for ldx, logitproc in enumerate(logits_processors):
        if isinstance(logitproc, type):
            logger.debug(" - Already-loaded logit processor: %s",
                         logitproc.__name__)
            if not issubclass(logitproc, LogitsProcessor):
                raise ValueError(
                    f"{logitproc.__name__} is not a subclass of LogitsProcessor"
                )
            classes.append(logitproc)
            continue

        logger.debug("- Loading logits processor %s", logitproc)
        module_path, qualname = logitproc.split(":")

        try:
            # Load module
            module = importlib.import_module(module_path)
        except Exception as e:
            raise RuntimeError(
                f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
            ) from e

        # Walk down dotted name to get logitproc class
        obj = module
        for attr in qualname.split("."):
            obj = getattr(obj, attr)
        if not isinstance(obj, type):
            raise ValueError("Loaded logit processor must be a type.")
        if not issubclass(obj, LogitsProcessor):
            raise ValueError(
                f"{obj.__name__} must be a subclass of LogitsProcessor")
        classes.append(obj)

    return classes

_load_logitsprocs_plugins

_load_logitsprocs_plugins() -> list[type[LogitsProcessor]]

Load all installed logit processor plugins

Source code in vllm/v1/sample/logits_processor/__init__.py
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
    """Load all installed logit processor plugins"""

    import sys

    if sys.version_info < (3, 10):
        from importlib_metadata import entry_points
    else:
        from importlib.metadata import entry_points

    installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
    if len(installed_logitsprocs_plugins) == 0:
        logger.debug("No logitsprocs plugins installed (group %s).",
                     LOGITSPROCS_GROUP)
        return []

    # Load logitsprocs plugins
    logger.debug("Loading installed logitsprocs plugins (group %s):",
                 LOGITSPROCS_GROUP)
    classes: list[type[LogitsProcessor]] = []
    for entrypoint in installed_logitsprocs_plugins:
        try:
            logger.debug("- Loading logitproc plugin entrypoint=%s target=%s",
                         entrypoint.name, entrypoint.value)
            classes.append(entrypoint.load())
        except Exception as e:
            raise RuntimeError(
                f"Failed to load LogitsProcessor plugin {entrypoint}") from e
    return classes

build_logitsprocs

build_logitsprocs(
    vllm_config: VllmConfig,
    device: device,
    is_pin_memory: bool,
    is_pooling_model: bool,
    custom_logitsprocs: Sequence[
        Union[str, type[LogitsProcessor]]
    ] = (),
) -> LogitsProcessors
Source code in vllm/v1/sample/logits_processor/__init__.py
def build_logitsprocs(
    vllm_config: "VllmConfig",
    device: torch.device,
    is_pin_memory: bool,
    is_pooling_model: bool,
    custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
) -> LogitsProcessors:
    if is_pooling_model:
        if custom_logitsprocs:
            raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
        logger.debug("Skipping logits processor loading because pooling models"
                     " do not support logits processors.")
        return LogitsProcessors()
    custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
    return LogitsProcessors(
        ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
            BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))