Files
QwenClaw-with-Auth/rig-service/src/agent.rs

244 lines
7.5 KiB
Rust

//! Agent management and multi-agent orchestration
use anyhow::{Result, Context};
use rig::{
agent::Agent,
completion::{Completion, Message},
providers::openai,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::tools::{Tool, ToolRegistry, ToolResult};
/// Agent configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub name: String,
pub preamble: String,
pub model: String,
pub provider: String,
pub temperature: f32,
pub max_turns: u32,
}
/// Agent session
#[derive(Debug, Clone)]
pub struct AgentSession {
pub id: String,
pub config: AgentConfig,
pub messages: Vec<Message>,
}
/// Agent Council for multi-agent orchestration
#[derive(Debug, Clone)]
pub struct AgentCouncil {
pub id: String,
pub name: String,
pub agents: Vec<AgentSession>,
}
/// Agent manager
#[derive(Debug, Clone)]
pub struct AgentManager {
sessions: Arc<RwLock<Vec<AgentSession>>>,
councils: Arc<RwLock<Vec<AgentCouncil>>>,
tool_registry: ToolRegistry,
}
impl AgentManager {
pub fn new(tool_registry: ToolRegistry) -> Self {
Self {
sessions: Arc::new(RwLock::new(Vec::new())),
councils: Arc::new(RwLock::new(Vec::new())),
tool_registry,
}
}
/// Create a new agent session
pub async fn create_session(&self, config: AgentConfig) -> Result<String> {
let session = AgentSession {
id: Uuid::new_v4().to_string(),
config,
messages: Vec::new(),
};
let id = session.id.clone();
let mut sessions = self.sessions.write().await;
sessions.push(session);
Ok(id)
}
/// Get a session by ID
pub async fn get_session(&self, id: &str) -> Option<AgentSession> {
let sessions = self.sessions.read().await;
sessions.iter().find(|s| s.id == id).cloned()
}
/// Execute agent prompt using Rig
pub async fn execute_prompt(
&self,
session_id: &str,
prompt: &str,
) -> Result<String> {
let session = self.get_session(session_id)
.await
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
// Get provider config (API key + optional base URL)
let (api_key, base_url) = self.get_provider_config(&session.config.provider)?;
// Create Rig agent with OpenAI-compatible provider
// Qwen uses OpenAI-compatible API, so we can use the OpenAI client
let mut client_builder = openai::ClientBuilder::new(&api_key);
// Use custom base URL if provided (for Qwen or other compatible APIs)
if let Some(url) = base_url {
client_builder = client_builder.base_url(&url);
}
let client = client_builder.build();
let agent = client
.agent(&session.config.model)
.preamble(&session.config.preamble)
.build();
// Execute prompt
let response = agent.prompt(prompt).await
.map_err(|e| anyhow::anyhow!("Rig prompt execution failed: {}", e))?;
// Store message
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.iter_mut().find(|s| s.id == session_id) {
session.messages.push(Message {
role: "user".to_string(),
content: prompt.to_string(),
});
session.messages.push(Message {
role: "assistant".to_string(),
content: response.clone(),
});
}
Ok(response)
}
/// Create agent council
pub async fn create_council(
&self,
name: &str,
agent_configs: Vec<AgentConfig>,
) -> Result<String> {
let mut agents = Vec::new();
for config in agent_configs {
let session = AgentSession {
id: Uuid::new_v4().to_string(),
config,
messages: Vec::new(),
};
agents.push(session);
}
let council = AgentCouncil {
id: Uuid::new_v4().to_string(),
name: name.to_string(),
agents,
};
let council_id = council.id.clone();
let mut councils = self.councils.write().await;
councils.push(council);
Ok(council_id)
}
/// Execute council orchestration
pub async fn execute_council(
&self,
council_id: &str,
task: &str,
) -> Result<String> {
let council = self.councils.read()
.await
.iter()
.find(|c| c.id == council_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Council not found"))?;
let mut results = Vec::new();
// Execute task with each agent
for agent in &council.agents {
match self.execute_prompt(&agent.id, task).await {
Ok(result) => {
results.push(format!("**{}**: {}", agent.config.name, result));
}
Err(e) => {
results.push(format!("**{}**: Error - {}", agent.config.name, e));
}
}
}
// Synthesize results
Ok(results.join("\n\n---\n\n"))
}
/// Get API key and base URL for provider
fn get_provider_config(&self, provider: &str) -> Result<(String, Option<String>)> {
match provider.to_lowercase().as_str() {
"qwen" | "qwen-plus" | "qwen-max" => {
let api_key = std::env::var("QWEN_API_KEY")
.map_err(|_| anyhow::anyhow!("QWEN_API_KEY not set. Get it from https://platform.qwen.ai"))?;
let base_url = std::env::var("QWEN_BASE_URL").ok();
Ok((api_key, base_url))
}
"openai" | "gpt-4" | "gpt-4o" | "gpt-3.5" => {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| anyhow::anyhow!("OPENAI_API_KEY not set"))?;
Ok((api_key, None))
}
"anthropic" | "claude" | "claude-3" => {
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY not set"))?;
Ok((api_key, None))
}
"ollama" | "local" => {
// Ollama doesn't need API key, uses localhost
Ok(("".to_string(), Some("http://localhost:11434".to_string())))
}
_ => {
// Default to Qwen for QwenClaw
let api_key = std::env::var("QWEN_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.map_err(|_| anyhow::anyhow!("No API key found. Set QWEN_API_KEY or OPENAI_API_KEY"))?;
let base_url = std::env::var("QWEN_BASE_URL").ok();
Ok((api_key, base_url))
}
}
}
/// List all sessions
pub async fn list_sessions(&self) -> Vec<AgentSession> {
let sessions = self.sessions.read().await;
sessions.clone()
}
/// List all councils
pub async fn list_councils(&self) -> Vec<AgentCouncil> {
let councils = self.councils.read().await;
councils.clone()
}
/// Delete a session
pub async fn delete_session(&self, id: &str) -> Result<()> {
let mut sessions = self.sessions.write().await;
sessions.retain(|s| s.id != id);
Ok(())
}
}