v3.0.0: ThreadingHTTPServer, dynamic ports, health gating, atomic config, safe cleanup, buffered SSE, batched stats, graceful shutdown
This commit is contained in:
@@ -11,7 +11,8 @@ Usage:
|
||||
python3 translate-proxy.py --backend openai-compat --target-url https://... --api-key sk-...
|
||||
"""
|
||||
|
||||
import json, http.server, urllib.request, time, uuid, os, sys, argparse, threading, socket
|
||||
import json, http.server, socketserver, urllib.request, urllib.parse, urllib.error
|
||||
import time, uuid, os, sys, argparse, threading, socket, collections, contextlib, signal
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
# Config
|
||||
@@ -74,25 +75,64 @@ def load_config():
|
||||
|
||||
return cfg
|
||||
|
||||
CONFIG = load_config()
|
||||
PORT = CONFIG["port"]
|
||||
BACKEND = CONFIG["backend_type"]
|
||||
TARGET_URL = CONFIG["target_url"].rstrip("/")
|
||||
API_KEY = CONFIG["api_key"]
|
||||
OAUTH_PROVIDER = CONFIG.get("oauth_provider", "")
|
||||
MODELS = CONFIG["models"]
|
||||
CC_VERSION = CONFIG.get("cc_version", "")
|
||||
REASONING_ENABLED = CONFIG.get("reasoning_enabled", True)
|
||||
REASONING_EFFORT = CONFIG.get("reasoning_effort", "medium")
|
||||
BGP_ROUTES = CONFIG.get("bgp_routes", [])
|
||||
BGP_MODELS = []
|
||||
for _r in BGP_ROUTES:
|
||||
for _m in _r.get("models", [{"id": _r.get("model", "unknown")}]):
|
||||
if _m.get("id", _m) not in BGP_MODELS:
|
||||
BGP_MODELS.append(_m.get("id", _m) if isinstance(_m, dict) else _m)
|
||||
if BGP_ROUTES and not MODELS:
|
||||
MODELS = [{"id": m, "object": "model", "created": 1700000000, "owned_by": "bgp"} for m in BGP_MODELS]
|
||||
CONFIG["models"] = MODELS
|
||||
CONFIG = None
|
||||
PORT = 8080
|
||||
BACKEND = "openai-compat"
|
||||
TARGET_URL = ""
|
||||
API_KEY = ""
|
||||
OAUTH_PROVIDER = ""
|
||||
MODELS = []
|
||||
CC_VERSION = ""
|
||||
REASONING_ENABLED = True
|
||||
REASONING_EFFORT = "medium"
|
||||
BGP_ROUTES = []
|
||||
SERVER = None
|
||||
|
||||
_LOG_DIR = os.path.join(os.path.expanduser("~"), ".cache", "codex-proxy")
|
||||
os.makedirs(_LOG_DIR, exist_ok=True)
|
||||
_stats_path = os.path.join(_LOG_DIR, "usage-stats.json")
|
||||
_stats_lock = threading.Lock()
|
||||
_stats_pending = []
|
||||
_stats_flush_timer = None
|
||||
_STATS_FLUSH_INTERVAL = 5.0
|
||||
|
||||
_response_store = collections.OrderedDict()
|
||||
_response_store_lock = threading.Lock()
|
||||
_MAX_STORED = 50
|
||||
|
||||
_crof_lock = threading.Lock()
|
||||
|
||||
_shutdown_requested = False
|
||||
_active_connections = 0
|
||||
_active_connections_lock = threading.Lock()
|
||||
|
||||
_pool = uuid.uuid4().hex[:8]
|
||||
|
||||
def _init_runtime():
|
||||
global CONFIG, PORT, BACKEND, TARGET_URL, API_KEY, OAUTH_PROVIDER
|
||||
global MODELS, CC_VERSION, REASONING_ENABLED, REASONING_EFFORT, BGP_ROUTES
|
||||
|
||||
CONFIG = load_config()
|
||||
PORT = CONFIG["port"]
|
||||
BACKEND = CONFIG["backend_type"]
|
||||
TARGET_URL = CONFIG["target_url"].rstrip("/")
|
||||
API_KEY = CONFIG["api_key"]
|
||||
OAUTH_PROVIDER = CONFIG.get("oauth_provider", "")
|
||||
MODELS = CONFIG["models"]
|
||||
CC_VERSION = CONFIG.get("cc_version", "")
|
||||
REASONING_ENABLED = CONFIG.get("reasoning_enabled", True)
|
||||
REASONING_EFFORT = CONFIG.get("reasoning_effort", "medium")
|
||||
BGP_ROUTES = CONFIG.get("bgp_routes", [])
|
||||
|
||||
bgp_models = []
|
||||
for _r in BGP_ROUTES:
|
||||
for _m in _r.get("models", [{"id": _r.get("model", "unknown")}]):
|
||||
mid = _m.get("id", _m) if isinstance(_m, dict) else _m
|
||||
if mid not in bgp_models:
|
||||
bgp_models.append(mid)
|
||||
if BGP_ROUTES and not MODELS:
|
||||
MODELS = [{"id": m, "object": "model", "created": 1700000000, "owned_by": "bgp"} for m in bgp_models]
|
||||
CONFIG["models"] = MODELS
|
||||
|
||||
def _refresh_oauth_token():
|
||||
return _refresh_oauth_token_for(API_KEY, OAUTH_PROVIDER)
|
||||
@@ -138,14 +178,6 @@ def _refresh_oauth_token_for(api_key, oauth_provider):
|
||||
|
||||
_pool = uuid.uuid4().hex[:8]
|
||||
|
||||
_response_store = {}
|
||||
_MAX_STORED = 50
|
||||
|
||||
_LOG_DIR = os.path.join(os.path.expanduser("~"), ".cache", "codex-proxy")
|
||||
os.makedirs(_LOG_DIR, exist_ok=True)
|
||||
_stats_path = os.path.join(_LOG_DIR, "usage-stats.json")
|
||||
_stats_lock = threading.Lock()
|
||||
|
||||
def _load_stats():
|
||||
try:
|
||||
if os.path.exists(_stats_path):
|
||||
@@ -154,46 +186,78 @@ def _load_stats():
|
||||
pass
|
||||
return {"providers": {}, "updated": None}
|
||||
|
||||
def _record_usage(provider, model, success, duration_s, tokens_in=0, tokens_out=0, error_type=None):
|
||||
def _atomic_write_json(path, obj):
|
||||
tmp = path + ".tmp"
|
||||
with open(tmp, "w") as f:
|
||||
json.dump(obj, f, indent=2, ensure_ascii=False)
|
||||
os.replace(tmp, path)
|
||||
|
||||
def _flush_stats():
|
||||
global _stats_flush_timer
|
||||
with _stats_lock:
|
||||
stats = _load_stats()
|
||||
batch = list(_stats_pending)
|
||||
_stats_pending.clear()
|
||||
_stats_flush_timer = None
|
||||
if not batch:
|
||||
return
|
||||
stats = _load_stats()
|
||||
for entry in batch:
|
||||
provider = entry["provider"]
|
||||
model = entry["model"]
|
||||
p = stats["providers"].setdefault(provider, {
|
||||
"total_requests": 0, "successes": 0, "failures": 0,
|
||||
"total_tokens_in": 0, "total_tokens_out": 0,
|
||||
"total_duration_s": 0.0, "models": {}, "last_used": None, "last_error": None,
|
||||
})
|
||||
p["total_requests"] += 1
|
||||
p["total_tokens_in"] += tokens_in
|
||||
p["total_tokens_out"] += tokens_out
|
||||
p["total_duration_s"] += duration_s
|
||||
p["last_used"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
if success:
|
||||
p["total_tokens_in"] += entry["tokens_in"]
|
||||
p["total_tokens_out"] += entry["tokens_out"]
|
||||
p["total_duration_s"] += entry["duration_s"]
|
||||
p["last_used"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(entry["ts"]))
|
||||
if entry["success"]:
|
||||
p["successes"] += 1
|
||||
else:
|
||||
p["failures"] += 1
|
||||
p["last_error"] = error_type or "unknown"
|
||||
p["last_error"] = entry.get("error_type") or "unknown"
|
||||
m = p["models"].setdefault(model, {"requests": 0, "tokens_in": 0, "tokens_out": 0})
|
||||
m["requests"] += 1
|
||||
m["tokens_in"] += tokens_in
|
||||
m["tokens_out"] += tokens_out
|
||||
stats["updated"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
with open(_stats_path, "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
m["tokens_in"] += entry["tokens_in"]
|
||||
m["tokens_out"] += entry["tokens_out"]
|
||||
stats["updated"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
_atomic_write_json(_stats_path, stats)
|
||||
|
||||
def _record_usage(provider, model, success, duration_s, tokens_in=0, tokens_out=0, error_type=None):
|
||||
global _stats_flush_timer
|
||||
entry = {
|
||||
"provider": provider or "unknown", "model": model or "unknown",
|
||||
"success": bool(success), "duration_s": float(duration_s or 0),
|
||||
"tokens_in": int(tokens_in or 0), "tokens_out": int(tokens_out or 0),
|
||||
"error_type": error_type, "ts": time.time(),
|
||||
}
|
||||
with _stats_lock:
|
||||
_stats_pending.append(entry)
|
||||
if _stats_flush_timer is None:
|
||||
_stats_flush_timer = threading.Timer(_STATS_FLUSH_INTERVAL, _flush_stats)
|
||||
_stats_flush_timer.daemon = True
|
||||
_stats_flush_timer.start()
|
||||
|
||||
def store_response(resp_id, input_data, output_items):
|
||||
if not resp_id:
|
||||
return
|
||||
_response_store[resp_id] = {"input": input_data, "output": output_items}
|
||||
if len(_response_store) > _MAX_STORED:
|
||||
oldest = list(_response_store.keys())[0]
|
||||
del _response_store[oldest]
|
||||
with _response_store_lock:
|
||||
_response_store[resp_id] = {"input": input_data, "output": output_items, "ts": time.time()}
|
||||
while len(_response_store) > _MAX_STORED:
|
||||
_response_store.popitem(last=False)
|
||||
|
||||
def resolve_previous_response(body):
|
||||
prev_id = body.get("previous_response_id")
|
||||
input_data = body.get("input", "")
|
||||
if not prev_id or prev_id not in _response_store:
|
||||
if not prev_id:
|
||||
return input_data
|
||||
with _response_store_lock:
|
||||
stored = _response_store.get(prev_id)
|
||||
if not stored:
|
||||
return input_data
|
||||
stored = _response_store[prev_id]
|
||||
prev_input = stored["input"]
|
||||
prev_output = stored["output"]
|
||||
new_input = input_data if isinstance(input_data, list) else []
|
||||
@@ -983,18 +1047,60 @@ def _log_resp(resp_id, status, output):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
class ConnectionTracker:
|
||||
def __enter__(self):
|
||||
global _active_connections
|
||||
with _active_connections_lock:
|
||||
_active_connections += 1
|
||||
def __exit__(self, *a):
|
||||
global _active_connections
|
||||
with _active_connections_lock:
|
||||
_active_connections -= 1
|
||||
|
||||
def _handle_shutdown_signal(signum, frame):
|
||||
global _shutdown_requested
|
||||
_shutdown_requested = True
|
||||
print("[proxy] shutdown requested; draining connections", file=sys.stderr)
|
||||
def _drain():
|
||||
deadline = time.time() + 5
|
||||
while time.time() < deadline:
|
||||
with _active_connections_lock:
|
||||
if _active_connections == 0:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
if SERVER is not None:
|
||||
SERVER.shutdown()
|
||||
threading.Thread(target=_drain, daemon=True).start()
|
||||
|
||||
def _upstream_timeout(body, stream):
|
||||
input_data = body.get("input", "")
|
||||
n_items = len(input_data) if isinstance(input_data, list) else 1
|
||||
has_tools = bool(body.get("tools"))
|
||||
if stream:
|
||||
return min((180 if has_tools else 120) + n_items * 2, 300)
|
||||
return min(60 + n_items * 2, 120)
|
||||
|
||||
class Handler(http.server.BaseHTTPRequestHandler):
|
||||
protocol_version = "HTTP/1.1"
|
||||
|
||||
def do_GET(self):
|
||||
if self.path in ("/v1/models", "/models"):
|
||||
self.send_json(200, {"object": "list", "data": MODELS})
|
||||
elif self.path in ("/health", "/v1/health"):
|
||||
self.send_json(200, {"ok": True, "backend": BACKEND,
|
||||
"target_url": TARGET_URL,
|
||||
"models": [m.get("id") for m in MODELS],
|
||||
"bgp_routes": len(BGP_ROUTES)})
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
def do_POST(self):
|
||||
if _shutdown_requested:
|
||||
return self.send_json(503, {"error": {"type": "proxy_shutting_down",
|
||||
"message": "Proxy is shutting down"}})
|
||||
if self.path in ("/v1/responses", "/responses"):
|
||||
self._handle()
|
||||
with ConnectionTracker():
|
||||
self._handle()
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
@@ -1082,7 +1188,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
for attempt in range(max_retries + 1):
|
||||
req = urllib.request.Request(target, data=chat_body_b, headers=fwd)
|
||||
try:
|
||||
upstream = urllib.request.urlopen(req, timeout=180)
|
||||
upstream = urllib.request.urlopen(req, timeout=_upstream_timeout(body, stream))
|
||||
except urllib.error.HTTPError as e:
|
||||
err_body = e.read().decode()
|
||||
if e.code in (429, 502, 503) and attempt < max_retries:
|
||||
@@ -1163,7 +1269,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
route_ok = False
|
||||
for attempt in range(3):
|
||||
try:
|
||||
upstream = urllib.request.urlopen(req, timeout=180)
|
||||
upstream = urllib.request.urlopen(req, timeout=_upstream_timeout(body, stream))
|
||||
print(f"[bgp] route '{route.get('name', r_url)}' connected OK", file=sys.stderr)
|
||||
self._forward_oa_compat(upstream, stream, r_model, chat_body, body, input_data, fwd, target)
|
||||
return
|
||||
@@ -1284,7 +1390,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
|
||||
def _forward_oa_compat_retry(self, req, model, chat_body, body, input_data):
|
||||
try:
|
||||
upstream = urllib.request.urlopen(req, timeout=180)
|
||||
upstream = urllib.request.urlopen(req, timeout=_upstream_timeout(body, True))
|
||||
except Exception as e:
|
||||
print(f"[crof-adaptive] retry failed: {e}", file=sys.stderr)
|
||||
return
|
||||
@@ -1427,7 +1533,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
|
||||
if stream:
|
||||
try:
|
||||
upstream = urllib.request.urlopen(req, timeout=180)
|
||||
upstream = urllib.request.urlopen(req, timeout=_upstream_timeout(body, True))
|
||||
except urllib.error.HTTPError as e:
|
||||
err = e.read().decode()
|
||||
return self.send_json(e.code, {"error": {"type": "upstream_error", "message": err}})
|
||||
@@ -1461,7 +1567,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
store_response(last_resp_id, body.get("input", ""), last_output)
|
||||
else:
|
||||
try:
|
||||
upstream = urllib.request.urlopen(req, timeout=180)
|
||||
upstream = urllib.request.urlopen(req, timeout=_upstream_timeout(body, False))
|
||||
except urllib.error.HTTPError as e:
|
||||
err = e.read().decode()
|
||||
return self.send_json(e.code, {"error": {"type": "upstream_error", "message": err}})
|
||||
@@ -1478,7 +1584,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
|
||||
def _forward(self, req, stream, model, nonstream_fn, stream_fn, input_data=None):
|
||||
try:
|
||||
upstream = urllib.request.urlopen(req, timeout=180)
|
||||
upstream = urllib.request.urlopen(req, timeout=_upstream_timeout({}, stream))
|
||||
except urllib.error.HTTPError as e:
|
||||
err = e.read().decode()
|
||||
return self.send_json(e.code, {"error": {"type": "upstream_error", "message": err}})
|
||||
@@ -1533,17 +1639,54 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def stream_buffered_events(self, event_iter, flush_interval=0.03, max_bytes=4096):
|
||||
buf = bytearray()
|
||||
last_flush = time.monotonic()
|
||||
def _flush():
|
||||
nonlocal buf, last_flush
|
||||
if buf:
|
||||
self.wfile.write(buf)
|
||||
self.wfile.flush()
|
||||
buf.clear()
|
||||
last_flush = time.monotonic()
|
||||
for event in event_iter:
|
||||
encoded = event.encode("utf-8") if isinstance(event, str) else event
|
||||
buf.extend(encoded)
|
||||
urgent = ("response.completed" in event or "response.output_text.done" in event
|
||||
or "response.output_item.done" in event
|
||||
or "function_call_arguments.done" in event)
|
||||
if urgent or len(buf) >= max_bytes or time.monotonic() - last_flush >= flush_interval:
|
||||
_flush()
|
||||
_flush()
|
||||
|
||||
def log_message(self, fmt, *args):
|
||||
msg = fmt % args if args else fmt
|
||||
print(f"[translate-proxy] {BACKEND} {msg}", file=sys.stderr)
|
||||
|
||||
if __name__ == "__main__":
|
||||
class ReusableHTTPServer(http.server.HTTPServer):
|
||||
def main():
|
||||
global SERVER
|
||||
_init_runtime()
|
||||
signal.signal(signal.SIGTERM, _handle_shutdown_signal)
|
||||
signal.signal(signal.SIGINT, _handle_shutdown_signal)
|
||||
try:
|
||||
from http.server import ThreadingHTTPServer as _BaseSrv
|
||||
except ImportError:
|
||||
class _BaseSrv(socketserver.ThreadingMixIn, http.server.HTTPServer):
|
||||
daemon_threads = True
|
||||
class ReusableHTTPServer(_BaseSrv):
|
||||
allow_reuse_address = True
|
||||
server = ReusableHTTPServer(("127.0.0.1", PORT), Handler)
|
||||
daemon_threads = True
|
||||
request_queue_size = 64
|
||||
SERVER = ReusableHTTPServer(("127.0.0.1", PORT), Handler)
|
||||
print(f"translate-proxy ({BACKEND}) listening on http://127.0.0.1:{PORT}", flush=True)
|
||||
print(f"Target: {TARGET_URL}", flush=True)
|
||||
print(f"Models: {[m['id'] for m in MODELS]}", flush=True)
|
||||
if BGP_ROUTES:
|
||||
print(f"BGP routes: {len(BGP_ROUTES)} ({[r.get('name','?') for r in BGP_ROUTES]})", flush=True)
|
||||
server.serve_forever()
|
||||
try:
|
||||
SERVER.serve_forever()
|
||||
finally:
|
||||
_flush_stats()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user