Skip to content

vllm.v1.structured_output.backend_lm_format_enforcer

LMFormatEnforcerBackend dataclass

Bases: StructuredOutputBackend

Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
@dataclass
class LMFormatEnforcerBackend(StructuredOutputBackend):

    def __post_init__(self):
        self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
            self.tokenizer, self.vocab_size)

    def compile_grammar(self, request_type: StructuredOutputOptions,
                        grammar_spec: str) -> StructuredOutputGrammar:
        character_level_parser: lmformatenforcer.CharacterLevelParser
        if request_type == StructuredOutputOptions.JSON:
            spec_dict = json.loads(grammar_spec)
            character_level_parser = lmformatenforcer.JsonSchemaParser(
                spec_dict)
        elif request_type == StructuredOutputOptions.JSON_OBJECT:
            character_level_parser = lmformatenforcer.JsonSchemaParser(None)
        elif request_type == StructuredOutputOptions.REGEX:
            character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
        elif request_type == StructuredOutputOptions.CHOICE:
            choices = ast.literal_eval(grammar_spec)
            character_level_parser = lmformatenforcer.UnionParser(
                [lmformatenforcer.StringParser(choice) for choice in choices])
        else:
            raise ValueError(
                "Invalid request type for LM Format Enforcer backend"
                f"({request_type!s})")
        max_rollback_tokens = (
            self.vllm_config.speculative_config.num_speculative_tokens
            if self.vllm_config.speculative_config is not None else 0)

        if max_rollback_tokens > 0:
            raise ValueError(
                "LM Format Enforcer backend does not support speculative tokens"
            )

        token_enforcer = lmformatenforcer.TokenEnforcer(
            tokenizer_data=self.tokenizer_data,
            parser=character_level_parser,
        )
        return LMFormatEnforcerGrammar(token_enforcer)

    def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
        return torch.full(
            (max_num_seqs, (self.vocab_size + 31) // 32),
            -1,
            dtype=torch.int32,
            pin_memory=torch.cuda.is_available(),
        )

    def destroy(self):
        pass

__init__

__init__(
    vllm_config: VllmConfig,
    tokenizer: AnyTokenizer,
    vocab_size: int,
) -> None

__post_init__

__post_init__()
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def __post_init__(self):
    self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
        self.tokenizer, self.vocab_size)

allocate_token_bitmask

allocate_token_bitmask(max_num_seqs: int) -> Tensor
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
    return torch.full(
        (max_num_seqs, (self.vocab_size + 31) // 32),
        -1,
        dtype=torch.int32,
        pin_memory=torch.cuda.is_available(),
    )

compile_grammar

compile_grammar(
    request_type: StructuredOutputOptions, grammar_spec: str
) -> StructuredOutputGrammar
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def compile_grammar(self, request_type: StructuredOutputOptions,
                    grammar_spec: str) -> StructuredOutputGrammar:
    character_level_parser: lmformatenforcer.CharacterLevelParser
    if request_type == StructuredOutputOptions.JSON:
        spec_dict = json.loads(grammar_spec)
        character_level_parser = lmformatenforcer.JsonSchemaParser(
            spec_dict)
    elif request_type == StructuredOutputOptions.JSON_OBJECT:
        character_level_parser = lmformatenforcer.JsonSchemaParser(None)
    elif request_type == StructuredOutputOptions.REGEX:
        character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
    elif request_type == StructuredOutputOptions.CHOICE:
        choices = ast.literal_eval(grammar_spec)
        character_level_parser = lmformatenforcer.UnionParser(
            [lmformatenforcer.StringParser(choice) for choice in choices])
    else:
        raise ValueError(
            "Invalid request type for LM Format Enforcer backend"
            f"({request_type!s})")
    max_rollback_tokens = (
        self.vllm_config.speculative_config.num_speculative_tokens
        if self.vllm_config.speculative_config is not None else 0)

    if max_rollback_tokens > 0:
        raise ValueError(
            "LM Format Enforcer backend does not support speculative tokens"
        )

    token_enforcer = lmformatenforcer.TokenEnforcer(
        tokenizer_data=self.tokenizer_data,
        parser=character_level_parser,
    )
    return LMFormatEnforcerGrammar(token_enforcer)

destroy

destroy()
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def destroy(self):
    pass

LMFormatEnforcerGrammar dataclass

Bases: StructuredOutputGrammar

Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
@dataclass
class LMFormatEnforcerGrammar(StructuredOutputGrammar):
    token_enforcer: lmformatenforcer.TokenEnforcer
    current_tokens_prefix: list[int] = field(default_factory=list)

    def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
        original_len = len(self.current_tokens_prefix)
        for token in tokens:
            if not self.token_enforcer.get_allowed_tokens(
                    self.current_tokens_prefix).is_token_allowed(token):
                # Rollback partial updates to ensure atomicity.
                del self.current_tokens_prefix[original_len:]
                return False
            self.current_tokens_prefix.append(token)
        return True

    def validate_tokens(self, tokens: list[int]) -> list[int]:
        for prefix_length in range(len(tokens)):
            prefix = tokens[:prefix_length]
            next_token = tokens[prefix_length]
            if not self.token_enforcer.get_allowed_tokens(
                    self.current_tokens_prefix +
                    prefix).is_token_allowed(next_token):
                break
        else:
            return tokens

        return tokens[:prefix_length]

    def rollback(self, num_tokens: int) -> None:
        self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]

    def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
        allowed_tokens = self.token_enforcer.get_allowed_tokens(
            self.current_tokens_prefix)
        bitmask[batch_index] = allowed_tokens.allowed_tokens

    def is_terminated(self) -> bool:
        # We are considered terminated if the prefix ends with eos_token_id
        return_value = len(
            self.current_tokens_prefix) > 0 and self.current_tokens_prefix[
                -1] == self.token_enforcer.eos_token_id
        return return_value

    def reset(self):
        self.current_tokens_prefix = []

current_tokens_prefix class-attribute instance-attribute

current_tokens_prefix: list[int] = field(
    default_factory=list
)

token_enforcer instance-attribute

token_enforcer: TokenEnforcer

__init__

__init__(
    token_enforcer: TokenEnforcer,
    current_tokens_prefix: list[int] = list(),
) -> None

accept_tokens

accept_tokens(request_id: str, tokens: list[int]) -> bool
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
    original_len = len(self.current_tokens_prefix)
    for token in tokens:
        if not self.token_enforcer.get_allowed_tokens(
                self.current_tokens_prefix).is_token_allowed(token):
            # Rollback partial updates to ensure atomicity.
            del self.current_tokens_prefix[original_len:]
            return False
        self.current_tokens_prefix.append(token)
    return True

fill_bitmask

fill_bitmask(bitmask: Tensor, batch_index: int) -> None
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
    allowed_tokens = self.token_enforcer.get_allowed_tokens(
        self.current_tokens_prefix)
    bitmask[batch_index] = allowed_tokens.allowed_tokens

is_terminated

is_terminated() -> bool
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def is_terminated(self) -> bool:
    # We are considered terminated if the prefix ends with eos_token_id
    return_value = len(
        self.current_tokens_prefix) > 0 and self.current_tokens_prefix[
            -1] == self.token_enforcer.eos_token_id
    return return_value

reset

reset()
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def reset(self):
    self.current_tokens_prefix = []

rollback

rollback(num_tokens: int) -> None
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def rollback(self, num_tokens: int) -> None:
    self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]

validate_tokens

validate_tokens(tokens: list[int]) -> list[int]
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def validate_tokens(self, tokens: list[int]) -> list[int]:
    for prefix_length in range(len(tokens)):
        prefix = tokens[:prefix_length]
        next_token = tokens[prefix_length]
        if not self.token_enforcer.get_allowed_tokens(
                self.current_tokens_prefix +
                prefix).is_token_allowed(next_token):
            break
    else:
        return tokens

    return tokens[:prefix_length]

_cached_build_vllm_token_enforcer_tokenizer_data cached

_cached_build_vllm_token_enforcer_tokenizer_data(
    tokenizer: PreTrainedTokenizerBase, vocab_size: int
) -> TokenEnforcerTokenizerData
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
@lru_cache
def _cached_build_vllm_token_enforcer_tokenizer_data(
        tokenizer: PreTrainedTokenizerBase,
        vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData:
    return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data(
        tokenizer, use_bitmask=True, vocab_size=vocab_size)

validate_structured_output_request_lm_format_enforcer

validate_structured_output_request_lm_format_enforcer(
    params: SamplingParams,
)
Source code in vllm/v1/structured_output/backend_lm_format_enforcer.py
def validate_structured_output_request_lm_format_enforcer(
        params: SamplingParams):
    if params.guided_decoding is None:
        return

    gd_params = params.guided_decoding

    if gd_params.regex:
        return
    elif gd_params.json:
        if isinstance(gd_params.json, str):
            try:
                # make sure schema is valid json
                json.loads(gd_params.json)
            except json.JSONDecodeError as e:
                raise ValueError("Invalid JSON grammar specification.") from e
        else:
            try:
                json.dumps(gd_params.json)
            except Exception as e:
                raise ValueError(
                    f"Error serializing guided decoding jsonschema: {e}"
                ) from e
        return
    elif gd_params.choice:
        return
    elif gd_params.grammar:
        raise ValueError("LM Format Enforcer guided decoding backend "
                         "does not support grammar specifications")