181 lines
5.5 KiB
Rust
181 lines
5.5 KiB
Rust
//! Tool registry and dynamic tool resolution
|
|
|
|
use anyhow::Result;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
|
|
/// Tool definition
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Tool {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: serde_json::Value,
|
|
pub category: String,
|
|
}
|
|
|
|
/// Tool execution result
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolResult {
|
|
pub success: bool,
|
|
pub output: String,
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
/// Tool registry for managing available tools
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolRegistry {
|
|
tools: Arc<RwLock<HashMap<String, Tool>>>,
|
|
}
|
|
|
|
impl ToolRegistry {
|
|
pub fn new() -> Self {
|
|
let mut registry = Self {
|
|
tools: Arc::new(RwLock::new(HashMap::new())),
|
|
};
|
|
|
|
// Register built-in tools
|
|
registry.register_builtin_tools();
|
|
|
|
registry
|
|
}
|
|
|
|
/// Register built-in tools
|
|
fn register_builtin_tools(&mut self) {
|
|
// Calculator tool
|
|
self.register_tool(Tool {
|
|
name: "calculator".to_string(),
|
|
description: "Perform mathematical calculations".to_string(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"expression": {
|
|
"type": "string",
|
|
"description": "Mathematical expression to evaluate"
|
|
}
|
|
},
|
|
"required": ["expression"]
|
|
}),
|
|
category: "utility".to_string(),
|
|
});
|
|
|
|
// Web search tool
|
|
self.register_tool(Tool {
|
|
name: "web_search".to_string(),
|
|
description: "Search the web for information".to_string(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Search query"
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "Number of results"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}),
|
|
category: "research".to_string(),
|
|
});
|
|
|
|
// File operations tool
|
|
self.register_tool(Tool {
|
|
name: "file_operations".to_string(),
|
|
description: "Read, write, and manage files".to_string(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"operation": {
|
|
"type": "string",
|
|
"enum": ["read", "write", "delete", "list"]
|
|
},
|
|
"path": {
|
|
"type": "string",
|
|
"description": "File or directory path"
|
|
},
|
|
"content": {
|
|
"type": "string",
|
|
"description": "Content to write (for write operation)"
|
|
}
|
|
},
|
|
"required": ["operation", "path"]
|
|
}),
|
|
category: "filesystem".to_string(),
|
|
});
|
|
|
|
// Code execution tool
|
|
self.register_tool(Tool {
|
|
name: "code_execution".to_string(),
|
|
description: "Execute code snippets in various languages".to_string(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"language": {
|
|
"type": "string",
|
|
"enum": ["python", "javascript", "rust", "bash"]
|
|
},
|
|
"code": {
|
|
"type": "string",
|
|
"description": "Code to execute"
|
|
}
|
|
},
|
|
"required": ["language", "code"]
|
|
}),
|
|
category: "development".to_string(),
|
|
});
|
|
}
|
|
|
|
/// Register a tool
|
|
pub async fn register_tool(&self, tool: Tool) {
|
|
let mut tools = self.tools.write().await;
|
|
tools.insert(tool.name.clone(), tool);
|
|
}
|
|
|
|
/// Get all tools
|
|
pub async fn get_all_tools(&self) -> Vec<Tool> {
|
|
let tools = self.tools.read().await;
|
|
tools.values().cloned().collect()
|
|
}
|
|
|
|
/// Get tools by category
|
|
pub async fn get_tools_by_category(&self, category: &str) -> Vec<Tool> {
|
|
let tools = self.tools.read().await;
|
|
tools.values()
|
|
.filter(|t| t.category == category)
|
|
.cloned()
|
|
.collect()
|
|
}
|
|
|
|
/// Search for tools by query (simple text search)
|
|
pub async fn search_tools(&self, query: &str, limit: usize) -> Vec<Tool> {
|
|
let tools = self.tools.read().await;
|
|
let query_lower = query.to_lowercase();
|
|
|
|
let mut results: Vec<_> = tools.values()
|
|
.filter(|t| {
|
|
t.name.to_lowercase().contains(&query_lower) ||
|
|
t.description.to_lowercase().contains(&query_lower)
|
|
})
|
|
.cloned()
|
|
.collect();
|
|
|
|
results.truncate(limit);
|
|
results
|
|
}
|
|
|
|
/// Get a specific tool by name
|
|
pub async fn get_tool(&self, name: &str) -> Option<Tool> {
|
|
let tools = self.tools.read().await;
|
|
tools.get(name).cloned()
|
|
}
|
|
}
|
|
|
|
impl Default for ToolRegistry {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|