244 lines
7.5 KiB
Rust
244 lines
7.5 KiB
Rust
//! Agent management and multi-agent orchestration
|
|
|
|
use anyhow::{Result, Context};
|
|
use rig::{
|
|
agent::Agent,
|
|
completion::{Completion, Message},
|
|
providers::openai,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
use uuid::Uuid;
|
|
|
|
use crate::tools::{Tool, ToolRegistry, ToolResult};
|
|
|
|
/// Agent configuration
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct AgentConfig {
|
|
pub name: String,
|
|
pub preamble: String,
|
|
pub model: String,
|
|
pub provider: String,
|
|
pub temperature: f32,
|
|
pub max_turns: u32,
|
|
}
|
|
|
|
/// Agent session
|
|
#[derive(Debug, Clone)]
|
|
pub struct AgentSession {
|
|
pub id: String,
|
|
pub config: AgentConfig,
|
|
pub messages: Vec<Message>,
|
|
}
|
|
|
|
/// Agent Council for multi-agent orchestration
|
|
#[derive(Debug, Clone)]
|
|
pub struct AgentCouncil {
|
|
pub id: String,
|
|
pub name: String,
|
|
pub agents: Vec<AgentSession>,
|
|
}
|
|
|
|
/// Agent manager
|
|
#[derive(Debug, Clone)]
|
|
pub struct AgentManager {
|
|
sessions: Arc<RwLock<Vec<AgentSession>>>,
|
|
councils: Arc<RwLock<Vec<AgentCouncil>>>,
|
|
tool_registry: ToolRegistry,
|
|
}
|
|
|
|
impl AgentManager {
|
|
pub fn new(tool_registry: ToolRegistry) -> Self {
|
|
Self {
|
|
sessions: Arc::new(RwLock::new(Vec::new())),
|
|
councils: Arc::new(RwLock::new(Vec::new())),
|
|
tool_registry,
|
|
}
|
|
}
|
|
|
|
/// Create a new agent session
|
|
pub async fn create_session(&self, config: AgentConfig) -> Result<String> {
|
|
let session = AgentSession {
|
|
id: Uuid::new_v4().to_string(),
|
|
config,
|
|
messages: Vec::new(),
|
|
};
|
|
|
|
let id = session.id.clone();
|
|
let mut sessions = self.sessions.write().await;
|
|
sessions.push(session);
|
|
|
|
Ok(id)
|
|
}
|
|
|
|
/// Get a session by ID
|
|
pub async fn get_session(&self, id: &str) -> Option<AgentSession> {
|
|
let sessions = self.sessions.read().await;
|
|
sessions.iter().find(|s| s.id == id).cloned()
|
|
}
|
|
|
|
/// Execute agent prompt using Rig
|
|
pub async fn execute_prompt(
|
|
&self,
|
|
session_id: &str,
|
|
prompt: &str,
|
|
) -> Result<String> {
|
|
let session = self.get_session(session_id)
|
|
.await
|
|
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
|
|
|
// Get provider config (API key + optional base URL)
|
|
let (api_key, base_url) = self.get_provider_config(&session.config.provider)?;
|
|
|
|
// Create Rig agent with OpenAI-compatible provider
|
|
// Qwen uses OpenAI-compatible API, so we can use the OpenAI client
|
|
let mut client_builder = openai::ClientBuilder::new(&api_key);
|
|
|
|
// Use custom base URL if provided (for Qwen or other compatible APIs)
|
|
if let Some(url) = base_url {
|
|
client_builder = client_builder.base_url(&url);
|
|
}
|
|
|
|
let client = client_builder.build();
|
|
|
|
let agent = client
|
|
.agent(&session.config.model)
|
|
.preamble(&session.config.preamble)
|
|
.build();
|
|
|
|
// Execute prompt
|
|
let response = agent.prompt(prompt).await
|
|
.map_err(|e| anyhow::anyhow!("Rig prompt execution failed: {}", e))?;
|
|
|
|
// Store message
|
|
let mut sessions = self.sessions.write().await;
|
|
if let Some(session) = sessions.iter_mut().find(|s| s.id == session_id) {
|
|
session.messages.push(Message {
|
|
role: "user".to_string(),
|
|
content: prompt.to_string(),
|
|
});
|
|
session.messages.push(Message {
|
|
role: "assistant".to_string(),
|
|
content: response.clone(),
|
|
});
|
|
}
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
/// Create agent council
|
|
pub async fn create_council(
|
|
&self,
|
|
name: &str,
|
|
agent_configs: Vec<AgentConfig>,
|
|
) -> Result<String> {
|
|
let mut agents = Vec::new();
|
|
|
|
for config in agent_configs {
|
|
let session = AgentSession {
|
|
id: Uuid::new_v4().to_string(),
|
|
config,
|
|
messages: Vec::new(),
|
|
};
|
|
agents.push(session);
|
|
}
|
|
|
|
let council = AgentCouncil {
|
|
id: Uuid::new_v4().to_string(),
|
|
name: name.to_string(),
|
|
agents,
|
|
};
|
|
|
|
let council_id = council.id.clone();
|
|
let mut councils = self.councils.write().await;
|
|
councils.push(council);
|
|
|
|
Ok(council_id)
|
|
}
|
|
|
|
/// Execute council orchestration
|
|
pub async fn execute_council(
|
|
&self,
|
|
council_id: &str,
|
|
task: &str,
|
|
) -> Result<String> {
|
|
let council = self.councils.read()
|
|
.await
|
|
.iter()
|
|
.find(|c| c.id == council_id)
|
|
.cloned()
|
|
.ok_or_else(|| anyhow::anyhow!("Council not found"))?;
|
|
|
|
let mut results = Vec::new();
|
|
|
|
// Execute task with each agent
|
|
for agent in &council.agents {
|
|
match self.execute_prompt(&agent.id, task).await {
|
|
Ok(result) => {
|
|
results.push(format!("**{}**: {}", agent.config.name, result));
|
|
}
|
|
Err(e) => {
|
|
results.push(format!("**{}**: Error - {}", agent.config.name, e));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Synthesize results
|
|
Ok(results.join("\n\n---\n\n"))
|
|
}
|
|
|
|
/// Get API key and base URL for provider
|
|
fn get_provider_config(&self, provider: &str) -> Result<(String, Option<String>)> {
|
|
match provider.to_lowercase().as_str() {
|
|
"qwen" | "qwen-plus" | "qwen-max" => {
|
|
let api_key = std::env::var("QWEN_API_KEY")
|
|
.map_err(|_| anyhow::anyhow!("QWEN_API_KEY not set. Get it from https://platform.qwen.ai"))?;
|
|
let base_url = std::env::var("QWEN_BASE_URL").ok();
|
|
Ok((api_key, base_url))
|
|
}
|
|
"openai" | "gpt-4" | "gpt-4o" | "gpt-3.5" => {
|
|
let api_key = std::env::var("OPENAI_API_KEY")
|
|
.map_err(|_| anyhow::anyhow!("OPENAI_API_KEY not set"))?;
|
|
Ok((api_key, None))
|
|
}
|
|
"anthropic" | "claude" | "claude-3" => {
|
|
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
|
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY not set"))?;
|
|
Ok((api_key, None))
|
|
}
|
|
"ollama" | "local" => {
|
|
// Ollama doesn't need API key, uses localhost
|
|
Ok(("".to_string(), Some("http://localhost:11434".to_string())))
|
|
}
|
|
_ => {
|
|
// Default to Qwen for QwenClaw
|
|
let api_key = std::env::var("QWEN_API_KEY")
|
|
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
|
.map_err(|_| anyhow::anyhow!("No API key found. Set QWEN_API_KEY or OPENAI_API_KEY"))?;
|
|
let base_url = std::env::var("QWEN_BASE_URL").ok();
|
|
Ok((api_key, base_url))
|
|
}
|
|
}
|
|
}
|
|
|
|
/// List all sessions
|
|
pub async fn list_sessions(&self) -> Vec<AgentSession> {
|
|
let sessions = self.sessions.read().await;
|
|
sessions.clone()
|
|
}
|
|
|
|
/// List all councils
|
|
pub async fn list_councils(&self) -> Vec<AgentCouncil> {
|
|
let councils = self.councils.read().await;
|
|
councils.clone()
|
|
}
|
|
|
|
/// Delete a session
|
|
pub async fn delete_session(&self, id: &str) -> Result<()> {
|
|
let mut sessions = self.sessions.write().await;
|
|
sessions.retain(|s| s.id != id);
|
|
Ok(())
|
|
}
|
|
}
|