feat: tool call arg normalizer + smart-continue loop + XML fallback
This commit is contained in:
@@ -2191,6 +2191,75 @@ def _inject_stored_reasoning(messages):
|
||||
msg["reasoning_content"] = reasoning
|
||||
return messages
|
||||
|
||||
def _normalize_tool_args(raw_args):
|
||||
if not raw_args or raw_args == "{}":
|
||||
return raw_args
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
if isinstance(parsed, dict):
|
||||
if "Arguments" in parsed and "arguments" not in parsed:
|
||||
inner = parsed["Arguments"]
|
||||
if isinstance(inner, str):
|
||||
inner = inner.strip()
|
||||
for pfx in ("```json", "```"):
|
||||
if inner.startswith(pfx):
|
||||
inner = inner[len(pfx):].strip()
|
||||
if inner.endswith("```"):
|
||||
inner = inner[:-3].strip()
|
||||
try:
|
||||
inner_parsed = json.loads(inner)
|
||||
if isinstance(inner_parsed, dict):
|
||||
return json.dumps(inner_parsed)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if "cmd" not in parsed and "Arguments" in parsed:
|
||||
inner = parsed["Arguments"]
|
||||
if isinstance(inner, str):
|
||||
inner = inner.strip()
|
||||
for pfx in ("```json", "```"):
|
||||
if inner.startswith(pfx):
|
||||
inner = inner[len(pfx):].strip()
|
||||
if inner.endswith("```"):
|
||||
inner = inner[:-3].strip()
|
||||
try:
|
||||
inner_parsed = json.loads(inner)
|
||||
if isinstance(inner_parsed, dict):
|
||||
return json.dumps(inner_parsed)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return raw_args
|
||||
except json.JSONDecodeError:
|
||||
return raw_args
|
||||
|
||||
_XML_TC_RE = re.compile(r'<tool_call>(\w+)(.*?)</tool_call>', re.DOTALL)
|
||||
_XML_ARG_VALUE_RE = re.compile(r'</?arg_value>\s*')
|
||||
|
||||
def _extract_xml_tool_calls(text):
|
||||
if not text:
|
||||
return []
|
||||
results = []
|
||||
for m in _XML_TC_RE.finditer(text):
|
||||
name = m.group(1)
|
||||
rest = _XML_ARG_VALUE_RE.sub("", m.group(2)).strip()
|
||||
args_str = "{}"
|
||||
try:
|
||||
for pfx in ("```json", "```"):
|
||||
if rest.startswith(pfx):
|
||||
rest = rest[len(pfx):].strip()
|
||||
if rest.endswith("```"):
|
||||
rest = rest[:-3].strip()
|
||||
if rest.startswith("{"):
|
||||
json.loads(rest)
|
||||
args_str = rest
|
||||
else:
|
||||
json.loads(rest)
|
||||
args_str = rest
|
||||
except Exception:
|
||||
if rest.startswith("{"):
|
||||
args_str = rest
|
||||
results.append({"name": name, "args": args_str, "call_id": f"xml_{len(results)}"})
|
||||
return results
|
||||
|
||||
def oa_input_to_messages(input_data):
|
||||
msgs = []
|
||||
tool_name_by_id = {}
|
||||
@@ -2203,11 +2272,13 @@ def oa_input_to_messages(input_data):
|
||||
t = item.get("type")
|
||||
if t == "function_call":
|
||||
tcid = item.get("call_id") or item.get("id") or uid("tc")
|
||||
raw_args = item.get("arguments", "{}")
|
||||
normalized_args = _normalize_tool_args(raw_args)
|
||||
pending_tool_calls.append(
|
||||
{"id": tcid,
|
||||
"type": "function",
|
||||
"function": {"name": item.get("name", ""),
|
||||
"arguments": item.get("arguments", "{}")}})
|
||||
"arguments": normalized_args}})
|
||||
tool_name_by_id[tcid] = item.get("name", "")
|
||||
continue
|
||||
if pending_tool_calls:
|
||||
@@ -5748,23 +5819,63 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
except Exception as e:
|
||||
print(f"[crof-adaptive] retry failed: {e}", file=sys.stderr)
|
||||
|
||||
# Smart continuation: if model returned finish_reason=stop with only text (no tool calls)
|
||||
# during an active tool-using session, nudge it to continue working.
|
||||
_has_tool_calls_in_output = any(o.get("type") == "function_call" for o in (last_output or []))
|
||||
if (finish_reason == "stop" and has_content and not _has_tool_calls_in_output
|
||||
and isinstance(input_data, list) and len(input_data) >= 3
|
||||
and has_function_call_output(input_data)):
|
||||
_nudge_msg = {
|
||||
"role": "user",
|
||||
"content": "Continue with the task. If you need to make changes or gather more information, use the appropriate tools. Do NOT just describe what to do — take action using tool calls."
|
||||
}
|
||||
# Smart continuation: loop with escalating nudges when model stops text-only mid-task.
|
||||
_smart_max = 2
|
||||
_smart_attempt = 0
|
||||
while _smart_attempt < _smart_max:
|
||||
_has_tool_calls_in_output = any(o.get("type") == "function_call" for o in (last_output or []))
|
||||
if not (finish_reason == "stop" and has_content and not _has_tool_calls_in_output
|
||||
and isinstance(input_data, list) and len(input_data) >= 3
|
||||
and has_function_call_output(input_data)):
|
||||
break
|
||||
_smart_attempt += 1
|
||||
_nudges = [
|
||||
"Continue with the task using tool calls. Do NOT describe what to do — call the appropriate functions.",
|
||||
"You MUST use tool calls to complete the task. Read files, run commands, and make changes using tools. Do NOT output XML tool calls as text.",
|
||||
]
|
||||
nudge_text = _nudges[min(_smart_attempt - 1, len(_nudges) - 1)]
|
||||
# Try extracting XML tool calls from text as fallback before nudging
|
||||
last_text = ""
|
||||
for o in (last_output or []):
|
||||
if o.get("type") == "message":
|
||||
for c in (o.get("content") or []):
|
||||
if isinstance(c, dict) and c.get("type") == "output_text":
|
||||
last_text += c.get("text", "")
|
||||
xml_fc = _extract_xml_tool_calls(last_text)
|
||||
if xml_fc:
|
||||
print(f"[{self._session_id}] [smart-continue] extracted {len(xml_fc)} XML tool calls from text, injecting and retrying", file=sys.stderr)
|
||||
fake_input = list(input_data)
|
||||
for xfc in xml_fc:
|
||||
fake_input.append({"type": "function_call", "id": uid("fcx"), "call_id": uid("fcx"),
|
||||
"name": xfc["name"], "arguments": xfc["args"], "status": "completed"})
|
||||
fake_messages = oa_input_to_messages(fake_input)
|
||||
instructions = body.get("instructions", "").strip()
|
||||
if instructions:
|
||||
fake_messages.insert(0, {"role": "system", "content": instructions})
|
||||
fake_chat_body = self._build_chat_body(model, fake_messages, body, stream)
|
||||
fake_req = urllib.request.Request(target, data=json.dumps(fake_chat_body).encode(), headers=fwd)
|
||||
try:
|
||||
retry_upstream = urllib.request.urlopen(fake_req, timeout=_upstream_timeout(body, True))
|
||||
collected_events = []
|
||||
last_resp_id = last_output = last_status = None
|
||||
finish_reason = None
|
||||
has_content = False
|
||||
for event in oa_stream_to_sse(retry_upstream, model, body.get("request_id") or body.get("id")):
|
||||
collected_events.append(event)
|
||||
_observe_event(event)
|
||||
input_data = fake_input
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"[{self._session_id}] [smart-continue] XML injection retry failed: {e}", file=sys.stderr)
|
||||
break
|
||||
_nudge_msg = {"role": "user", "content": nudge_text}
|
||||
nudge_messages = oa_input_to_messages(input_data) + [_nudge_msg]
|
||||
instructions = body.get("instructions", "").strip()
|
||||
if instructions:
|
||||
nudge_messages.insert(0, {"role": "system", "content": instructions})
|
||||
nudge_chat_body = self._build_chat_body(model, nudge_messages, body, stream)
|
||||
nudge_req = urllib.request.Request(target, data=json.dumps(nudge_chat_body).encode(), headers=fwd)
|
||||
print(f"[{self._session_id}] [smart-continue] model stopped mid-task without tool calls, nudging continuation", file=sys.stderr)
|
||||
print(f"[{self._session_id}] [smart-continue] attempt {_smart_attempt}/{_smart_max}: model stopped mid-task, nudging", file=sys.stderr)
|
||||
try:
|
||||
retry_upstream = urllib.request.urlopen(nudge_req, timeout=_upstream_timeout(body, True))
|
||||
collected_events = []
|
||||
@@ -5775,7 +5886,8 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
collected_events.append(event)
|
||||
_observe_event(event)
|
||||
except Exception as e:
|
||||
print(f"[{self._session_id}] [smart-continue] nudge retry failed: {e}", file=sys.stderr)
|
||||
print(f"[{self._session_id}] [smart-continue] nudge attempt {_smart_attempt} failed: {e}", file=sys.stderr)
|
||||
break
|
||||
|
||||
self.stream_buffered_events(collected_events)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user