v1.3.0: Full Rig integration - Multi-agent AI framework

This commit is contained in:
admin
2026-02-26 11:56:08 +04:00
Unverified
parent 04e0fcd59e
commit 5455eaa125
11 changed files with 1958 additions and 0 deletions

49
rig-service/Cargo.toml Normal file
View File

@@ -0,0 +1,49 @@
[package]
name = "qwenclaw-rig"
version = "0.1.0"
edition = "2021"
description = "Rig-based AI agent service for QwenClaw"
authors = ["admin <admin@rommark.dev>"]
[dependencies]
# Rig core framework
rig-core = "0.16"
# Async runtime
tokio = { version = "1", features = ["full"] }
# HTTP server (Axum)
axum = { version = "0.7", features = ["macros"] }
tower = "0.4"
tower-http = { version = "0.5", features = ["cors", "trace"] }
# Serialization
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# Logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# Error handling
anyhow = "1"
thiserror = "1"
# Environment variables
dotenvy = "0.15"
# UUID for session IDs
uuid = { version = "1", features = ["v4"] }
# Time
chrono = { version = "0.4", features = ["serde"] }
# Vector store (SQLite for simplicity)
rusqlite = { version = "0.31", features = ["bundled"] }
# Embeddings (using Rig's built-in)
rig-sqlite = "0.1"
[profile.release]
opt-level = 3
lto = true

217
rig-service/src/agent.rs Normal file
View File

@@ -0,0 +1,217 @@
//! Agent management and multi-agent orchestration
use anyhow::Result;
use rig::{
agent::Agent,
client::{CompletionClient, ProviderClient},
completion::{Completion, Message},
providers::openai,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::tools::{Tool, ToolRegistry, ToolResult};
/// Agent configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub name: String,
pub preamble: String,
pub model: String,
pub provider: String,
pub temperature: f32,
pub max_turns: u32,
}
/// Agent session
#[derive(Debug, Clone)]
pub struct AgentSession {
pub id: String,
pub config: AgentConfig,
pub messages: Vec<Message>,
}
/// Agent Council for multi-agent orchestration
#[derive(Debug, Clone)]
pub struct AgentCouncil {
pub id: String,
pub name: String,
pub agents: Vec<AgentSession>,
}
/// Agent manager
#[derive(Debug, Clone)]
pub struct AgentManager {
sessions: Arc<RwLock<Vec<AgentSession>>>,
councils: Arc<RwLock<Vec<AgentCouncil>>>,
tool_registry: ToolRegistry,
}
impl AgentManager {
pub fn new(tool_registry: ToolRegistry) -> Self {
Self {
sessions: Arc::new(RwLock::new(Vec::new())),
councils: Arc::new(RwLock::new(Vec::new())),
tool_registry,
}
}
/// Create a new agent session
pub async fn create_session(&self, config: AgentConfig) -> Result<String> {
let session = AgentSession {
id: Uuid::new_v4().to_string(),
config,
messages: Vec::new(),
};
let id = session.id.clone();
let mut sessions = self.sessions.write().await;
sessions.push(session);
Ok(id)
}
/// Get a session by ID
pub async fn get_session(&self, id: &str) -> Option<AgentSession> {
let sessions = self.sessions.read().await;
sessions.iter().find(|s| s.id == id).cloned()
}
/// Execute agent prompt
pub async fn execute_prompt(
&self,
session_id: &str,
prompt: &str,
) -> Result<String> {
let session = self.get_session(session_id)
.await
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
// Create Rig client based on provider
let client = self.create_client(&session.config.provider)?;
// Build agent with Rig
let agent = client
.agent(&session.config.model)
.preamble(&session.config.preamble)
.temperature(session.config.temperature)
.build();
// Execute prompt
let response = agent.prompt(prompt).await?;
// Store message
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.iter_mut().find(|s| s.id == session_id) {
session.messages.push(Message {
role: "user".to_string(),
content: prompt.to_string(),
});
session.messages.push(Message {
role: "assistant".to_string(),
content: response.clone(),
});
}
Ok(response)
}
/// Create agent council
pub async fn create_council(
&self,
name: &str,
agent_configs: Vec<AgentConfig>,
) -> Result<String> {
let mut agents = Vec::new();
for config in agent_configs {
let session = AgentSession {
id: Uuid::new_v4().to_string(),
config,
messages: Vec::new(),
};
agents.push(session);
}
let council = AgentCouncil {
id: Uuid::new_v4().to_string(),
name: name.to_string(),
agents,
};
let council_id = council.id.clone();
let mut councils = self.councils.write().await;
councils.push(council);
Ok(council_id)
}
/// Execute council orchestration
pub async fn execute_council(
&self,
council_id: &str,
task: &str,
) -> Result<String> {
let council = self.councils.read()
.await
.iter()
.find(|c| c.id == council_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Council not found"))?;
let mut results = Vec::new();
// Execute task with each agent
for agent in &council.agents {
match self.execute_prompt(&agent.id, task).await {
Ok(result) => {
results.push(format!("{}: {}", agent.config.name, result));
}
Err(e) => {
results.push(format!("{}: Error - {}", agent.config.name, e));
}
}
}
// Synthesize results
Ok(results.join("\n\n"))
}
/// Create Rig client for provider
fn create_client(&self, provider: &str) -> Result<openai::Client> {
match provider.to_lowercase().as_str() {
"openai" => {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| anyhow::anyhow!("OPENAI_API_KEY not set"))?;
Ok(openai::Client::new(&api_key))
}
_ => {
// Default to OpenAI for now
let api_key = std::env::var("OPENAI_API_KEY")
.unwrap_or_else(|_| "dummy".to_string());
Ok(openai::Client::new(&api_key))
}
}
}
/// List all sessions
pub async fn list_sessions(&self) -> Vec<AgentSession> {
let sessions = self.sessions.read().await;
sessions.clone()
}
/// List all councils
pub async fn list_councils(&self) -> Vec<AgentCouncil> {
let councils = self.councils.read().await;
councils.clone()
}
/// Delete a session
pub async fn delete_session(&self, id: &str) -> Result<()> {
let mut sessions = self.sessions.write().await;
sessions.retain(|s| s.id != id);
Ok(())
}
}

406
rig-service/src/api.rs Normal file
View File

@@ -0,0 +1,406 @@
//! HTTP API server
use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use tower_http::{
cors::{Any, CorsLayer},
trace::TraceLayer,
};
use tracing::info;
use crate::{
agent::{AgentConfig, AgentManager},
config::Config,
tools::ToolRegistry,
vector_store::{simple_embed, Document, VectorStore},
};
/// Application state
#[derive(Clone)]
pub struct AppState {
pub config: Config,
pub vector_store: VectorStore,
pub tool_registry: ToolRegistry,
pub agent_manager: AgentManager,
}
/// Create the Axum router
pub fn create_app(config: Config, vector_store: VectorStore, tool_registry: ToolRegistry) -> Router {
let agent_manager = AgentManager::new(tool_registry.clone());
let state = AppState {
config,
vector_store,
tool_registry,
agent_manager,
};
Router::new()
// Health check
.route("/health", get(health_check))
// Agent endpoints
.route("/api/agents", post(create_agent))
.route("/api/agents", get(list_agents))
.route("/api/agents/:id/prompt", post(execute_prompt))
.route("/api/agents/:id", get(get_agent))
.route("/api/agents/:id", delete(delete_agent))
// Council endpoints
.route("/api/councils", post(create_council))
.route("/api/councils", get(list_councils))
.route("/api/councils/:id/execute", post(execute_council))
// Tool endpoints
.route("/api/tools", get(list_tools))
.route("/api/tools/search", post(search_tools))
// Vector store endpoints
.route("/api/documents", post(add_document))
.route("/api/documents", get(list_documents))
.route("/api/documents/search", post(search_documents))
.route("/api/documents/:id", get(get_document))
.route("/api/documents/:id", delete(delete_document))
// State
.with_state(state)
// Middleware
.layer(TraceLayer::new_for_http())
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
}
// Health check handler
async fn health_check() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "ok",
"service": "qwenclaw-rig"
}))
}
// ============ Agent Endpoints ============
#[derive(Debug, Deserialize)]
struct CreateAgentRequest {
name: String,
preamble: String,
model: Option<String>,
provider: Option<String>,
temperature: Option<f32>,
}
#[derive(Debug, Serialize)]
struct CreateAgentResponse {
session_id: String,
}
async fn create_agent(
State(state): State<AppState>,
Json(payload): Json<CreateAgentRequest>,
) -> Result<Json<CreateAgentResponse>, StatusCode> {
let config = AgentConfig {
name: payload.name,
preamble: payload.preamble,
model: payload.model.unwrap_or(state.config.default_model.clone()),
provider: payload.provider.unwrap_or(state.config.default_provider.clone()),
temperature: payload.temperature.unwrap_or(0.7),
max_turns: 5,
};
let session_id = state.agent_manager
.create_session(config)
.await
.map_err(|e| {
tracing::error!("Failed to create agent: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(CreateAgentResponse { session_id }))
}
async fn list_agents(
State(state): State<AppState>,
) -> Json<serde_json::Value> {
let sessions = state.agent_manager.list_sessions().await;
Json(serde_json::json!({
"agents": sessions.iter().map(|s| serde_json::json!({
"id": s.id,
"name": s.config.name,
"model": s.config.model,
"provider": s.config.provider,
})).collect::<Vec<_>>()
}))
}
#[derive(Debug, Deserialize)]
struct PromptRequest {
prompt: String,
}
async fn execute_prompt(
State(state): State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
Json(payload): Json<PromptRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let response = state.agent_manager
.execute_prompt(&id, &payload.prompt)
.await
.map_err(|e| {
tracing::error!("Failed to execute prompt: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::json!({
"response": response
})))
}
async fn get_agent(
State(state): State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let session = state.agent_manager
.get_session(&id)
.await
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(serde_json::json!({
"agent": {
"id": session.id,
"name": session.config.name,
"preamble": session.config.preamble,
"model": session.config.model,
"provider": session.config.provider,
"temperature": session.config.temperature,
}
})))
}
async fn delete_agent(
State(state): State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<StatusCode, StatusCode> {
state.agent_manager
.delete_session(&id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::NO_CONTENT)
}
// ============ Council Endpoints ============
#[derive(Debug, Deserialize)]
struct CreateCouncilRequest {
name: String,
agents: Vec<CreateAgentRequest>,
}
async fn create_council(
State(state): State<AppState>,
Json(payload): Json<CreateCouncilRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let agent_configs: Vec<AgentConfig> = payload.agents
.into_iter()
.map(|a| AgentConfig {
name: a.name,
preamble: a.preamble,
model: a.model.unwrap_or_else(|| state.config.default_model.clone()),
provider: a.provider.unwrap_or_else(|| state.config.default_provider.clone()),
temperature: a.temperature.unwrap_or(0.7),
max_turns: 5,
})
.collect();
let council_id = state.agent_manager
.create_council(&payload.name, agent_configs)
.await
.map_err(|e| {
tracing::error!("Failed to create council: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::json!({
"council_id": council_id
})))
}
async fn list_councils(
State(state): State<AppState>,
) -> Json<serde_json::Value> {
let councils = state.agent_manager.list_councils().await;
Json(serde_json::json!({
"councils": councils.iter().map(|c| serde_json::json!({
"id": c.id,
"name": c.name,
"agents": c.agents.iter().map(|a| serde_json::json!({
"id": a.id,
"name": a.config.name,
})).collect::<Vec<_>>()
})).collect::<Vec<_>>()
}))
}
#[derive(Debug, Deserialize)]
struct ExecuteCouncilRequest {
task: String,
}
async fn execute_council(
State(state): State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
Json(payload): Json<ExecuteCouncilRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let response = state.agent_manager
.execute_council(&id, &payload.task)
.await
.map_err(|e| {
tracing::error!("Failed to execute council: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::json!({
"response": response
})))
}
// ============ Tool Endpoints ============
async fn list_tools(
State(state): State<AppState>,
) -> Json<serde_json::Value> {
let tools = state.tool_registry.get_all_tools().await;
Json(serde_json::json!({
"tools": tools
}))
}
#[derive(Debug, Deserialize)]
struct SearchToolsRequest {
query: String,
limit: Option<usize>,
}
async fn search_tools(
State(state): State<AppState>,
Json(payload): Json<SearchToolsRequest>,
) -> Json<serde_json::Value> {
let limit = payload.limit.unwrap_or(10);
let tools = state.tool_registry.search_tools(&payload.query, limit).await;
Json(serde_json::json!({
"tools": tools
}))
}
// ============ Document Endpoints ============
#[derive(Debug, Deserialize)]
struct AddDocumentRequest {
content: String,
metadata: Option<serde_json::Value>,
}
async fn add_document(
State(state): State<AppState>,
Json(payload): Json<AddDocumentRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let doc = Document {
id: uuid::Uuid::new_v4().to_string(),
content: payload.content.clone(),
metadata: payload.metadata.unwrap_or_default(),
embedding: simple_embed(&payload.content),
};
state.vector_store
.add_document(&doc)
.map_err(|e| {
tracing::error!("Failed to add document: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::json!({
"id": doc.id
})))
}
async fn list_documents(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let count = state.vector_store.count()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({
"count": count
})))
}
#[derive(Debug, Deserialize)]
struct SearchDocumentsRequest {
query: String,
limit: Option<usize>,
}
async fn search_documents(
State(state): State<AppState>,
Json(payload): Json<SearchDocumentsRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let limit = payload.limit.unwrap_or(10);
let query_embedding = simple_embed(&payload.query);
let docs = state.vector_store
.search(&query_embedding, limit)
.map_err(|e| {
tracing::error!("Failed to search documents: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::json!({
"documents": docs.iter().map(|d| serde_json::json!({
"id": d.id,
"content": d.content,
"metadata": d.metadata,
})).collect::<Vec<_>>()
})))
}
async fn get_document(
State(state): State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let doc = state.vector_store
.get_document(&id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(serde_json::json!({
"document": {
"id": doc.id,
"content": doc.content,
"metadata": doc.metadata,
}
})))
}
async fn delete_document(
State(state): State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<StatusCode, StatusCode> {
state.vector_store
.delete_document(&id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::NO_CONTENT)
}

43
rig-service/src/config.rs Normal file
View File

@@ -0,0 +1,43 @@
//! Configuration management
use anyhow::{Result, Context};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
/// Host to bind to
pub host: String,
/// Port to listen on
pub port: u16,
/// Database path for vector store
pub database_path: String,
/// Default model provider
pub default_provider: String,
/// Default model name
pub default_model: String,
/// API keys for providers
pub openai_api_key: Option<String>,
pub anthropic_api_key: Option<String>,
pub qwen_api_key: Option<String>,
}
impl Config {
pub fn from_env() -> Result<Self> {
Ok(Self {
host: std::env::var("RIG_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()),
port: std::env::var("RIG_PORT")
.unwrap_or_else(|_| "8080".to_string())
.parse()
.context("Invalid RIG_PORT")?,
database_path: std::env::var("RIG_DATABASE_PATH")
.unwrap_or_else(|_| "rig-store.db".to_string()),
default_provider: std::env::var("RIG_DEFAULT_PROVIDER")
.unwrap_or_else(|_| "openai".to_string()),
default_model: std::env::var("RIG_DEFAULT_MODEL")
.unwrap_or_else(|_| "gpt-4".to_string()),
openai_api_key: std::env::var("OPENAI_API_KEY").ok(),
anthropic_api_key: std::env::var("ANTHROPIC_API_KEY").ok(),
qwen_api_key: std::env::var("QWEN_API_KEY").ok(),
})
}
}

57
rig-service/src/main.rs Normal file
View File

@@ -0,0 +1,57 @@
//! QwenClaw Rig Service
//!
//! A Rust-based AI agent service using Rig framework for:
//! - Multi-agent orchestration
//! - Dynamic tool calling
//! - RAG workflows
//! - Vector store integration
mod agent;
mod api;
mod tools;
mod vector_store;
mod config;
use anyhow::Result;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "qwenclaw_rig=debug,info".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// Load environment variables
dotenvy::dotenv().ok();
tracing::info!("🦀 Starting QwenClaw Rig Service...");
// Initialize configuration
let config = config::Config::from_env()?;
// Initialize vector store
let vector_store = vector_store::VectorStore::new(&config.database_path).await?;
// Initialize tool registry
let tool_registry = tools::ToolRegistry::new();
// Create API server
let app = api::create_app(config, vector_store, tool_registry);
// Get host and port from config
let addr = format!("{}:{}", config.host, config.port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("🚀 Rig service listening on http://{}", addr);
tracing::info!("📚 API docs: http://{}/docs", addr);
// Start server
axum::serve(listener, app).await?;
Ok(())
}

180
rig-service/src/tools.rs Normal file
View File

@@ -0,0 +1,180 @@
//! Tool registry and dynamic tool resolution
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Tool definition
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
pub category: String,
}
/// Tool execution result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: String,
pub error: Option<String>,
}
/// Tool registry for managing available tools
#[derive(Debug, Clone)]
pub struct ToolRegistry {
tools: Arc<RwLock<HashMap<String, Tool>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
let mut registry = Self {
tools: Arc::new(RwLock::new(HashMap::new())),
};
// Register built-in tools
registry.register_builtin_tools();
registry
}
/// Register built-in tools
fn register_builtin_tools(&mut self) {
// Calculator tool
self.register_tool(Tool {
name: "calculator".to_string(),
description: "Perform mathematical calculations".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "Mathematical expression to evaluate"
}
},
"required": ["expression"]
}),
category: "utility".to_string(),
});
// Web search tool
self.register_tool(Tool {
name: "web_search".to_string(),
description: "Search the web for information".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query"
},
"limit": {
"type": "integer",
"description": "Number of results"
}
},
"required": ["query"]
}),
category: "research".to_string(),
});
// File operations tool
self.register_tool(Tool {
name: "file_operations".to_string(),
description: "Read, write, and manage files".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"operation": {
"type": "string",
"enum": ["read", "write", "delete", "list"]
},
"path": {
"type": "string",
"description": "File or directory path"
},
"content": {
"type": "string",
"description": "Content to write (for write operation)"
}
},
"required": ["operation", "path"]
}),
category: "filesystem".to_string(),
});
// Code execution tool
self.register_tool(Tool {
name: "code_execution".to_string(),
description: "Execute code snippets in various languages".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"language": {
"type": "string",
"enum": ["python", "javascript", "rust", "bash"]
},
"code": {
"type": "string",
"description": "Code to execute"
}
},
"required": ["language", "code"]
}),
category: "development".to_string(),
});
}
/// Register a tool
pub async fn register_tool(&self, tool: Tool) {
let mut tools = self.tools.write().await;
tools.insert(tool.name.clone(), tool);
}
/// Get all tools
pub async fn get_all_tools(&self) -> Vec<Tool> {
let tools = self.tools.read().await;
tools.values().cloned().collect()
}
/// Get tools by category
pub async fn get_tools_by_category(&self, category: &str) -> Vec<Tool> {
let tools = self.tools.read().await;
tools.values()
.filter(|t| t.category == category)
.cloned()
.collect()
}
/// Search for tools by query (simple text search)
pub async fn search_tools(&self, query: &str, limit: usize) -> Vec<Tool> {
let tools = self.tools.read().await;
let query_lower = query.to_lowercase();
let mut results: Vec<_> = tools.values()
.filter(|t| {
t.name.to_lowercase().contains(&query_lower) ||
t.description.to_lowercase().contains(&query_lower)
})
.cloned()
.collect();
results.truncate(limit);
results
}
/// Get a specific tool by name
pub async fn get_tool(&self, name: &str) -> Option<Tool> {
let tools = self.tools.read().await;
tools.get(name).cloned()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,214 @@
//! Vector store for RAG and semantic search
use anyhow::Result;
use rusqlite::{Connection, params};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// Document for vector store
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Document {
pub id: String,
pub content: String,
pub metadata: serde_json::Value,
pub embedding: Vec<f32>,
}
/// Vector store using SQLite with simple embeddings
pub struct VectorStore {
conn: Connection,
}
impl VectorStore {
/// Create new vector store
pub async fn new(db_path: &str) -> Result<Self> {
let conn = Connection::open(db_path)?;
// Create tables
conn.execute(
"CREATE TABLE IF NOT EXISTS documents (
id TEXT PRIMARY KEY,
content TEXT NOT NULL,
metadata TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)",
[],
)?;
conn.execute(
"CREATE TABLE IF NOT EXISTS embeddings (
document_id TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
FOREIGN KEY (document_id) REFERENCES documents(id)
)",
[],
)?;
Ok(Self { conn })
}
/// Add document to store
pub fn add_document(&self, doc: &Document) -> Result<()> {
let tx = self.conn.transaction()?;
// Insert document
tx.execute(
"INSERT OR REPLACE INTO documents (id, content, metadata) VALUES (?1, ?2, ?3)",
params![doc.id, doc.content, serde_json::to_string(&doc.metadata)?],
)?;
// Insert embedding (store as blob)
let embedding_bytes: Vec<u8> = doc.embedding
.iter()
.flat_map(|&f| f.to_le_bytes().to_vec())
.collect();
tx.execute(
"INSERT OR REPLACE INTO embeddings (document_id, embedding) VALUES (?1, ?2)",
params![doc.id, embedding_bytes],
)?;
tx.commit()?;
Ok(())
}
/// Search documents by similarity (simple cosine similarity)
pub fn search(&self, query_embedding: &[f32], limit: usize) -> Result<Vec<Document>> {
let mut stmt = self.conn.prepare(
"SELECT d.id, d.content, d.metadata, e.embedding
FROM documents d
JOIN embeddings e ON d.id = e.document_id
ORDER BY d.created_at DESC
LIMIT ?1"
)?;
let docs = stmt.query_map(params![limit], |row| {
let id: String = row.get(0)?;
let content: String = row.get(1)?;
let metadata: String = row.get(2)?;
let embedding_blob: Vec<u8> = row.get(3)?;
// Convert blob back to f32 vector
let embedding: Vec<f32> = embedding_blob
.chunks(4)
.map(|chunk| {
let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
f32::from_le_bytes(bytes)
})
.collect();
Ok(Document {
id,
content,
metadata: serde_json::from_str(&metadata).unwrap_or_default(),
embedding,
})
})?;
let mut results: Vec<Document> = docs.filter_map(|r| r.ok()).collect();
// Sort by cosine similarity
results.sort_by(|a, b| {
let sim_a = cosine_similarity(query_embedding, &a.embedding);
let sim_b = cosine_similarity(query_embedding, &b.embedding);
sim_b.partial_cmp(&sim_a).unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
/// Get document by ID
pub fn get_document(&self, id: &str) -> Result<Option<Document>> {
let mut stmt = self.conn.prepare(
"SELECT d.id, d.content, d.metadata, e.embedding
FROM documents d
JOIN embeddings e ON d.id = e.document_id
WHERE d.id = ?1"
)?;
let doc = stmt.query_row(params![id], |row| {
let id: String = row.get(0)?;
let content: String = row.get(1)?;
let metadata: String = row.get(2)?;
let embedding_blob: Vec<u8> = row.get(3)?;
let embedding: Vec<f32> = embedding_blob
.chunks(4)
.map(|chunk| {
let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
f32::from_le_bytes(bytes)
})
.collect();
Ok(Document {
id,
content,
metadata: serde_json::from_str(&metadata).unwrap_or_default(),
embedding,
})
}).ok();
Ok(doc.ok().flatten())
}
/// Delete document
pub fn delete_document(&self, id: &str) -> Result<()> {
self.conn.execute(
"DELETE FROM documents WHERE id = ?1",
params![id],
)?;
self.conn.execute(
"DELETE FROM embeddings WHERE document_id = ?1",
params![id],
)?;
Ok(())
}
/// Get document count
pub fn count(&self) -> Result<usize> {
let count: usize = self.conn.query_row(
"SELECT COUNT(*) FROM documents",
[],
|row| row.get(0),
)?;
Ok(count)
}
}
/// Calculate cosine similarity between two vectors
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
/// Simple embedding function (placeholder - in production, use real embeddings)
pub fn simple_embed(text: &str) -> Vec<f32> {
// Simple hash-based embedding (NOT production quality!)
// In production, use a real embedding model via Rig
let mut embedding = vec![0.0; 384]; // 384-dimensional embedding
for (i, byte) in text.bytes().enumerate() {
embedding[i % 384] += byte as f32 / 255.0;
}
// Normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in embedding.iter_mut() {
*x /= norm;
}
}
embedding
}