The multi-query vector search with blocking urllib.request.urlopen calls was stalling the single-threaded uvicorn event loop. Now uses async httpx.AsyncClient with asyncio.gather for parallel requests. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
566 lines
21 KiB
Python
566 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""Wiki VectorDB Chat — Multi-Provider AI Chat with RAG (KB + VectorDB).
|
|
|
|
Serves at port 8770, proxied via nginx at /zportal/wiki/api/chat
|
|
Uses wiki-api (:8097) for KB search and vector-db (:8099) for vector search.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
import urllib.request
|
|
from pathlib import Path
|
|
|
|
PROVIDERS_FILE = Path("/opt/blog/wiki-chat-providers.json")
|
|
CUSTOM_PROVIDERS_FILE = Path("/opt/blog/wiki-chat-providers.json")
|
|
|
|
WIKI_API = "http://127.0.0.1:8097"
|
|
VECTOR_DB = "http://127.0.0.1:8099"
|
|
|
|
# Shared API token for wiki-api and vector-db
|
|
_API_TOKEN = ""
|
|
try:
|
|
_API_TOKEN = Path("/opt/blog/.wiki-api-token").read_text().strip()
|
|
except Exception:
|
|
pass
|
|
|
|
PRESETS = [
|
|
{
|
|
"id": "zai-coding",
|
|
"name": "Z.ai Coding Plan",
|
|
"base_url": "https://api.z.ai/api/coding/paas/v4",
|
|
"model": "glm-4-plus",
|
|
"format": "openai",
|
|
"icon": "\u26a1",
|
|
"description": "Official Z.ai coding plan API",
|
|
},
|
|
{
|
|
"id": "openadapter",
|
|
"name": "OpenAdapter",
|
|
"base_url": "https://api.openadapter.com/v1",
|
|
"model": "gpt-4o-mini",
|
|
"format": "openai",
|
|
"icon": "\u1f512",
|
|
"description": "OpenAdapter unified API",
|
|
},
|
|
{
|
|
"id": "openrouter",
|
|
"name": "OpenRouter",
|
|
"base_url": "https://openrouter.ai/api/v1",
|
|
"model": "anthropic/claude-sonnet-4",
|
|
"format": "openrouter",
|
|
"icon": "\u1f6e3",
|
|
"description": "Model router across providers",
|
|
},
|
|
{
|
|
"id": "crofai",
|
|
"name": "Crof.AI",
|
|
"base_url": "https://api.crof.ai/v1",
|
|
"model": "crof-4-plus",
|
|
"format": "openai",
|
|
"icon": "\u1f42a",
|
|
"description": "Crof AI models",
|
|
},
|
|
{
|
|
"id": "opencode-zen",
|
|
"name": "Opencode Zen",
|
|
"base_url": "https://api.zen.opencode.com/v1",
|
|
"model": "glm-4-plus",
|
|
"format": "openai",
|
|
"icon": "\u1f9e0",
|
|
"description": "Opencode Zen hosted models",
|
|
},
|
|
]
|
|
|
|
|
|
def load_custom_providers():
|
|
try:
|
|
if CUSTOM_PROVIDERS_FILE.exists():
|
|
return json.loads(CUSTOM_PROVIDERS_FILE.read_text())
|
|
except Exception:
|
|
pass
|
|
return []
|
|
|
|
|
|
def save_custom_providers(providers):
|
|
CUSTOM_PROVIDERS_FILE.write_text(json.dumps(providers, indent=2))
|
|
|
|
|
|
def get_all_providers():
|
|
"""Return presets + custom providers."""
|
|
custom = load_custom_providers()
|
|
seen = {p["id"] for p in PRESETS}
|
|
result = list(PRESETS)
|
|
for p in custom:
|
|
if p.get("id") not in seen:
|
|
seen.add(p["id"])
|
|
result.append(p)
|
|
return result
|
|
|
|
|
|
def detect_provider_format(base_url: str) -> str:
|
|
from urllib.parse import urlparse
|
|
host = urlparse(base_url).hostname.lower()
|
|
if "ollama" in host or host in ("localhost", "127.0.0.1"):
|
|
return "ollama"
|
|
if "anthropic" in host:
|
|
return "anthropic"
|
|
if "openrouter" in host:
|
|
return "openrouter"
|
|
if "groq" in host:
|
|
return "groq"
|
|
return "openai"
|
|
|
|
|
|
async def search_kb(query: str, limit: int = 5) -> str:
|
|
"""Search wiki-kb.json via wiki-api with multiple query variants."""
|
|
try:
|
|
import re
|
|
words = re.findall(r'\b[A-Z][a-z]+\b', query)
|
|
queries = [query]
|
|
if words:
|
|
queries.append(" ".join(words))
|
|
queries = list(dict.fromkeys(queries))[:2]
|
|
|
|
all_results = {}
|
|
import httpx
|
|
async with httpx.AsyncClient(timeout=5) as client:
|
|
tasks = []
|
|
for q in queries:
|
|
url = f"{WIKI_API}/search?q={urllib.parse.quote(q)}&limit={limit}&token={_API_TOKEN}"
|
|
tasks.append(client.get(url))
|
|
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
|
for resp in responses:
|
|
if isinstance(resp, Exception):
|
|
continue
|
|
data = resp.json()
|
|
for r in data.get("results", []):
|
|
key = r.get("q", "")[:80]
|
|
if key not in all_results:
|
|
all_results[key] = r
|
|
|
|
results = list(all_results.values())[:limit]
|
|
if not results:
|
|
return ""
|
|
lines = []
|
|
for r in results:
|
|
lines.append(f"[{r.get('topic', '')}] Q: {r.get('q', '')}\nA: {r.get('a', '')}")
|
|
return "\n\n".join(lines)
|
|
except Exception as e:
|
|
return f"(KB search error: {e})"
|
|
|
|
|
|
async def search_vector(query: str, top_k: int = 10) -> str:
|
|
"""Search vector-db with multiple query variants for better recall."""
|
|
try:
|
|
import re
|
|
# Extract named entities (capitalized) and topic words separately
|
|
entities = [w for w in re.findall(r'\b[A-Z][a-z]+\b', query)
|
|
if w.lower() not in ('what', 'when', 'where', 'how', 'why', 'that', 'this')]
|
|
stop = {'what', 'about', 'that', 'this', 'with', 'from', 'have', 'been',
|
|
'said', 'recently', 'announced', 'did', 'does', 'when', 'where',
|
|
'how', 'why', 'the', 'and', 'for', 'are', 'was'}
|
|
entity_lower = {e.lower() for e in entities}
|
|
topic_words = [w for w in query.lower().split()
|
|
if len(w) > 2 and w not in stop and w not in entity_lower]
|
|
|
|
queries = [query] # full original
|
|
if entities:
|
|
queries.append(" ".join(entities)) # e.g. "Jodie"
|
|
# Topic-only query (without entity names) — catches messages BY the person
|
|
if topic_words:
|
|
queries.append(" ".join(topic_words[:6])) # e.g. "ambassador program"
|
|
# Combined
|
|
if entities and topic_words:
|
|
queries.append(" ".join(entities + topic_words[:3]))
|
|
|
|
queries = list(dict.fromkeys(queries))[:4]
|
|
|
|
all_hits = {}
|
|
# Use async httpx to avoid blocking the event loop
|
|
import httpx
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
tasks = []
|
|
for q in queries:
|
|
tasks.append(client.post(
|
|
f"{VECTOR_DB}/vector/search",
|
|
json={"query": q, "top_k": 30},
|
|
headers={"Content-Type": "application/json", "x-api-key": _API_TOKEN},
|
|
))
|
|
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
|
for resp in responses:
|
|
if isinstance(resp, Exception):
|
|
continue
|
|
result = resp.json()
|
|
for h in result.get("results", []):
|
|
text = h.get("content", "") or h.get("text", "")
|
|
key = text[:80]
|
|
if key not in all_hits or h.get("score", 0) > all_hits[key].get("score", 0):
|
|
all_hits[key] = h
|
|
|
|
# Sort by score descending, take top_k
|
|
# Score boost: prefer results that match multiple topic words
|
|
def topic_relevance(h):
|
|
text = (h.get("content", "") or "").lower()
|
|
bonus = sum(1 for tw in topic_words if tw in text)
|
|
return h.get("score", 0) + bonus * 0.15
|
|
|
|
hits = sorted(all_hits.values(), key=topic_relevance, reverse=True)[:15]
|
|
if not hits:
|
|
return ""
|
|
lines = []
|
|
for h in hits:
|
|
text = h.get("content", "") or h.get("text", "")
|
|
score = h.get("score", 0)
|
|
source = h.get("source", "unknown")
|
|
author = h.get("author", "")
|
|
channel = h.get("channel", "")
|
|
ts = h.get("timestamp", "")[:10]
|
|
preview = text.replace("\n", " ")[:500]
|
|
lines.append(f"[{source}] @{author} ({ts}): {preview} (score: {score:.2f})")
|
|
return "\n\n".join(lines)
|
|
except Exception as e:
|
|
return f"(Vector search error: {e})"
|
|
|
|
|
|
async def build_rag_context(user_message: str, rag_wiki: bool = True, rag_vector: bool = True) -> str:
|
|
"""Build RAG context from KB + VectorDB searches."""
|
|
kb_results = ""
|
|
vec_results = ""
|
|
tasks = []
|
|
if rag_wiki:
|
|
tasks.append(search_kb(user_message, 5))
|
|
if rag_vector:
|
|
tasks.append(search_vector(user_message, 10))
|
|
|
|
if tasks:
|
|
results = await asyncio.gather(*tasks)
|
|
idx = 0
|
|
if rag_wiki:
|
|
kb_results = results[idx]; idx += 1
|
|
if rag_vector:
|
|
vec_results = results[idx]
|
|
|
|
parts = [
|
|
"You are Z.ai Wiki Assistant. Use ALL the knowledge sources below to answer the user's question.",
|
|
"Draw from both the Wiki KB and Community Messages. Synthesize information even from partial matches.",
|
|
"If the context mentions anything relevant, include it in your answer. Be specific — quote authors, dates, and details when available.",
|
|
"IMPORTANT: People often talk ABOUT someone (like Jodie, Cobra, Agnes) without that person posting directly. If multiple users reference what someone said or did, reconstruct and summarize that information. Do NOT say 'no information found' if other users are clearly discussing the topic.",
|
|
"When users ask 'what did X say about Y', look for messages FROM others referencing X's statements or actions about Y.",
|
|
"Only say you don't have information if the sources are truly empty or completely unrelated.",
|
|
"",
|
|
]
|
|
if rag_wiki:
|
|
parts += ["=== Wiki Knowledge Base ===", kb_results or "(no KB results found)", ""]
|
|
if rag_vector:
|
|
parts += ["=== Related Community Messages (Discord/Reddit) ===", vec_results or "(no community messages found)"]
|
|
if not rag_wiki and not rag_vector:
|
|
parts.append("(RAG sources disabled for this session)")
|
|
return "\n".join(parts)
|
|
|
|
|
|
# ── LLM Provider Calls ──
|
|
|
|
def format_messages_openai(system: str, messages: list, model: str) -> dict:
|
|
"""Format for OpenAI-compatible /chat/completions endpoint."""
|
|
return {"model": model, "messages": [{"role": "system", "content": system}] + messages,
|
|
"temperature": 0.7, "max_tokens": 2048, "stream": True}
|
|
|
|
|
|
def format_messages_anthropic(system: str, messages: list, model: str) -> dict:
|
|
"""Convert OpenAI-format messages to Anthropic format."""
|
|
anthropic_msgs = []
|
|
for m in messages:
|
|
role = "user" if m["role"] == "user" else "assistant"
|
|
anthropic_msgs.append({"role": role, "content": m["content"]})
|
|
return {"model": model, "system": system, "messages": anthropic_msgs,
|
|
"max_tokens": 2048, "stream": True}
|
|
|
|
|
|
def format_messages_ollama(system: str, messages: list, model: str) -> dict:
|
|
"""Format for Ollama /api/chat endpoint."""
|
|
ollama_msgs = []
|
|
for m in messages:
|
|
role = "user" if m["role"] == "user" else "assistant"
|
|
ollama_msgs.append({"role": role, "content": m["content"]})
|
|
return {"model": model, "messages": ollama_msgs, "stream": True}
|
|
|
|
|
|
async def call_llm_stream(provider: dict, system: str, messages: list):
|
|
"""Call LLM provider and yield SSE delta chunks."""
|
|
base_url = provider["base_url"].rstrip("/")
|
|
fmt = provider.get("format", detect_provider_format(base_url))
|
|
api_key = provider.get("api_key", "")
|
|
model = provider.get("model", "gpt-4o-mini")
|
|
|
|
if fmt == "anthropic":
|
|
payload = format_messages_anthropic(system, messages, model)
|
|
url = f"{base_url}/v1/messages"
|
|
headers = {"x-api-key": api_key, "Content-Type": "application/json",
|
|
"anthropic-version": "2023-06-01"}
|
|
elif fmt == "ollama":
|
|
payload = format_messages_ollama(system, messages, model)
|
|
url = f"{base_url}/api/chat"
|
|
headers = {"Content-Type": "application/json"}
|
|
else:
|
|
# openai / openrouter / groq / default
|
|
payload = format_messages_openai(system, messages, model)
|
|
url = f"{base_url}/chat/completions"
|
|
headers = {}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
if fmt == "openrouter":
|
|
headers["HTTP-OpenRouter-AI-Model"] = model
|
|
headers["Content-Type"] = "application/json"
|
|
|
|
data = json.dumps(payload).encode()
|
|
req = urllib.request.Request(url, data=data, headers=headers)
|
|
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=60) as resp:
|
|
reader = resp
|
|
buf = b""
|
|
while True:
|
|
chunk = reader.read(4096)
|
|
if not chunk:
|
|
break
|
|
buf += chunk
|
|
while b"\n" in buf:
|
|
line, buf = buf.split(b"\n", 1)
|
|
line = line.decode("utf-8", errors="replace").strip()
|
|
if not line:
|
|
continue
|
|
if line.startswith("data: "):
|
|
data_str = line[5:].strip()
|
|
if data_str == "[DONE]":
|
|
yield {"type": "done"}
|
|
return
|
|
try:
|
|
chunk_data = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
yield {"delta": data_str, "type": "raw"}
|
|
continue
|
|
|
|
# Anthropic SSE format
|
|
if fmt == "anthropic":
|
|
evt_type = chunk_data.get("type", "")
|
|
if evt_type == "content_block_delta":
|
|
text = chunk_data.get("delta", {}).get("text", "")
|
|
if text:
|
|
yield {"delta": text, "type": "delta"}
|
|
elif evt_type == "message_stop":
|
|
yield {"type": "done"}
|
|
return
|
|
elif evt_type == "error":
|
|
err_msg = chunk_data.get("error", {}).get("message", str(chunk_data))
|
|
yield {"type": "error", "delta": err_msg}
|
|
return
|
|
continue
|
|
|
|
# OpenAI-compatible SSE format
|
|
deltas = chunk_data.get("choices", [{}])[0].get("delta", {})
|
|
content = deltas.get("content", "")
|
|
if content:
|
|
yield {"delta": content, "type": "delta"}
|
|
tool_calls = deltas.get("tool_calls")
|
|
if tool_calls:
|
|
names = [tc.get("function", {}).get("name", "?") for tc in tool_calls]
|
|
yield {"delta": f"\n[Using tools: {', '.join(names)}]", "type": "tool"}
|
|
finish = chunk_data.get("finish_reason")
|
|
if finish:
|
|
yield {"type": "done"}
|
|
elif line.startswith("event:"):
|
|
pass # SSE event name, skip
|
|
except urllib.error.HTTPError as e:
|
|
body = e.read().decode("utf-8", errors="replace")[:500]
|
|
yield {"type": "error", "delta": f"HTTP {e.code}: {body}"}
|
|
except Exception as e:
|
|
yield {"type": "error", "delta": str(e)}
|
|
|
|
|
|
# ── FastAPI App ──
|
|
|
|
try:
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import StreamingResponse, JSONResponse, Response
|
|
from pydantic import BaseModel
|
|
except ImportError:
|
|
print("Installing fastapi...")
|
|
os.system("pip install fastapi uvicorn httpx -q")
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import StreamingResponse, JSONResponse, Response
|
|
from pydantic import BaseModel
|
|
|
|
app = FastAPI(title="Wiki VectorDB Chat")
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
message: str
|
|
provider_id: str = "zai-coding"
|
|
history: list = []
|
|
rag_wiki: bool = True
|
|
rag_vector: bool = True
|
|
mode: str = "chat"
|
|
|
|
|
|
class ProviderSave(BaseModel):
|
|
id: str
|
|
name: str
|
|
base_url: str
|
|
model: str
|
|
api_key: str = ""
|
|
format: str = "openai"
|
|
icon: str = "\u2b99"
|
|
description: str = ""
|
|
|
|
|
|
@app.get("/providers/presets")
|
|
async def get_presets():
|
|
return PRESETS
|
|
|
|
|
|
@app.get("/providers")
|
|
async def list_providers():
|
|
return get_all_providers()
|
|
|
|
|
|
@app.post("/providers/save")
|
|
async def save_provider(p: ProviderSave):
|
|
custom = load_custom_providers()
|
|
p_dict = p.model_dump()
|
|
# Update or append
|
|
found = False
|
|
for i, existing in enumerate(custom):
|
|
if existing.get("id") == p.id:
|
|
custom[i] = p_dict
|
|
found = True
|
|
break
|
|
if not found:
|
|
custom.append(p_dict)
|
|
save_custom_providers(custom)
|
|
return {"ok": True, "provider": p_dict}
|
|
|
|
|
|
@app.delete("/providers/{provider_id}")
|
|
async def delete_provider(provider_id: str):
|
|
custom = load_custom_providers()
|
|
custom = [p for p in custom if p.get("id") != provider_id]
|
|
save_custom_providers(custom)
|
|
return {"ok": True}
|
|
|
|
|
|
@app.post("/chat/message")
|
|
async def chat_message(msg: ChatMessage):
|
|
async def generate():
|
|
providers = get_all_providers()
|
|
provider = next((p for p in providers if p.get("id") == msg.provider_id), None)
|
|
if not provider:
|
|
yield f"data: {json.dumps({'type':'error','delta':'Provider not found'})}\n\n"
|
|
return
|
|
|
|
# Build conversation history
|
|
messages = []
|
|
for h in msg.history[-10:]:
|
|
messages.append(h)
|
|
|
|
messages.append({"role": "user", "content": msg.message})
|
|
|
|
# Build RAG context with per-session toggles
|
|
rag_context = await build_rag_context(msg.message, msg.rag_wiki, msg.rag_vector)
|
|
|
|
# Mode-specific system prompt additions
|
|
mode_hints = {
|
|
"chat": "",
|
|
"code": "\n\nMODE: Coding. The user is working on code. Provide precise, well-structured code examples with explanations. Use markdown code blocks. Be concise and technical.",
|
|
"brain": "\n\nMODE: Brainstorm. The user wants creative exploration. Think freely, offer multiple perspectives, suggest unconventional approaches. Be enthusiastic and expansive.",
|
|
}
|
|
system_prompt = rag_context + mode_hints.get(msg.mode, "")
|
|
|
|
async for chunk in call_llm_stream(provider, system_prompt, messages):
|
|
data = json.dumps(chunk, ensure_ascii=False)
|
|
yield f"data: {data}\n\n"
|
|
yield f"data: {json.dumps({'type':'done'})}\n\n"
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
|
|
@app.post("/chat/tunnel")
|
|
async def chat_tunnel(msg: ChatMessage):
|
|
"""Server-side token chat — uses ZAI_API_TOKEN env var if available."""
|
|
async def generate():
|
|
providers = get_all_providers()
|
|
provider = next((p for p in providers if p.get("id") == msg.provider_id), None)
|
|
if not provider:
|
|
yield f"data: {json.dumps({'type':'error','delta':'Provider not found'})}\n\n"
|
|
return
|
|
|
|
# Use server-side token if available (for tunnel mode)
|
|
token = os.environ.get("ZAI_API_TOKEN", "")
|
|
if token and not provider.get("api_key"):
|
|
provider = dict(provider)
|
|
provider["api_key"] = token
|
|
|
|
messages = []
|
|
for h in msg.history[-10:]:
|
|
messages.append(h)
|
|
messages.append({"role": "user", "content": msg.message})
|
|
|
|
rag_context = await build_rag_context(msg.message)
|
|
|
|
async for chunk in call_llm_stream(provider, rag_context, messages):
|
|
data = json.dumps(chunk, ensure_ascii=False)
|
|
yield f"data: {data}\n\n"
|
|
yield f"data: {json.dumps({'type':'done'})}\n\n"
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok", "providers": len(get_all_providers())}
|
|
|
|
|
|
class WikiSave(BaseModel):
|
|
question: str
|
|
answer: str
|
|
topic: str = "chat-saved"
|
|
|
|
|
|
@app.post("/chat/save-to-wiki")
|
|
async def save_to_wiki(item: WikiSave):
|
|
"""Save a Q&A pair directly to wiki-kb.json."""
|
|
try:
|
|
kb_path = Path("/opt/blog/wiki-kb.json")
|
|
kb = json.loads(kb_path.read_text())
|
|
entry = {
|
|
"q": item.question,
|
|
"a": item.answer,
|
|
"topic": item.topic,
|
|
"author": "chat-assistant",
|
|
"source": "chat-saved",
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M"),
|
|
}
|
|
kb.append(entry)
|
|
kb_path.write_text(json.dumps(kb, ensure_ascii=False, indent=2))
|
|
return {"ok": True, "total": len(kb)}
|
|
except Exception as e:
|
|
return {"ok": False, "error": str(e)}
|
|
|
|
|
|
def main():
|
|
import uvicorn
|
|
port = 8770
|
|
for i, arg in enumerate(__import__("sys").argv):
|
|
if arg == "--port" and i + 1 < len(__import__("sys").argv):
|
|
port = int(__import__("sys").argv[i + 1])
|
|
print(f"Wiki VectorDB Chat starting on port {port}")
|
|
uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|