397 lines
16 KiB
Python
397 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Unit tests for the Antigravity gRPC fallback module.
|
|
|
|
Tests cover:
|
|
1. Module import and availability detection
|
|
2. Protobuf conversion helpers (JSON <-> protobuf)
|
|
3. Request building from wrapped REST dict
|
|
4. Reverse alias map correctness
|
|
5. GrpcFallbackResult type
|
|
6. Integration: _try_grpc_fallback triggers correctly on REST 404
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
# Add src to path so we can import the antigravity_grpc package
|
|
_src_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
|
|
if _src_dir not in sys.path:
|
|
sys.path.insert(0, _src_dir)
|
|
|
|
|
|
class TestGrpcModuleAvailability(unittest.TestCase):
|
|
"""Tests for is_grpc_available() and module loading."""
|
|
|
|
def test_is_grpc_available_returns_bool(self):
|
|
"""is_grpc_available should return a boolean."""
|
|
from antigravity_grpc import is_grpc_available
|
|
result = is_grpc_available()
|
|
self.assertIsInstance(result, bool)
|
|
|
|
def test_is_grpc_available_true_when_installed(self):
|
|
"""If grpcio is installed and stubs are loadable, should return True."""
|
|
from antigravity_grpc import is_grpc_available
|
|
# grpcio was installed at test time, so this should be True
|
|
self.assertTrue(is_grpc_available())
|
|
|
|
def test_client_instantiation(self):
|
|
"""AntigravityGrpcClient should be instantiatable."""
|
|
from antigravity_grpc import AntigravityGrpcClient
|
|
client = AntigravityGrpcClient()
|
|
self.assertIsNotNone(client)
|
|
|
|
def test_get_client_singleton(self):
|
|
"""get_client should return the same singleton."""
|
|
from antigravity_grpc import get_client
|
|
c1 = get_client()
|
|
c2 = get_client()
|
|
self.assertIs(c1, c2)
|
|
|
|
|
|
class TestGrpcFallbackResult(unittest.TestCase):
|
|
"""Tests for GrpcFallbackResult type."""
|
|
|
|
def test_default_values(self):
|
|
from antigravity_grpc import GrpcFallbackResult
|
|
r = GrpcFallbackResult()
|
|
self.assertFalse(r.ok)
|
|
self.assertIsNone(r.response_data)
|
|
self.assertIsNone(r.stream_chunks)
|
|
self.assertEqual(r.error_message, "")
|
|
self.assertEqual(r.endpoint_used, "")
|
|
self.assertEqual(r.model_used, "")
|
|
self.assertEqual(r.elapsed_s, 0.0)
|
|
|
|
def test_success_result(self):
|
|
from antigravity_grpc import GrpcFallbackResult
|
|
r = GrpcFallbackResult(ok=True, response_data={"response": {"candidates": []}},
|
|
endpoint_used="daily-cloudcode-pa.googleapis.com:443",
|
|
model_used="Gemini 3.5 Flash (High)",
|
|
elapsed_s=2.5)
|
|
self.assertTrue(r.ok)
|
|
self.assertIsNotNone(r.response_data)
|
|
self.assertEqual(r.elapsed_s, 2.5)
|
|
|
|
def test_failure_result(self):
|
|
from antigravity_grpc import GrpcFallbackResult
|
|
r = GrpcFallbackResult(ok=False, error_message="All gRPC endpoints failed")
|
|
self.assertFalse(r.ok)
|
|
self.assertIn("failed", r.error_message)
|
|
|
|
def test_repr(self):
|
|
from antigravity_grpc import GrpcFallbackResult
|
|
r_ok = GrpcFallbackResult(ok=True, response_data={"response": {"candidates": []}})
|
|
self.assertIn("OK", repr(r_ok))
|
|
r_fail = GrpcFallbackResult(ok=False, error_message="timeout")
|
|
self.assertIn("FAIL", repr(r_fail))
|
|
|
|
|
|
class TestReverseAliasMap(unittest.TestCase):
|
|
"""Tests for the _GRPC_REVERSE_ALIAS map in translate-proxy.py."""
|
|
|
|
def test_import_reverse_alias(self):
|
|
"""The reverse alias map should be importable from the proxy module."""
|
|
import importlib
|
|
_spec = importlib.util.spec_from_file_location(
|
|
"translate_proxy",
|
|
os.path.join(_src_dir, "translate-proxy.py"),
|
|
)
|
|
tp = importlib.util.module_from_spec(_spec)
|
|
_spec.loader.exec_module(tp)
|
|
self.assertIsInstance(tp._GRPC_REVERSE_ALIAS, dict)
|
|
|
|
def test_key_models_have_reverse_aliases(self):
|
|
"""All key REST model slugs should have gRPC display name mappings."""
|
|
import importlib
|
|
_spec = importlib.util.spec_from_file_location(
|
|
"translate_proxy",
|
|
os.path.join(_src_dir, "translate-proxy.py"),
|
|
)
|
|
tp = importlib.util.module_from_spec(_spec)
|
|
_spec.loader.exec_module(tp)
|
|
|
|
required_slugs = [
|
|
"gemini-3-flash",
|
|
"gemini-3.5-flash-low",
|
|
"gemini-3.1-pro-low",
|
|
"claude-sonnet-4-6",
|
|
"claude-opus-4-6-thinking",
|
|
"gemini-2.5-flash",
|
|
]
|
|
for slug in required_slugs:
|
|
self.assertIn(slug, tp._GRPC_REVERSE_ALIAS,
|
|
f"Missing reverse alias for REST slug '{slug}'")
|
|
|
|
def test_reverse_alias_values_are_display_names(self):
|
|
"""gRPC display names should contain spaces and parentheses, not hyphens."""
|
|
import importlib
|
|
_spec = importlib.util.spec_from_file_location(
|
|
"translate_proxy",
|
|
os.path.join(_src_dir, "translate-proxy.py"),
|
|
)
|
|
tp = importlib.util.module_from_spec(_spec)
|
|
_spec.loader.exec_module(tp)
|
|
|
|
for slug, display_name in tp._GRPC_REVERSE_ALIAS.items():
|
|
# Display names typically have spaces (e.g. "Gemini 3.5 Flash (High)")
|
|
# while slugs use hyphens (e.g. "gemini-3-flash")
|
|
self.assertNotEqual(slug, display_name,
|
|
f"Reverse alias for '{slug}' should differ from slug (gRPC uses display names)")
|
|
|
|
|
|
class TestProtobufConversion(unittest.TestCase):
|
|
"""Tests for JSON -> protobuf conversion helpers."""
|
|
|
|
def test_struct_to_protobuf(self):
|
|
"""_struct_to_protobuf should convert a simple dict to Struct."""
|
|
from antigravity_grpc.client import _struct_to_protobuf
|
|
result = _struct_to_protobuf({"key": "value", "num": 42})
|
|
self.assertIsNotNone(result)
|
|
# Verify round-trip
|
|
from antigravity_grpc.client import _protobuf_struct_to_dict
|
|
d = _protobuf_struct_to_dict(result)
|
|
self.assertEqual(d["key"], "value")
|
|
self.assertEqual(d["num"], 42.0)
|
|
|
|
def test_struct_round_trip_nested(self):
|
|
"""Nested dicts should survive a round-trip through protobuf."""
|
|
from antigravity_grpc.client import _struct_to_protobuf, _protobuf_struct_to_dict
|
|
original = {"outer": {"inner": "hello"}, "list_val": [1, 2, 3]}
|
|
proto = _struct_to_protobuf(original)
|
|
result = _protobuf_struct_to_dict(proto)
|
|
self.assertEqual(result["outer"]["inner"], "hello")
|
|
self.assertEqual(result["list_val"], [1.0, 2.0, 3.0])
|
|
|
|
def test_json_parts_to_proto_text(self):
|
|
"""Text parts should convert to protobuf Part with text field."""
|
|
from antigravity_grpc.client import _json_parts_to_proto
|
|
parts = _json_parts_to_proto([{"text": "Hello world"}])
|
|
self.assertEqual(len(parts), 1)
|
|
self.assertEqual(parts[0].text, "Hello world")
|
|
|
|
def test_json_parts_to_proto_function_call(self):
|
|
"""FunctionCall parts should convert correctly."""
|
|
from antigravity_grpc.client import _json_parts_to_proto
|
|
parts = _json_parts_to_proto([{
|
|
"functionCall": {
|
|
"name": "exec_command",
|
|
"args": {"cmd": "ls -la"},
|
|
"id": "call_123"
|
|
}
|
|
}])
|
|
self.assertEqual(len(parts), 1)
|
|
self.assertTrue(parts[0].HasField("function_call"))
|
|
self.assertEqual(parts[0].function_call.name, "exec_command")
|
|
self.assertEqual(parts[0].function_call.id, "call_123")
|
|
|
|
def test_json_parts_to_proto_function_response(self):
|
|
"""FunctionResponse parts should convert correctly."""
|
|
from antigravity_grpc.client import _json_parts_to_proto
|
|
parts = _json_parts_to_proto([{
|
|
"functionResponse": {
|
|
"name": "exec_command",
|
|
"response": {"result": "file1.txt"},
|
|
"id": "call_123"
|
|
}
|
|
}])
|
|
self.assertEqual(len(parts), 1)
|
|
self.assertTrue(parts[0].HasField("function_response"))
|
|
self.assertEqual(parts[0].function_response.name, "exec_command")
|
|
|
|
def test_json_contents_to_proto(self):
|
|
"""Content objects should convert correctly."""
|
|
from antigravity_grpc.client import _json_contents_to_proto
|
|
contents = _json_contents_to_proto([
|
|
{"role": "user", "parts": [{"text": "Hello"}]},
|
|
{"role": "model", "parts": [{"text": "Hi there"}]},
|
|
])
|
|
self.assertEqual(len(contents), 2)
|
|
self.assertEqual(contents[0].role, "user")
|
|
self.assertEqual(contents[1].role, "model")
|
|
|
|
def test_proto_candidate_to_json(self):
|
|
"""Protobuf candidates should convert back to JSON-compatible dicts."""
|
|
from antigravity_grpc.client import _json_contents_to_proto, _proto_candidate_to_json
|
|
from antigravity_grpc import cloudcode_pb2 as pb2
|
|
|
|
# Build a candidate manually
|
|
candidate = pb2.Candidate()
|
|
candidate.content.role = "model"
|
|
candidate.content.parts.add().text = "Hello from gRPC"
|
|
candidate.finish_reason = "STOP"
|
|
candidate.index = 0
|
|
|
|
result = _proto_candidate_to_json(candidate)
|
|
self.assertEqual(result["finishReason"], "STOP")
|
|
self.assertEqual(result["content"]["role"], "model")
|
|
self.assertEqual(result["content"]["parts"][0]["text"], "Hello from gRPC")
|
|
|
|
|
|
class TestGrpcRequestBuilding(unittest.TestCase):
|
|
"""Tests for _build_request (wrapped REST dict → protobuf)."""
|
|
|
|
def _get_client(self):
|
|
from antigravity_grpc import AntigravityGrpcClient
|
|
return AntigravityGrpcClient()
|
|
|
|
def test_build_request_basic(self):
|
|
"""Basic request fields should be populated correctly."""
|
|
client = self._get_client()
|
|
wrapped = {
|
|
"project": "test-project-123",
|
|
"model": "Gemini 3.5 Flash (High)",
|
|
"requestType": "agent",
|
|
"userAgent": "antigravity/2.0.6",
|
|
"requestId": "agent-test123",
|
|
"request": {
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "Say hello"}]}
|
|
],
|
|
"safetySettings": [
|
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
|
],
|
|
}
|
|
}
|
|
req = client._build_request(wrapped)
|
|
self.assertEqual(req.project, "test-project-123")
|
|
self.assertEqual(req.model, "Gemini 3.5 Flash (High)")
|
|
self.assertEqual(req.request_type, "agent")
|
|
self.assertEqual(len(req.request.contents), 1)
|
|
self.assertEqual(req.request.contents[0].role, "user")
|
|
|
|
def test_build_request_with_tools(self):
|
|
"""Tools should be converted to function declarations."""
|
|
client = self._get_client()
|
|
wrapped = {
|
|
"project": "test-project",
|
|
"model": "gemini-3-flash",
|
|
"request": {
|
|
"contents": [],
|
|
"tools": [{
|
|
"functionDeclarations": [{
|
|
"name": "exec_command",
|
|
"description": "Run a shell command",
|
|
"parameters": {"type": "object", "properties": {"cmd": {"type": "string"}}}
|
|
}]
|
|
}],
|
|
}
|
|
}
|
|
req = client._build_request(wrapped)
|
|
self.assertEqual(len(req.request.tools), 1)
|
|
self.assertEqual(req.request.tools[0].function_declarations[0].name, "exec_command")
|
|
|
|
def test_build_request_with_generation_config(self):
|
|
"""Generation config should be populated correctly."""
|
|
client = self._get_client()
|
|
wrapped = {
|
|
"project": "test-project",
|
|
"model": "gemini-3-flash",
|
|
"request": {
|
|
"contents": [],
|
|
"generationConfig": {
|
|
"maxOutputTokens": 64000,
|
|
"temperature": 0.7,
|
|
"stopSequences": ["\n\nHuman:"],
|
|
"thinkingConfig": {
|
|
"includeThoughts": True,
|
|
"thinkingBudget": 8192,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
req = client._build_request(wrapped)
|
|
self.assertEqual(req.request.generation_config.max_output_tokens, 64000)
|
|
self.assertAlmostEqual(req.request.generation_config.temperature, 0.7, places=2)
|
|
self.assertTrue(req.request.generation_config.thinking_config.include_thoughts)
|
|
self.assertEqual(req.request.generation_config.thinking_config.thinking_budget, 8192)
|
|
|
|
def test_build_request_with_function_call_history(self):
|
|
"""Function call/response pairs in contents should be preserved."""
|
|
client = self._get_client()
|
|
wrapped = {
|
|
"project": "test-project",
|
|
"model": "gemini-3-flash",
|
|
"request": {
|
|
"contents": [
|
|
{"role": "user", "parts": [{"text": "List files"}]},
|
|
{"role": "model", "parts": [{
|
|
"functionCall": {"name": "exec_command", "args": {"cmd": "ls"}, "id": "call_1"}
|
|
}]},
|
|
{"role": "user", "parts": [{
|
|
"functionResponse": {"name": "exec_command", "response": {"result": "file.txt"}, "id": "call_1"}
|
|
}]},
|
|
]
|
|
}
|
|
}
|
|
req = client._build_request(wrapped)
|
|
self.assertEqual(len(req.request.contents), 3)
|
|
# Verify function call preserved
|
|
self.assertTrue(req.request.contents[1].parts[0].HasField("function_call"))
|
|
self.assertEqual(req.request.contents[1].parts[0].function_call.name, "exec_command")
|
|
# Verify function response preserved
|
|
self.assertTrue(req.request.contents[2].parts[0].HasField("function_response"))
|
|
self.assertEqual(req.request.contents[2].parts[0].function_response.name, "exec_command")
|
|
|
|
|
|
class TestGrpcEndpointsConfig(unittest.TestCase):
|
|
"""Tests for gRPC endpoint configuration."""
|
|
|
|
def test_default_endpoints(self):
|
|
"""Default endpoints should include production and daily."""
|
|
from antigravity_grpc.client import _GRPC_ENDPOINTS
|
|
self.assertGreaterEqual(len(_GRPC_ENDPOINTS), 2)
|
|
hostnames = [ep.split(":")[0] for ep in _GRPC_ENDPOINTS]
|
|
self.assertIn("daily-cloudcode-pa.googleapis.com", hostnames)
|
|
self.assertIn("cloudcode-pa.googleapis.com", hostnames)
|
|
|
|
def test_staging_env_var(self):
|
|
"""Staging endpoints should be controlled by env var."""
|
|
from antigravity_grpc.client import _ALLOW_STAGING_ENV
|
|
self.assertEqual(_ALLOW_STAGING_ENV, "ALLOW_ANTIGRAVITY_STAGING")
|
|
|
|
|
|
class TestProxyIntegration(unittest.TestCase):
|
|
"""Tests for the proxy's gRPC fallback integration."""
|
|
|
|
def _load_proxy_module(self):
|
|
import importlib
|
|
_spec = importlib.util.spec_from_file_location(
|
|
"translate_proxy",
|
|
os.path.join(_src_dir, "translate-proxy.py"),
|
|
)
|
|
tp = importlib.util.module_from_spec(_spec)
|
|
_spec.loader.exec_module(tp)
|
|
return tp
|
|
|
|
def test_get_grpc_client_function_exists(self):
|
|
"""_get_grpc_client should exist as a module-level function."""
|
|
tp = self._load_proxy_module()
|
|
self.assertTrue(callable(tp._get_grpc_client))
|
|
|
|
def test_grpc_fallback_errors_set(self):
|
|
"""_GRPC_FALLBACK_REST_ERRORS should include 404."""
|
|
tp = self._load_proxy_module()
|
|
self.assertIn(404, tp._GRPC_FALLBACK_REST_ERRORS)
|
|
|
|
def test_versions_bug_fixed(self):
|
|
"""The _versions[0] NameError should be fixed (should be _fetched_ver)."""
|
|
# Read the source file and verify _versions is not used incorrectly
|
|
with open(os.path.join(_src_dir, "translate-proxy.py")) as f:
|
|
source = f.read()
|
|
# The bug was: ver={_versions[0]} -- should be ver={_fetched_ver}
|
|
self.assertNotIn("_versions[0]", source,
|
|
"Bug: _versions[0] should have been replaced with _fetched_ver")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("=" * 70)
|
|
print("Antigravity gRPC Fallback - Unit Tests")
|
|
print("=" * 70)
|
|
print()
|
|
|
|
unittest.main(verbosity=2)
|