Skip to content

vllm.v1.spec_decode.ngram_proposer

NgramProposer

Source code in vllm/v1/spec_decode/ngram_proposer.py
class NgramProposer:

    def __init__(self, vllm_config: VllmConfig):
        assert vllm_config.speculative_config is not None
        assert vllm_config.speculative_config.prompt_lookup_min is not None
        assert vllm_config.speculative_config.prompt_lookup_max is not None

        # Minimum length of the n-gram to match.
        self.min_n = vllm_config.speculative_config.prompt_lookup_min
        # Maximum length of the n-gram to match.
        self.max_n = vllm_config.speculative_config.prompt_lookup_max
        # Number of tokens follow the match. If there are less than k
        # tokens follow the match, we will return the maximum amount of
        # tokens until the end.
        self.k = vllm_config.speculative_config.num_speculative_tokens
        # Maximum length of the model.
        self.max_model_len = vllm_config.model_config.max_model_len

        # Trigger Numba JIT compilation for N-gram proposer.
        # This usually takes less than 1 second.
        self.propose(np.zeros(1024, dtype=np.int32))

    def propose(
        self,
        context_token_ids: np.ndarray,
    ) -> Optional[np.ndarray]:
        """Proposes the next sequence of tokens based on n-gram pattern 
        matching in the context. The function finds matches of the last n 
        tokens in the previous context, and returns k tokens that followed 
        that match.

        Args:
            context_token_ids: Numpy array of token IDs representing the 
                               context sequence.

        Returns:
            np.ndarray: The sequence of tokens that followed 
                        the matched n-gram in the context.
            None: If no matching n-gram pattern is found.

        Example:
            If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
            k = 4:
            - The last 3 (= max_n) tokens [4,2,3] cannot find a match.
            - The last 2 tokens [2,3] will be matched against the previous 
              4 tokens [1,2,3,4].
            - Finding a match of [2,3] would return the tokens that 
              followed that pattern. Here we will return [4,2,3] because 
              we only have three tokens after the match.
        """
        # TODO(woosuk): Optimize this.
        return _find_longest_matched_ngram_and_propose_tokens(
            origin_tokens=context_token_ids,
            min_ngram=self.min_n,
            max_ngram=self.max_n,
            max_model_len=self.max_model_len,
            k=self.k)

    def load_model(self, *args, **kwargs):
        # No model to load.
        pass

k instance-attribute

k = num_speculative_tokens

max_model_len instance-attribute

max_model_len = max_model_len

max_n instance-attribute

max_n = prompt_lookup_max

min_n instance-attribute

min_n = prompt_lookup_min

__init__

__init__(vllm_config: VllmConfig)
Source code in vllm/v1/spec_decode/ngram_proposer.py
def __init__(self, vllm_config: VllmConfig):
    assert vllm_config.speculative_config is not None
    assert vllm_config.speculative_config.prompt_lookup_min is not None
    assert vllm_config.speculative_config.prompt_lookup_max is not None

    # Minimum length of the n-gram to match.
    self.min_n = vllm_config.speculative_config.prompt_lookup_min
    # Maximum length of the n-gram to match.
    self.max_n = vllm_config.speculative_config.prompt_lookup_max
    # Number of tokens follow the match. If there are less than k
    # tokens follow the match, we will return the maximum amount of
    # tokens until the end.
    self.k = vllm_config.speculative_config.num_speculative_tokens
    # Maximum length of the model.
    self.max_model_len = vllm_config.model_config.max_model_len

    # Trigger Numba JIT compilation for N-gram proposer.
    # This usually takes less than 1 second.
    self.propose(np.zeros(1024, dtype=np.int32))

load_model

load_model(*args, **kwargs)
Source code in vllm/v1/spec_decode/ngram_proposer.py
def load_model(self, *args, **kwargs):
    # No model to load.
    pass

propose

propose(context_token_ids: ndarray) -> Optional[ndarray]

Proposes the next sequence of tokens based on n-gram pattern matching in the context. The function finds matches of the last n tokens in the previous context, and returns k tokens that followed that match.

Parameters:

Name Type Description Default
context_token_ids ndarray

Numpy array of token IDs representing the context sequence.

required

Returns:

Name Type Description
Optional[ndarray]

np.ndarray: The sequence of tokens that followed the matched n-gram in the context.

None Optional[ndarray]

If no matching n-gram pattern is found.

Example

If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and k = 4: - The last 3 (= max_n) tokens [4,2,3] cannot find a match. - The last 2 tokens [2,3] will be matched against the previous 4 tokens [1,2,3,4]. - Finding a match of [2,3] would return the tokens that followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match.

Source code in vllm/v1/spec_decode/ngram_proposer.py
def propose(
    self,
    context_token_ids: np.ndarray,
) -> Optional[np.ndarray]:
    """Proposes the next sequence of tokens based on n-gram pattern 
    matching in the context. The function finds matches of the last n 
    tokens in the previous context, and returns k tokens that followed 
    that match.

    Args:
        context_token_ids: Numpy array of token IDs representing the 
                           context sequence.

    Returns:
        np.ndarray: The sequence of tokens that followed 
                    the matched n-gram in the context.
        None: If no matching n-gram pattern is found.

    Example:
        If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
        k = 4:
        - The last 3 (= max_n) tokens [4,2,3] cannot find a match.
        - The last 2 tokens [2,3] will be matched against the previous 
          4 tokens [1,2,3,4].
        - Finding a match of [2,3] would return the tokens that 
          followed that pattern. Here we will return [4,2,3] because 
          we only have three tokens after the match.
    """
    # TODO(woosuk): Optimize this.
    return _find_longest_matched_ngram_and_propose_tokens(
        origin_tokens=context_token_ids,
        min_ngram=self.min_n,
        max_ngram=self.max_n,
        max_model_len=self.max_model_len,
        k=self.k)

_find_longest_matched_ngram_and_propose_tokens

_find_longest_matched_ngram_and_propose_tokens(
    origin_tokens: ndarray,
    min_ngram: int,
    max_ngram: int,
    max_model_len: int,
    k: int,
) -> Optional[ndarray]

Find the longest n-gram which matches the suffix of the given tokens whose length is within [min_ngram, max_ngram] (inclusive).

If found, we will extract k right after the matched ngram.

Source code in vllm/v1/spec_decode/ngram_proposer.py
@jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens(
        origin_tokens: np.ndarray, min_ngram: int, max_ngram: int,
        max_model_len: int, k: int) -> Optional[np.ndarray]:
    """
    Find the longest n-gram which matches the suffix of the given tokens
    whose length is within [min_ngram, max_ngram] (inclusive).

    If found, we will extract k right after the matched ngram.
    """
    # Do not generate draft tokens is context is shorter than minimum n-gram
    total_token = origin_tokens.shape[0]
    if total_token < min_ngram:
        return None

    # Do not generate draft tokens beyond the max model length.
    k = min(k, max_model_len - total_token)
    if k <= 0:
        return None

    # Flip tokens, and the goal become to find longest ngram
    # on the rightmost position which matches the prefix with
    # length [min_n, max_n] (inclusive).
    tokens = origin_tokens[::-1]

    # Longest prefix (not including itself) which is a suffix of
    # the current position.
    #   lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
    #
    # As ngram is capped by max_ngram to save memory, we only need to
    # store lps for the first max_ngram prefix.
    lps = np.zeros(max_ngram, dtype=np.int32)

    longest_ngram = 0
    position = 0

    # lps[0] always equal to 0, we starts with index 1
    prev_lps = 0
    i = 1
    while i < total_token:
        # tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
        if tokens[prev_lps] == tokens[i]:
            # Token match: tokens[:prev_lps+1] is the longest prefix as
            # a suffix of tokens[:i+1]
            prev_lps += 1
            # Check if we found a longer valid ngram.
            #
            # Update position when longest_ngram matched prev_lps,
            # as we want to get the target n-gram of the earliest position
            # in the original tokens (i.e.
            # latest position in the reversed tokens)
            if prev_lps >= longest_ngram:
                longest_ngram = prev_lps
                position = i
            if i < max_ngram:
                # Store LPS for the first max_ngram prefix
                lps[i] = prev_lps
            if prev_lps == max_ngram:
                # When prev_lps reached max_ngram, update prev_lps
                # to lps[max_ngram-1] to avoid matching ngram
                # longer than max_ngram
                prev_lps = lps[max_ngram - 1]
            i += 1
        elif prev_lps != 0:
            # Token mismatch: try the second longest prefix
            # among all suffix of tokens[:i],
            # which is the longest prefix of tokens[:prev_lps]
            prev_lps = lps[prev_lps - 1]
        else:
            # Token mismatch, and no more prefix (except empty string)
            # as a suffix of tokens[:i]
            i += 1

    if longest_ngram < min_ngram:
        # No valid ngram is found
        return None

    # Flip the position back, so in origin_tokens,
    # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
    # is the matched ngram, so we should start drafting tokens from
    # total_token-1-position+longest_ngram
    start_position = total_token - 1 - position + longest_ngram
    k = min(k, total_token - start_position)
    return origin_tokens[start_position:start_position + k]