Improve VectorDB search: multi-query, topic boosting, wider recall
- Generate multiple query variants (entities, topic words, combined) - Search with top_k=30 per sub-query for wider recall - Boost results matching multiple topic words for relevance - Deduplicate and merge across all sub-queries - Return top 15 results (up from 10) for richer RAG context Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -114,52 +114,102 @@ def detect_provider_format(base_url: str) -> str:
|
||||
return "openai"
|
||||
|
||||
|
||||
async def search_kb(query: str, limit: int = 3) -> str:
|
||||
"""Search wiki-kb.json via wiki-api."""
|
||||
async def search_kb(query: str, limit: int = 5) -> str:
|
||||
"""Search wiki-kb.json via wiki-api with multiple query variants."""
|
||||
try:
|
||||
url = f"{WIKI_API}/search?q={urllib.parse.quote(query)}&limit={limit}&token={_API_TOKEN}"
|
||||
req = urllib.request.Request(url)
|
||||
with urllib.request.urlopen(req, timeout=5) as resp:
|
||||
data = json.loads(resp.read())
|
||||
results = data.get("results", [])
|
||||
if not results:
|
||||
return ""
|
||||
lines = []
|
||||
for r in results[:limit]:
|
||||
q_text = r.get("q", "")
|
||||
a_text = r.get("a", "")
|
||||
topic = r.get("topic", "")
|
||||
score = r.get("_score", 0)
|
||||
lines.append(f"[{topic}] Q: {q_text}\nA: {a_text}")
|
||||
return "\n\n".join(lines)
|
||||
import re
|
||||
words = re.findall(r'\b[A-Z][a-z]+\b', query)
|
||||
queries = [query]
|
||||
if words:
|
||||
queries.append(" ".join(words))
|
||||
queries = list(dict.fromkeys(queries))[:2]
|
||||
|
||||
all_results = {}
|
||||
for q in queries:
|
||||
url = f"{WIKI_API}/search?q={urllib.parse.quote(q)}&limit={limit}&token={_API_TOKEN}"
|
||||
req = urllib.request.Request(url)
|
||||
with urllib.request.urlopen(req, timeout=5) as resp:
|
||||
data = json.loads(resp.read())
|
||||
for r in data.get("results", []):
|
||||
key = r.get("q", "")[:80]
|
||||
if key not in all_results:
|
||||
all_results[key] = r
|
||||
|
||||
results = list(all_results.values())[:limit]
|
||||
if not results:
|
||||
return ""
|
||||
lines = []
|
||||
for r in results:
|
||||
lines.append(f"[{r.get('topic', '')}] Q: {r.get('q', '')}\nA: {r.get('a', '')}")
|
||||
return "\n\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"(KB search error: {e})"
|
||||
|
||||
|
||||
async def search_vector(query: str, top_k: int = 5) -> str:
|
||||
"""Search vector-db for related Discord/Reddit messages."""
|
||||
async def search_vector(query: str, top_k: int = 10) -> str:
|
||||
"""Search vector-db with multiple query variants for better recall."""
|
||||
try:
|
||||
data = json.dumps({"query": query, "top_k": top_k}).encode()
|
||||
req = urllib.request.Request(
|
||||
f"{VECTOR_DB}/vector/search",
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json", "x-api-key": _API_TOKEN},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=8) as resp:
|
||||
result = json.loads(resp.read())
|
||||
hits = result.get("results", [])
|
||||
if not hits:
|
||||
return ""
|
||||
lines = []
|
||||
for h in hits[:top_k]:
|
||||
text = h.get("content", "") or h.get("text", "")
|
||||
score = h.get("score", 0)
|
||||
source = h.get("source", "unknown")
|
||||
author = h.get("author", "")
|
||||
channel = h.get("channel", "")
|
||||
preview = text.replace("\n", " ")[:200]
|
||||
lines.append(f"[{source}] @{author} in #{channel}: {preview} (score: {score:.2f})")
|
||||
return "\n\n".join(lines)
|
||||
import re
|
||||
# Extract named entities (capitalized) and topic words separately
|
||||
entities = [w for w in re.findall(r'\b[A-Z][a-z]+\b', query)
|
||||
if w.lower() not in ('what', 'when', 'where', 'how', 'why', 'that', 'this')]
|
||||
stop = {'what', 'about', 'that', 'this', 'with', 'from', 'have', 'been',
|
||||
'said', 'recently', 'announced', 'did', 'does', 'when', 'where',
|
||||
'how', 'why', 'the', 'and', 'for', 'are', 'was'}
|
||||
entity_lower = {e.lower() for e in entities}
|
||||
topic_words = [w for w in query.lower().split()
|
||||
if len(w) > 2 and w not in stop and w not in entity_lower]
|
||||
|
||||
queries = [query] # full original
|
||||
if entities:
|
||||
queries.append(" ".join(entities)) # e.g. "Jodie"
|
||||
# Topic-only query (without entity names) — catches messages BY the person
|
||||
if topic_words:
|
||||
queries.append(" ".join(topic_words[:6])) # e.g. "ambassador program"
|
||||
# Combined
|
||||
if entities and topic_words:
|
||||
queries.append(" ".join(entities + topic_words[:3]))
|
||||
|
||||
queries = list(dict.fromkeys(queries))[:4]
|
||||
|
||||
all_hits = {}
|
||||
for q in queries:
|
||||
data = json.dumps({"query": q, "top_k": 30}).encode()
|
||||
req = urllib.request.Request(
|
||||
f"{VECTOR_DB}/vector/search",
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json", "x-api-key": _API_TOKEN},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=8) as resp:
|
||||
result = json.loads(resp.read())
|
||||
for h in result.get("results", []):
|
||||
# Deduplicate by content
|
||||
text = h.get("content", "") or h.get("text", "")
|
||||
key = text[:80]
|
||||
if key not in all_hits or h.get("score", 0) > all_hits[key].get("score", 0):
|
||||
all_hits[key] = h
|
||||
|
||||
# Sort by score descending, take top_k
|
||||
# Score boost: prefer results that match multiple topic words
|
||||
def topic_relevance(h):
|
||||
text = (h.get("content", "") or "").lower()
|
||||
bonus = sum(1 for tw in topic_words if tw in text)
|
||||
return h.get("score", 0) + bonus * 0.15
|
||||
|
||||
hits = sorted(all_hits.values(), key=topic_relevance, reverse=True)[:15]
|
||||
if not hits:
|
||||
return ""
|
||||
lines = []
|
||||
for h in hits:
|
||||
text = h.get("content", "") or h.get("text", "")
|
||||
score = h.get("score", 0)
|
||||
source = h.get("source", "unknown")
|
||||
author = h.get("author", "")
|
||||
channel = h.get("channel", "")
|
||||
ts = h.get("timestamp", "")[:10]
|
||||
preview = text.replace("\n", " ")[:500]
|
||||
lines.append(f"[{source}] @{author} ({ts}): {preview} (score: {score:.2f})")
|
||||
return "\n\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"(Vector search error: {e})"
|
||||
|
||||
@@ -170,9 +220,9 @@ async def build_rag_context(user_message: str, rag_wiki: bool = True, rag_vector
|
||||
vec_results = ""
|
||||
tasks = []
|
||||
if rag_wiki:
|
||||
tasks.append(search_kb(user_message, 3))
|
||||
tasks.append(search_kb(user_message, 5))
|
||||
if rag_vector:
|
||||
tasks.append(search_vector(user_message, 5))
|
||||
tasks.append(search_vector(user_message, 10))
|
||||
|
||||
if tasks:
|
||||
results = await asyncio.gather(*tasks)
|
||||
@@ -185,7 +235,9 @@ async def build_rag_context(user_message: str, rag_wiki: bool = True, rag_vector
|
||||
parts = [
|
||||
"You are Z.ai Wiki Assistant. Use ALL the knowledge sources below to answer the user's question.",
|
||||
"Draw from both the Wiki KB and Community Messages. Synthesize information even from partial matches.",
|
||||
"If the context mentions anything relevant, include it in your answer. Be specific — quote authors, channels, and details when available.",
|
||||
"If the context mentions anything relevant, include it in your answer. Be specific — quote authors, dates, and details when available.",
|
||||
"IMPORTANT: People often talk ABOUT someone (like Jodie, Cobra, Agnes) without that person posting directly. If multiple users reference what someone said or did, reconstruct and summarize that information. Do NOT say 'no information found' if other users are clearly discussing the topic.",
|
||||
"When users ask 'what did X say about Y', look for messages FROM others referencing X's statements or actions about Y.",
|
||||
"Only say you don't have information if the sources are truly empty or completely unrelated.",
|
||||
"",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user