feat: tool call arg normalizer + smart-continue loop + XML fallback

This commit is contained in:
Roman | RyzenAdvanced
2026-05-26 14:50:43 +04:00
Unverified
parent 6f8d808db0
commit cae161023f

View File

@@ -2191,6 +2191,75 @@ def _inject_stored_reasoning(messages):
msg["reasoning_content"] = reasoning
return messages
def _normalize_tool_args(raw_args):
if not raw_args or raw_args == "{}":
return raw_args
try:
parsed = json.loads(raw_args)
if isinstance(parsed, dict):
if "Arguments" in parsed and "arguments" not in parsed:
inner = parsed["Arguments"]
if isinstance(inner, str):
inner = inner.strip()
for pfx in ("```json", "```"):
if inner.startswith(pfx):
inner = inner[len(pfx):].strip()
if inner.endswith("```"):
inner = inner[:-3].strip()
try:
inner_parsed = json.loads(inner)
if isinstance(inner_parsed, dict):
return json.dumps(inner_parsed)
except json.JSONDecodeError:
pass
if "cmd" not in parsed and "Arguments" in parsed:
inner = parsed["Arguments"]
if isinstance(inner, str):
inner = inner.strip()
for pfx in ("```json", "```"):
if inner.startswith(pfx):
inner = inner[len(pfx):].strip()
if inner.endswith("```"):
inner = inner[:-3].strip()
try:
inner_parsed = json.loads(inner)
if isinstance(inner_parsed, dict):
return json.dumps(inner_parsed)
except json.JSONDecodeError:
pass
return raw_args
except json.JSONDecodeError:
return raw_args
_XML_TC_RE = re.compile(r'<tool_call>(\w+)(.*?)</tool_call>', re.DOTALL)
_XML_ARG_VALUE_RE = re.compile(r'</?arg_value>\s*')
def _extract_xml_tool_calls(text):
if not text:
return []
results = []
for m in _XML_TC_RE.finditer(text):
name = m.group(1)
rest = _XML_ARG_VALUE_RE.sub("", m.group(2)).strip()
args_str = "{}"
try:
for pfx in ("```json", "```"):
if rest.startswith(pfx):
rest = rest[len(pfx):].strip()
if rest.endswith("```"):
rest = rest[:-3].strip()
if rest.startswith("{"):
json.loads(rest)
args_str = rest
else:
json.loads(rest)
args_str = rest
except Exception:
if rest.startswith("{"):
args_str = rest
results.append({"name": name, "args": args_str, "call_id": f"xml_{len(results)}"})
return results
def oa_input_to_messages(input_data):
msgs = []
tool_name_by_id = {}
@@ -2203,11 +2272,13 @@ def oa_input_to_messages(input_data):
t = item.get("type")
if t == "function_call":
tcid = item.get("call_id") or item.get("id") or uid("tc")
raw_args = item.get("arguments", "{}")
normalized_args = _normalize_tool_args(raw_args)
pending_tool_calls.append(
{"id": tcid,
"type": "function",
"function": {"name": item.get("name", ""),
"arguments": item.get("arguments", "{}")}})
"arguments": normalized_args}})
tool_name_by_id[tcid] = item.get("name", "")
continue
if pending_tool_calls:
@@ -5748,23 +5819,63 @@ class Handler(http.server.BaseHTTPRequestHandler):
except Exception as e:
print(f"[crof-adaptive] retry failed: {e}", file=sys.stderr)
# Smart continuation: if model returned finish_reason=stop with only text (no tool calls)
# during an active tool-using session, nudge it to continue working.
# Smart continuation: loop with escalating nudges when model stops text-only mid-task.
_smart_max = 2
_smart_attempt = 0
while _smart_attempt < _smart_max:
_has_tool_calls_in_output = any(o.get("type") == "function_call" for o in (last_output or []))
if (finish_reason == "stop" and has_content and not _has_tool_calls_in_output
if not (finish_reason == "stop" and has_content and not _has_tool_calls_in_output
and isinstance(input_data, list) and len(input_data) >= 3
and has_function_call_output(input_data)):
_nudge_msg = {
"role": "user",
"content": "Continue with the task. If you need to make changes or gather more information, use the appropriate tools. Do NOT just describe what to do — take action using tool calls."
}
break
_smart_attempt += 1
_nudges = [
"Continue with the task using tool calls. Do NOT describe what to do — call the appropriate functions.",
"You MUST use tool calls to complete the task. Read files, run commands, and make changes using tools. Do NOT output XML tool calls as text.",
]
nudge_text = _nudges[min(_smart_attempt - 1, len(_nudges) - 1)]
# Try extracting XML tool calls from text as fallback before nudging
last_text = ""
for o in (last_output or []):
if o.get("type") == "message":
for c in (o.get("content") or []):
if isinstance(c, dict) and c.get("type") == "output_text":
last_text += c.get("text", "")
xml_fc = _extract_xml_tool_calls(last_text)
if xml_fc:
print(f"[{self._session_id}] [smart-continue] extracted {len(xml_fc)} XML tool calls from text, injecting and retrying", file=sys.stderr)
fake_input = list(input_data)
for xfc in xml_fc:
fake_input.append({"type": "function_call", "id": uid("fcx"), "call_id": uid("fcx"),
"name": xfc["name"], "arguments": xfc["args"], "status": "completed"})
fake_messages = oa_input_to_messages(fake_input)
instructions = body.get("instructions", "").strip()
if instructions:
fake_messages.insert(0, {"role": "system", "content": instructions})
fake_chat_body = self._build_chat_body(model, fake_messages, body, stream)
fake_req = urllib.request.Request(target, data=json.dumps(fake_chat_body).encode(), headers=fwd)
try:
retry_upstream = urllib.request.urlopen(fake_req, timeout=_upstream_timeout(body, True))
collected_events = []
last_resp_id = last_output = last_status = None
finish_reason = None
has_content = False
for event in oa_stream_to_sse(retry_upstream, model, body.get("request_id") or body.get("id")):
collected_events.append(event)
_observe_event(event)
input_data = fake_input
continue
except Exception as e:
print(f"[{self._session_id}] [smart-continue] XML injection retry failed: {e}", file=sys.stderr)
break
_nudge_msg = {"role": "user", "content": nudge_text}
nudge_messages = oa_input_to_messages(input_data) + [_nudge_msg]
instructions = body.get("instructions", "").strip()
if instructions:
nudge_messages.insert(0, {"role": "system", "content": instructions})
nudge_chat_body = self._build_chat_body(model, nudge_messages, body, stream)
nudge_req = urllib.request.Request(target, data=json.dumps(nudge_chat_body).encode(), headers=fwd)
print(f"[{self._session_id}] [smart-continue] model stopped mid-task without tool calls, nudging continuation", file=sys.stderr)
print(f"[{self._session_id}] [smart-continue] attempt {_smart_attempt}/{_smart_max}: model stopped mid-task, nudging", file=sys.stderr)
try:
retry_upstream = urllib.request.urlopen(nudge_req, timeout=_upstream_timeout(body, True))
collected_events = []
@@ -5775,7 +5886,8 @@ class Handler(http.server.BaseHTTPRequestHandler):
collected_events.append(event)
_observe_event(event)
except Exception as e:
print(f"[{self._session_id}] [smart-continue] nudge retry failed: {e}", file=sys.stderr)
print(f"[{self._session_id}] [smart-continue] nudge attempt {_smart_attempt} failed: {e}", file=sys.stderr)
break
self.stream_buffered_events(collected_events)
else: