Skip to content

vllm.v1.sample.ops.logprobs

Some utilities for logprobs, including logits.

batched_count_greater_than

batched_count_greater_than(
    x: Tensor, values: Tensor
) -> Tensor

Counts elements in each row of x that are greater than the corresponding value in values. Use torch.compile to generate an optimized kernel for this function. otherwise, it will create additional copies of the input tensors and cause memory issues.

Parameters:

Name Type Description Default
x Tensor

A 2D tensor of shape (batch_size, n_elements).

required
values Tensor

A 2D tensor of shape (batch_size, 1).

required

Returns:

Type Description
Tensor

torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.

Source code in vllm/v1/sample/ops/logprobs.py
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def batched_count_greater_than(x: torch.Tensor,
                               values: torch.Tensor) -> torch.Tensor:
    """
    Counts elements in each row of x that are greater than the corresponding
    value in values.  Use torch.compile to generate an optimized kernel for
    this function. otherwise, it will create additional copies of the input
    tensors and cause memory issues.

    Args:
        x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
        values (torch.Tensor): A 2D tensor of shape (batch_size, 1).

    Returns:
        torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
    """
    return (x >= values).sum(-1)