@ToolParserManager.register_module(["step3"])
class Step3ToolParser(ToolParser):
"""
Tool parser for a model that uses a specific XML-like format for tool calls.
This version uses a robust, stateful, cursor-based streaming parser and
consolidates tool arguments into a single message.
"""
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
TOOL_CALLS_END = "<|tool_calls_end|>"
TOOL_CALL_BEGIN = "<|tool_call_begin|>"
TOOL_CALL_END = "<|tool_call_end|>"
TOOL_SEP = "<|tool_sep|>"
SPECIAL_TOKENS = [
TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
]
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
# Explicit state flags for robust streaming
self.tool_block_started = False
self.tool_block_finished = False
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
request.skip_special_tokens = False
return request
@staticmethod
def _parse_steptml_invoke(
action_text: str
) -> tuple[Optional[str], Optional[dict[str, str]]]:
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">',
action_text)
if not func_name_match:
return None, None
func_name = func_name_match.group(1)
params: dict[str, str] = {}
param_matches = re.findall(
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
action_text)
for name, value in param_matches:
params[name] = value.strip()
return func_name, params
def _cast_arguments(
self,
func_name: str,
params: dict[str, Any],
request: ChatCompletionRequest,
) -> dict[str, Any]:
for tool in request.tools or []:
if tool.function.name == func_name:
schema = tool.function.parameters or {}
properties = schema.get("properties", {})
for key, value in params.items():
if not isinstance(value, str):
continue
prop = properties.get(key, {})
typ = prop.get("type")
if typ == "string":
params[key] = value.strip()
elif typ == "integer":
with contextlib.suppress(ValueError):
params[key] = int(value)
elif typ == "number":
with contextlib.suppress(ValueError):
params[key] = float(value)
elif typ == "boolean":
lower_val = value.lower()
params[key] = lower_val == "true" if lower_val in (
"true", "false") else value
elif typ == "null":
params[key] = None if value.lower(
) == "null" else value
break
return params
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# The main loop processes the stream from the last known position.
while True:
if self.position >= len(current_text):
return None # We've processed the entire stream.
unprocessed_text = current_text[self.position:]
# STATE: After all tools are done, all subsequent text is content.
if self.tool_block_finished:
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
# STATE: Before the tool block has started.
if not self.tool_block_started:
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
self.position += len(self.TOOL_CALLS_BEGIN)
self.tool_block_started = True
continue # Token consumed, re-loop.
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
if start_pos == -1:
if self.TOOL_CALLS_BEGIN.startswith(
unprocessed_text.strip()) and unprocessed_text:
return None # It's a prefix, wait.
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
else:
content = unprocessed_text[:start_pos]
self.position += len(content)
return DeltaMessage(content=content)
# STATE: Inside the main tool block.
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
unprocessed_text = unprocessed_text.lstrip()
self.position += offset
if unprocessed_text.startswith(self.TOOL_CALLS_END):
self.position += len(self.TOOL_CALLS_END)
self.tool_block_finished = True
self.current_tool_id = -1
continue
# Check if we are between tool calls.
tool_finished = (
self.current_tool_id != -1 and
self.prev_tool_call_arr[self.current_tool_id].get("finished"))
if self.current_tool_id == -1 or tool_finished:
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
self.position += len(self.TOOL_CALL_BEGIN)
if self.current_tool_id == -1:
self.current_tool_id = 0
else:
self.current_tool_id += 1
self.current_tool_name_sent = False
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[
self.current_tool_id]["finished"] = False
continue
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
return None
# STATE: Parsing an active tool call.
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
self.current_tool_id].get("finished", False):
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
if end_tool_pos == -1:
tool_body = unprocessed_text
else:
tool_body = unprocessed_text[:end_tool_pos]
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
tool_body):
return None
function_name, arguments = self._parse_steptml_invoke(
tool_body)
if not function_name:
return None
tool_call_arr = {
"name": function_name,
"parameters": arguments or {}
}
# Send the function name as soon as it's parsed.
if not self.current_tool_name_sent:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id].update(
tool_call_arr)
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name))
])
# Update our internal state with the latest parsed arguments.
self.prev_tool_call_arr[
self.current_tool_id].update( # noqa: E501
tool_call_arr)
# Only send arguments when the tool call is complete.
if end_tool_pos != -1:
self.position += end_tool_pos + len(self.TOOL_CALL_END)
self.prev_tool_call_arr[
self.current_tool_id]["finished"] = True
final_args = self._cast_arguments(
function_name,
tool_call_arr.get("parameters", {}), # type: ignore
request)
if final_args:
final_args_json = json.dumps(final_args,
ensure_ascii=False)
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=final_args_json))
])
# If tool is not finished, return None to wait for more tokens.
return None
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.TOOL_CALLS_BEGIN not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
if self.TOOL_CALLS_END not in rest:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
content = (pre_text + post_text).strip()
tool_calls: list[ToolCall] = []
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
for part in call_parts:
if not part or self.TOOL_CALL_END not in part:
continue
call_content = part.split(self.TOOL_CALL_END, 1)[0]
if self.TOOL_SEP not in call_content:
continue
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
if type_part.strip() != "function":
continue
function_name, params_dict = self._parse_steptml_invoke(
invoke_part)
if function_name and params_dict is not None:
params_dict = self._cast_arguments(function_name, params_dict,
request)
params_str = json.dumps(params_dict, ensure_ascii=False)
tool_calls.append(
ToolCall(function=FunctionCall(name=function_name,
arguments=params_str)))
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)