feat: tool call arg normalizer + smart-continue loop + XML fallback
This commit is contained in:
@@ -2191,6 +2191,75 @@ def _inject_stored_reasoning(messages):
|
|||||||
msg["reasoning_content"] = reasoning
|
msg["reasoning_content"] = reasoning
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def _normalize_tool_args(raw_args):
|
||||||
|
if not raw_args or raw_args == "{}":
|
||||||
|
return raw_args
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw_args)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
if "Arguments" in parsed and "arguments" not in parsed:
|
||||||
|
inner = parsed["Arguments"]
|
||||||
|
if isinstance(inner, str):
|
||||||
|
inner = inner.strip()
|
||||||
|
for pfx in ("```json", "```"):
|
||||||
|
if inner.startswith(pfx):
|
||||||
|
inner = inner[len(pfx):].strip()
|
||||||
|
if inner.endswith("```"):
|
||||||
|
inner = inner[:-3].strip()
|
||||||
|
try:
|
||||||
|
inner_parsed = json.loads(inner)
|
||||||
|
if isinstance(inner_parsed, dict):
|
||||||
|
return json.dumps(inner_parsed)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
if "cmd" not in parsed and "Arguments" in parsed:
|
||||||
|
inner = parsed["Arguments"]
|
||||||
|
if isinstance(inner, str):
|
||||||
|
inner = inner.strip()
|
||||||
|
for pfx in ("```json", "```"):
|
||||||
|
if inner.startswith(pfx):
|
||||||
|
inner = inner[len(pfx):].strip()
|
||||||
|
if inner.endswith("```"):
|
||||||
|
inner = inner[:-3].strip()
|
||||||
|
try:
|
||||||
|
inner_parsed = json.loads(inner)
|
||||||
|
if isinstance(inner_parsed, dict):
|
||||||
|
return json.dumps(inner_parsed)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return raw_args
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return raw_args
|
||||||
|
|
||||||
|
_XML_TC_RE = re.compile(r'<tool_call>(\w+)(.*?)</tool_call>', re.DOTALL)
|
||||||
|
_XML_ARG_VALUE_RE = re.compile(r'</?arg_value>\s*')
|
||||||
|
|
||||||
|
def _extract_xml_tool_calls(text):
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
results = []
|
||||||
|
for m in _XML_TC_RE.finditer(text):
|
||||||
|
name = m.group(1)
|
||||||
|
rest = _XML_ARG_VALUE_RE.sub("", m.group(2)).strip()
|
||||||
|
args_str = "{}"
|
||||||
|
try:
|
||||||
|
for pfx in ("```json", "```"):
|
||||||
|
if rest.startswith(pfx):
|
||||||
|
rest = rest[len(pfx):].strip()
|
||||||
|
if rest.endswith("```"):
|
||||||
|
rest = rest[:-3].strip()
|
||||||
|
if rest.startswith("{"):
|
||||||
|
json.loads(rest)
|
||||||
|
args_str = rest
|
||||||
|
else:
|
||||||
|
json.loads(rest)
|
||||||
|
args_str = rest
|
||||||
|
except Exception:
|
||||||
|
if rest.startswith("{"):
|
||||||
|
args_str = rest
|
||||||
|
results.append({"name": name, "args": args_str, "call_id": f"xml_{len(results)}"})
|
||||||
|
return results
|
||||||
|
|
||||||
def oa_input_to_messages(input_data):
|
def oa_input_to_messages(input_data):
|
||||||
msgs = []
|
msgs = []
|
||||||
tool_name_by_id = {}
|
tool_name_by_id = {}
|
||||||
@@ -2203,11 +2272,13 @@ def oa_input_to_messages(input_data):
|
|||||||
t = item.get("type")
|
t = item.get("type")
|
||||||
if t == "function_call":
|
if t == "function_call":
|
||||||
tcid = item.get("call_id") or item.get("id") or uid("tc")
|
tcid = item.get("call_id") or item.get("id") or uid("tc")
|
||||||
|
raw_args = item.get("arguments", "{}")
|
||||||
|
normalized_args = _normalize_tool_args(raw_args)
|
||||||
pending_tool_calls.append(
|
pending_tool_calls.append(
|
||||||
{"id": tcid,
|
{"id": tcid,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {"name": item.get("name", ""),
|
"function": {"name": item.get("name", ""),
|
||||||
"arguments": item.get("arguments", "{}")}})
|
"arguments": normalized_args}})
|
||||||
tool_name_by_id[tcid] = item.get("name", "")
|
tool_name_by_id[tcid] = item.get("name", "")
|
||||||
continue
|
continue
|
||||||
if pending_tool_calls:
|
if pending_tool_calls:
|
||||||
@@ -5748,23 +5819,63 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[crof-adaptive] retry failed: {e}", file=sys.stderr)
|
print(f"[crof-adaptive] retry failed: {e}", file=sys.stderr)
|
||||||
|
|
||||||
# Smart continuation: if model returned finish_reason=stop with only text (no tool calls)
|
# Smart continuation: loop with escalating nudges when model stops text-only mid-task.
|
||||||
# during an active tool-using session, nudge it to continue working.
|
_smart_max = 2
|
||||||
_has_tool_calls_in_output = any(o.get("type") == "function_call" for o in (last_output or []))
|
_smart_attempt = 0
|
||||||
if (finish_reason == "stop" and has_content and not _has_tool_calls_in_output
|
while _smart_attempt < _smart_max:
|
||||||
and isinstance(input_data, list) and len(input_data) >= 3
|
_has_tool_calls_in_output = any(o.get("type") == "function_call" for o in (last_output or []))
|
||||||
and has_function_call_output(input_data)):
|
if not (finish_reason == "stop" and has_content and not _has_tool_calls_in_output
|
||||||
_nudge_msg = {
|
and isinstance(input_data, list) and len(input_data) >= 3
|
||||||
"role": "user",
|
and has_function_call_output(input_data)):
|
||||||
"content": "Continue with the task. If you need to make changes or gather more information, use the appropriate tools. Do NOT just describe what to do — take action using tool calls."
|
break
|
||||||
}
|
_smart_attempt += 1
|
||||||
|
_nudges = [
|
||||||
|
"Continue with the task using tool calls. Do NOT describe what to do — call the appropriate functions.",
|
||||||
|
"You MUST use tool calls to complete the task. Read files, run commands, and make changes using tools. Do NOT output XML tool calls as text.",
|
||||||
|
]
|
||||||
|
nudge_text = _nudges[min(_smart_attempt - 1, len(_nudges) - 1)]
|
||||||
|
# Try extracting XML tool calls from text as fallback before nudging
|
||||||
|
last_text = ""
|
||||||
|
for o in (last_output or []):
|
||||||
|
if o.get("type") == "message":
|
||||||
|
for c in (o.get("content") or []):
|
||||||
|
if isinstance(c, dict) and c.get("type") == "output_text":
|
||||||
|
last_text += c.get("text", "")
|
||||||
|
xml_fc = _extract_xml_tool_calls(last_text)
|
||||||
|
if xml_fc:
|
||||||
|
print(f"[{self._session_id}] [smart-continue] extracted {len(xml_fc)} XML tool calls from text, injecting and retrying", file=sys.stderr)
|
||||||
|
fake_input = list(input_data)
|
||||||
|
for xfc in xml_fc:
|
||||||
|
fake_input.append({"type": "function_call", "id": uid("fcx"), "call_id": uid("fcx"),
|
||||||
|
"name": xfc["name"], "arguments": xfc["args"], "status": "completed"})
|
||||||
|
fake_messages = oa_input_to_messages(fake_input)
|
||||||
|
instructions = body.get("instructions", "").strip()
|
||||||
|
if instructions:
|
||||||
|
fake_messages.insert(0, {"role": "system", "content": instructions})
|
||||||
|
fake_chat_body = self._build_chat_body(model, fake_messages, body, stream)
|
||||||
|
fake_req = urllib.request.Request(target, data=json.dumps(fake_chat_body).encode(), headers=fwd)
|
||||||
|
try:
|
||||||
|
retry_upstream = urllib.request.urlopen(fake_req, timeout=_upstream_timeout(body, True))
|
||||||
|
collected_events = []
|
||||||
|
last_resp_id = last_output = last_status = None
|
||||||
|
finish_reason = None
|
||||||
|
has_content = False
|
||||||
|
for event in oa_stream_to_sse(retry_upstream, model, body.get("request_id") or body.get("id")):
|
||||||
|
collected_events.append(event)
|
||||||
|
_observe_event(event)
|
||||||
|
input_data = fake_input
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{self._session_id}] [smart-continue] XML injection retry failed: {e}", file=sys.stderr)
|
||||||
|
break
|
||||||
|
_nudge_msg = {"role": "user", "content": nudge_text}
|
||||||
nudge_messages = oa_input_to_messages(input_data) + [_nudge_msg]
|
nudge_messages = oa_input_to_messages(input_data) + [_nudge_msg]
|
||||||
instructions = body.get("instructions", "").strip()
|
instructions = body.get("instructions", "").strip()
|
||||||
if instructions:
|
if instructions:
|
||||||
nudge_messages.insert(0, {"role": "system", "content": instructions})
|
nudge_messages.insert(0, {"role": "system", "content": instructions})
|
||||||
nudge_chat_body = self._build_chat_body(model, nudge_messages, body, stream)
|
nudge_chat_body = self._build_chat_body(model, nudge_messages, body, stream)
|
||||||
nudge_req = urllib.request.Request(target, data=json.dumps(nudge_chat_body).encode(), headers=fwd)
|
nudge_req = urllib.request.Request(target, data=json.dumps(nudge_chat_body).encode(), headers=fwd)
|
||||||
print(f"[{self._session_id}] [smart-continue] model stopped mid-task without tool calls, nudging continuation", file=sys.stderr)
|
print(f"[{self._session_id}] [smart-continue] attempt {_smart_attempt}/{_smart_max}: model stopped mid-task, nudging", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
retry_upstream = urllib.request.urlopen(nudge_req, timeout=_upstream_timeout(body, True))
|
retry_upstream = urllib.request.urlopen(nudge_req, timeout=_upstream_timeout(body, True))
|
||||||
collected_events = []
|
collected_events = []
|
||||||
@@ -5775,7 +5886,8 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
|||||||
collected_events.append(event)
|
collected_events.append(event)
|
||||||
_observe_event(event)
|
_observe_event(event)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{self._session_id}] [smart-continue] nudge retry failed: {e}", file=sys.stderr)
|
print(f"[{self._session_id}] [smart-continue] nudge attempt {_smart_attempt} failed: {e}", file=sys.stderr)
|
||||||
|
break
|
||||||
|
|
||||||
self.stream_buffered_events(collected_events)
|
self.stream_buffered_events(collected_events)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user