v3.12.0: gRPC auto-fallback for Antigravity (PR #13)
This commit is contained in:
@@ -165,6 +165,56 @@ import tempfile
|
||||
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
# Lazy gRPC import for Antigravity fallback
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
_antigravity_grpc_client = None
|
||||
_antigravity_grpc_available = None
|
||||
|
||||
def _get_grpc_client():
|
||||
"""Lazy-load the Antigravity gRPC client. Returns None if grpcio is not installed."""
|
||||
global _antigravity_grpc_client, _antigravity_grpc_available
|
||||
if _antigravity_grpc_available is False:
|
||||
return None
|
||||
if _antigravity_grpc_client is not None:
|
||||
return _antigravity_grpc_client
|
||||
try:
|
||||
# Add the src directory to sys.path so antigravity_grpc package is found
|
||||
_src_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if _src_dir not in sys.path:
|
||||
sys.path.insert(0, _src_dir)
|
||||
from antigravity_grpc import is_grpc_available, AntigravityGrpcClient, get_client
|
||||
if is_grpc_available():
|
||||
_antigravity_grpc_client = get_client()
|
||||
_antigravity_grpc_available = True
|
||||
print("[antigravity-grpc] gRPC fallback module loaded OK", file=sys.stderr)
|
||||
return _antigravity_grpc_client
|
||||
else:
|
||||
_antigravity_grpc_available = False
|
||||
print("[antigravity-grpc] grpcio available but stubs failed to load, gRPC fallback disabled", file=sys.stderr)
|
||||
return None
|
||||
except ImportError as e:
|
||||
_antigravity_grpc_available = False
|
||||
print(f"[antigravity-grpc] grpcio not installed ({e}), gRPC fallback disabled", file=sys.stderr)
|
||||
return None
|
||||
|
||||
# Reverse alias map: REST slug → gRPC display name
|
||||
# gRPC uses display names (e.g. "Gemini 3.5 Flash (High)") while REST uses slugs (e.g. "gemini-3-flash")
|
||||
_GRPC_REVERSE_ALIAS = {
|
||||
"gemini-3-flash": "Gemini 3.5 Flash (High)",
|
||||
"gemini-3.5-flash-low": "Gemini 3.5 Flash (Low)",
|
||||
"gemini-3.1-pro-low": "Gemini 3.1 Pro (High)",
|
||||
"claude-sonnet-4-6": "Claude Sonnet 4.6 (Thinking)",
|
||||
"claude-opus-4-6-thinking": "Claude Opus 4.6 (Thinking)",
|
||||
"gpt-oss-120b-medium": "GPT-OSS 120B (Medium)",
|
||||
"gemini-2.5-flash": "Gemini 2.5 Flash",
|
||||
"gemini-2.5-pro": "Gemini 2.5 Pro",
|
||||
"gemini-2.5-flash-lite": "Gemini 2.5 Flash Lite",
|
||||
}
|
||||
|
||||
# Errors from REST that should trigger gRPC fallback
|
||||
_GRPC_FALLBACK_REST_ERRORS = {404} # Model not found via REST (model exists in gRPC but not REST)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
# Config
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
@@ -5762,7 +5812,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
_antigravity_endpoints.append("https://autopush-cloudcode-pa.sandbox.googleapis.com")
|
||||
|
||||
body_b = json.dumps(wrapped).encode()
|
||||
print(f"[{self._session_id}] [antigravity-v2] model={model} stream={stream} contents={len(contents)} tools={bool(gemini_tools)} project={project_id} ver={_versions[0]}", file=sys.stderr)
|
||||
print(f"[{self._session_id}] [antigravity-v2] model={model} stream={stream} contents={len(contents)} tools={bool(gemini_tools)} project={project_id} ver={_fetched_ver}", file=sys.stderr)
|
||||
try:
|
||||
debug_path = os.path.join(_LOG_DIR, f"antigravity-v2-request-{self._session_id}.json")
|
||||
with open(debug_path, "w") as dbg:
|
||||
@@ -5863,6 +5913,14 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
continue
|
||||
|
||||
if upstream is None:
|
||||
# ─── gRPC FALLBACK ─────────────────────────────────────────
|
||||
# If REST failed with 404 (model not available via REST API),
|
||||
# try gRPC which supports display names and has a wider model catalog.
|
||||
if _all_404:
|
||||
grpc_result = self._try_grpc_fallback(wrapped, access_token, stream, tracker)
|
||||
if grpc_result is not None:
|
||||
return # gRPC succeeded, response already sent
|
||||
# ─── END gRPC FALLBACK ─────────────────────────────────────
|
||||
return self.send_json(502, {"error": {"type": "proxy_error", "message": "All endpoints failed"}})
|
||||
|
||||
if stream:
|
||||
@@ -5870,6 +5928,190 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
else:
|
||||
self._forward_gemini_json(upstream, model, body, input_data)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
# gRPC Fallback for Antigravity
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
|
||||
def _try_grpc_fallback(self, wrapped_dict, access_token, stream, tracker=None):
|
||||
"""
|
||||
Try gRPC fallback when REST API returns 404 (model not found).
|
||||
|
||||
gRPC uses display names (e.g. "Gemini 3.5 Flash (High)") instead of
|
||||
REST slugs (e.g. "gemini-3-flash"), so models unavailable via REST
|
||||
may work via gRPC.
|
||||
|
||||
Returns None if gRPC is unavailable or also failed (caller should
|
||||
send its own error response). Returns True if gRPC succeeded and
|
||||
the response was already sent to the client.
|
||||
"""
|
||||
grpc_client = _get_grpc_client()
|
||||
if grpc_client is None:
|
||||
print(f"[{self._session_id}] [antigravity-grpc] gRPC fallback not available (grpcio not installed), skipping", file=sys.stderr)
|
||||
return None
|
||||
|
||||
# gRPC uses display names, not REST slugs — remap the model ID
|
||||
grpc_wrapped = dict(wrapped_dict)
|
||||
rest_model = grpc_wrapped.get("model", "")
|
||||
grpc_model = _GRPC_REVERSE_ALIAS.get(rest_model, rest_model)
|
||||
grpc_wrapped["model"] = grpc_model
|
||||
if grpc_model != rest_model:
|
||||
print(f"[{self._session_id}] [antigravity-grpc] model remapped for gRPC: REST={rest_model} -> gRPC={grpc_model}", file=sys.stderr)
|
||||
|
||||
print(f"[{self._session_id}] [antigravity-grpc] REST 404, trying gRPC fallback with model={grpc_model} stream={stream}", file=sys.stderr)
|
||||
|
||||
try:
|
||||
result = grpc_client.try_generate(
|
||||
grpc_wrapped,
|
||||
stream=stream,
|
||||
access_token=access_token,
|
||||
timeout_s=180,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self._session_id}] [antigravity-grpc] gRPC call exception: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
if not result.ok:
|
||||
print(f"[{self._session_id}] [antigravity-grpc] gRPC fallback also failed: {result.error_message}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
print(f"[{self._session_id}] [antigravity-grpc] gRPC fallback OK! endpoint={result.endpoint_used} model={result.model_used} elapsed={result.elapsed_s:.1f}s", file=sys.stderr)
|
||||
|
||||
# Process the gRPC response through the same forwarding paths as REST
|
||||
if stream and result.stream_chunks is not None:
|
||||
self._forward_grpc_sse(result, grpc_model)
|
||||
elif not stream and result.response_data is not None:
|
||||
self._forward_grpc_json(result, grpc_model)
|
||||
else:
|
||||
print(f"[{self._session_id}] [antigravity-grpc] unexpected result shape, no data to forward", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return True # Response sent successfully via gRPC
|
||||
|
||||
def _forward_grpc_sse(self, grpc_result, model):
|
||||
"""
|
||||
Forward a gRPC streaming result to the client as SSE events.
|
||||
The gRPC result contains stream_chunks that match the REST SSE chunk shape,
|
||||
so we can process them through the same _forward_gemini_sse logic.
|
||||
"""
|
||||
resp_id = f"resp-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self.send_header("Connection", "keep-alive")
|
||||
self.end_headers()
|
||||
|
||||
full_text = ""
|
||||
output_items = []
|
||||
current_tool_calls = {}
|
||||
message_started = False
|
||||
message_id = f"msg-{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def flush_event(event_type, data):
|
||||
self.wfile.write(f"event: {event_type}\ndata: {json.dumps(data)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
flush_event("response.created", {"type": "response.created", "response": {"id": resp_id, "object": "response", "model": model, "status": "in_progress", "created": created, "output": []}})
|
||||
flush_event("response.in_progress", {"type": "response.in_progress", "response": {"id": resp_id}})
|
||||
|
||||
# Process each gRPC chunk (same shape as REST SSE chunks)
|
||||
for chunk in grpc_result.stream_chunks:
|
||||
candidates = chunk.get("response", chunk).get("candidates", [])
|
||||
if not candidates:
|
||||
continue
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
sig = _extract_gemini_sig(part)
|
||||
if sig:
|
||||
if part.get("functionCall"):
|
||||
fc_id = part["functionCall"].get("id") or part["functionCall"].get("name")
|
||||
fc_name = part["functionCall"].get("name")
|
||||
if fc_id:
|
||||
_gemini_store_sig(f"fc:{fc_id}", sig)
|
||||
if fc_name:
|
||||
_gemini_store_sig(f"fc:{fc_name}", sig)
|
||||
_gemini_store_sig(f"turn:{resp_id}", sig)
|
||||
if part.get("thought"):
|
||||
sig_from_thought = _extract_gemini_sig(part)
|
||||
if sig_from_thought:
|
||||
_gemini_store_sig(f"turn:{resp_id}", sig_from_thought)
|
||||
continue
|
||||
if "text" in part and not part.get("functionCall"):
|
||||
text_delta = part["text"]
|
||||
if not text_delta:
|
||||
continue
|
||||
full_text += text_delta
|
||||
if not message_started:
|
||||
flush_event("response.output_item.added", {"type": "response.output_item.added", "output_index": 0, "item": {"type": "message", "id": message_id, "role": "assistant", "content": []}})
|
||||
flush_event("response.content_part.added", {"type": "response.content_part.added", "output_index": 0, "content_index": 0, "part": {"type": "output_text", "text": ""}})
|
||||
output_items.append({"text": True})
|
||||
message_started = True
|
||||
flush_event("response.output_text.delta", {"type": "response.output_text.delta", "output_index": 0, "content_index": 0, "delta": text_delta})
|
||||
elif part.get("functionCall"):
|
||||
fc = part["functionCall"]
|
||||
call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
args_str = json.dumps(fc.get("args", fc.get("arguments", {})))
|
||||
output_index = len(output_items)
|
||||
flush_event("response.output_item.added", {"type": "response.output_item.added", "output_index": output_index, "item": {"type": "function_call", "id": call_id, "call_id": call_id, "name": fc.get("name", ""), "arguments": ""}})
|
||||
flush_event("response.function_call_arguments.delta", {"type": "response.function_call_arguments.delta", "output_index": output_index, "item_id": call_id, "delta": args_str})
|
||||
flush_event("response.function_call_arguments.done", {"type": "response.function_call_arguments.done", "output_index": output_index, "item_id": call_id, "arguments": args_str})
|
||||
current_tool_calls[call_id] = fc
|
||||
output_items.append({"tool": True})
|
||||
|
||||
# Build final response
|
||||
out = []
|
||||
if full_text:
|
||||
out.append({"type": "message", "id": message_id, "role": "assistant", "content": [{"type": "output_text", "text": full_text}]})
|
||||
tool_outputs = []
|
||||
for cid, fc in current_tool_calls.items():
|
||||
tool_outputs.append({"type": "function_call", "id": cid, "call_id": cid, "name": fc.get("name", ""), "arguments": json.dumps(fc.get("args", fc.get("arguments", {})))})
|
||||
out.extend(tool_outputs)
|
||||
|
||||
final_resp = {"id": resp_id, "object": "response", "model": model, "status": "completed", "created": created, "output": out}
|
||||
if full_text:
|
||||
flush_event("response.output_text.done", {"type": "response.output_text.done", "output_index": 0, "content_index": 0, "text": full_text})
|
||||
flush_event("response.content_part.done", {"type": "response.content_part.done", "output_index": 0, "content_index": 0, "part": {"type": "output_text", "text": full_text}})
|
||||
flush_event("response.output_item.done", {"type": "response.output_item.done", "output_index": 0, "item": out[0]})
|
||||
for idx, item in enumerate(tool_outputs, start=(1 if full_text else 0)):
|
||||
flush_event("response.output_item.done", {"type": "response.output_item.done", "output_index": idx, "item": item})
|
||||
flush_event("response.completed", {"type": "response.completed", "response": final_resp})
|
||||
self.close_connection = True
|
||||
|
||||
with _response_store_lock:
|
||||
_response_store[resp_id] = final_resp
|
||||
while len(_response_store) > _MAX_STORED:
|
||||
_response_store.popitem(last=False)
|
||||
|
||||
def _forward_grpc_json(self, grpc_result, model):
|
||||
"""Forward a gRPC non-streaming result to the client as JSON."""
|
||||
resp_id = f"resp-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
out = []
|
||||
full_text = ""
|
||||
data = grpc_result.response_data
|
||||
candidates = data.get("response", data).get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if part.get("thought"):
|
||||
continue
|
||||
if "text" in part and not part.get("functionCall"):
|
||||
text_parts.append(part["text"])
|
||||
elif part.get("functionCall"):
|
||||
fc = part["functionCall"]
|
||||
call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
out.append({"type": "function_call", "id": call_id, "call_id": call_id, "name": fc.get("name", ""), "arguments": json.dumps(fc.get("args", fc.get("arguments", {})))})
|
||||
if text_parts:
|
||||
full_text = "".join(text_parts)
|
||||
out.insert(0, {"type": "message", "id": f"msg-{uuid.uuid4().hex[:24]}", "role": "assistant", "content": [{"type": "output_text", "text": full_text}]})
|
||||
resp = {"id": resp_id, "object": "response", "model": model, "status": "completed", "created": created, "output": out}
|
||||
with _response_store_lock:
|
||||
_response_store[resp_id] = resp
|
||||
while len(_response_store) > _MAX_STORED:
|
||||
_response_store.popitem(last=False)
|
||||
self.send_json(200, resp)
|
||||
|
||||
def _handle_gemini_oauth(self, body, model, stream, tracker=None):
|
||||
input_data = body.get("input", "")
|
||||
policy = provider_policy()
|
||||
|
||||
Reference in New Issue
Block a user