@ToolParserManager.register_module("hunyuan_a13b")
class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# Initialize state for streaming mode
self.prev_tool_calls: list[dict] = []
self.current_tool_id = -1
self.current_tool_name_sent = False
self.streamed_args: list[str] = [
] # Track arguments sent for each tool
# For backward compatibility with tests
self.current_tools_sent: list[bool] = []
# For backward compatibility with serving code
self.prev_tool_call_arr = []
# Regex patterns for preprocessing
self.answer_tool_calls_pattern = re.compile(
r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL)
self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"')
self.tool_empty_arg_reg = re.compile(
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
# TODO: not support nested json object in fc arguments.
self.tool_non_empty_arg_reg = re.compile(
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
)
self.bot_string = "<tool_calls>"
# Define streaming state type to be initialized later
self.streaming_state: dict[str, Any] = {
"current_tool_index": -1,
"tool_ids": [],
"sent_tools": [],
}
def preprocess_model_output(
self, model_output: str) -> tuple[Optional[str], Optional[str]]:
# find the location tool call
for match in self.answer_tool_calls_pattern.finditer(model_output):
start, end = match.span()
# check tool_calls whether in side of <think>
think_regions = [(m.start(), m.end()) for m in re.finditer(
r"<think>(.*?)</think>", model_output, flags=re.DOTALL)]
in_think = any(start > t_start and end < t_end
for t_start, t_end in think_regions)
if not in_think:
content = model_output[:start]
tool_calls_content = match.group(1).strip()
try:
json.loads(tool_calls_content)
return content, tool_calls_content
except Exception:
continue
return model_output, None
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract tool calls from a complete model output.
"""
try:
# Preprocess the model output
content, potential_tool_calls = self.preprocess_model_output(
model_output)
if not potential_tool_calls:
# some text should be filtered out for no function call
# this text is in a13b's chat template.
if content:
content = content.replace("助手:", "", 1)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=content)
# Parse the potential tool calls as JSON
tool_calls_data = json.loads(potential_tool_calls)
# Ensure it's an array
if not isinstance(tool_calls_data, list):
logger.debug("Tool calls data is not an array")
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=content or model_output,
)
tool_calls: list[ToolCall] = []
for idx, call in enumerate(tool_calls_data):
if (not isinstance(call, dict) or "name" not in call
or "arguments" not in call):
continue
tool_call = ToolCall(
id=f"call_{random_uuid()}",
type="function",
function=FunctionCall(
name=call["name"],
arguments=(json.dumps(call["arguments"]) if isinstance(
call["arguments"], dict) else call["arguments"]),
),
)
tool_calls.append(tool_call)
if not content or len(content.strip()) == 0:
# clear the whitespace content.
content = None
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=content,
)
except Exception:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Extract tool calls for streaming mode.
"""
start_idx = consume_space(0, current_text)
if current_text[start_idx:].startswith(self.bot_string):
start_idx = consume_space(start_idx + len(self.bot_string),
current_text)
if not current_text or start_idx >= len(
current_text) or current_text[start_idx] != '[':
return DeltaMessage(content=delta_text)
self._try_parse_json_tools(current_text[start_idx:])
test_delta = self._handle_test_compatibility(current_text)
if test_delta:
return test_delta
name_matches = list(self.tool_name_reg.finditer(current_text))
tool_count = len(name_matches)
if tool_count == 0:
return None
self._ensure_state_arrays(tool_count)
current_idx = self.streaming_state["current_tool_index"]
name_delta = self._handle_tool_name_streaming(current_idx, tool_count,
name_matches)
if name_delta:
return name_delta
args_delta = self._handle_tool_args_streaming(current_text,
current_idx, tool_count)
if args_delta:
return args_delta
return None
def _try_parse_json_tools(self, current_text: str):
try:
parsed_tools = json.loads(current_text)
if isinstance(parsed_tools, list):
self.prev_tool_call_arr = parsed_tools
except json.JSONDecodeError:
pass
def _handle_test_compatibility(self, current_text: str):
if len(self.current_tools_sent) > 0:
if (len(self.current_tools_sent) == 1
and self.current_tools_sent[0] is False):
name_match = self.tool_name_reg.search(current_text)
if name_match:
function_name = name_match.group(1)
tool_id = f"chatcmpl-tool-{random_uuid()}"
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=0,
type="function",
id=tool_id,
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True),
)
])
self.current_tools_sent = [True]
self.current_tool_id = 0
self.streaming_state["current_tool_index"] = 0
if len(self.streaming_state["sent_tools"]) == 0:
self.streaming_state["sent_tools"].append({
"sent_name":
True,
"sent_arguments_prefix":
False,
"sent_arguments":
"",
})
else:
self.streaming_state["sent_tools"][0][
"sent_name"] = True
self.current_tool_name_sent = True
return delta
return None
def _ensure_state_arrays(self, tool_count: int):
while len(self.streaming_state["sent_tools"]) < tool_count:
self.streaming_state["sent_tools"].append({
"sent_name": False,
"sent_arguments_prefix": False,
"sent_arguments": "",
})
while len(self.streaming_state["tool_ids"]) < tool_count:
self.streaming_state["tool_ids"].append(None)
def _handle_tool_name_streaming(self, current_idx: int, tool_count: int,
name_matches):
if current_idx == -1 or current_idx < tool_count - 1:
next_idx = current_idx + 1
if (next_idx < tool_count
and not self.streaming_state["sent_tools"][next_idx]
["sent_name"]):
self.streaming_state["current_tool_index"] = next_idx
self.current_tool_id = next_idx
current_idx = next_idx
tool_name = name_matches[current_idx].group(1)
tool_id = f"call_{current_idx}_{random_uuid()}"
self.streaming_state["tool_ids"][current_idx] = tool_id
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=current_idx,
type="function",
id=tool_id,
function=DeltaFunctionCall(name=tool_name).model_dump(
exclude_none=True),
)
])
self.streaming_state["sent_tools"][current_idx][
"sent_name"] = True
self.current_tool_name_sent = True
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
return delta
return None
def _handle_tool_args_streaming(self, current_text: str, current_idx: int,
tool_count: int):
if current_idx >= 0 and current_idx < tool_count:
empty_args_match = self.tool_empty_arg_reg.search(current_text)
if empty_args_match and empty_args_match.start() > 0:
for i in range(tool_count):
if i == current_idx:
if not self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"]:
self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"] = True
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"] = "{}"
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += "{}"
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments="{}").model_dump(
exclude_none=True),
)
])
if current_idx < tool_count - 1:
self.streaming_state["current_tool_index"] += 1
self.current_tool_id = self.streaming_state[
"current_tool_index"]
return delta
args_matches = list(
self.tool_non_empty_arg_reg.finditer(current_text))
if current_idx < len(args_matches):
args_text = args_matches[current_idx].group(1)
is_last_tool = current_idx == tool_count - 1
if not is_last_tool:
next_tool_pos = current_text.find(
"},{", args_matches[current_idx].start())
if next_tool_pos != -1:
args_end_pos = (next_tool_pos + 1)
args_text = (
current_text[args_matches[current_idx].start(
):args_end_pos].split('"arguments":')[1].strip())
sent_args = self.streaming_state["sent_tools"][current_idx][
"sent_arguments"]
if not self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"] and args_text.startswith("{"):
self.streaming_state["sent_tools"][current_idx][
"sent_arguments_prefix"] = True
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"] = "{"
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += "{"
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments="{").model_dump(exclude_none=True),
)
])
return delta
if args_text.startswith(sent_args):
args_diff = args_text[len(sent_args):]
if args_diff:
self.streaming_state["sent_tools"][current_idx][
"sent_arguments"] = args_text
while len(self.streamed_args) <= current_idx:
self.streamed_args.append("")
self.streamed_args[current_idx] += args_diff
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=current_idx,
function=DeltaFunctionCall(
arguments=args_diff).model_dump(
exclude_none=True),
)
])
return delta
if args_text.endswith("}") and args_text == sent_args:
if current_idx < tool_count - 1:
self.streaming_state["current_tool_index"] += 1
self.current_tool_id = self.streaming_state[
"current_tool_index"]
return None