Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
# ABOUTME: Ralph Orchestrator package for AI agent orchestration
|
||||
# ABOUTME: Implements the Ralph Wiggum technique with multi-tool support
|
||||
|
||||
"""Ralph Orchestrator - Simple AI agent orchestration."""
|
||||
|
||||
__version__ = "1.2.3"
|
||||
|
||||
from .orchestrator import RalphOrchestrator
|
||||
from .metrics import Metrics, CostTracker, IterationStats
|
||||
from .error_formatter import ClaudeErrorFormatter, ErrorMessage
|
||||
from .verbose_logger import VerboseLogger
|
||||
from .output import DiffStats, DiffFormatter, RalphConsole
|
||||
|
||||
__all__ = [
|
||||
"RalphOrchestrator",
|
||||
"Metrics",
|
||||
"CostTracker",
|
||||
"IterationStats",
|
||||
"ClaudeErrorFormatter",
|
||||
"ErrorMessage",
|
||||
"VerboseLogger",
|
||||
"DiffStats",
|
||||
"DiffFormatter",
|
||||
"RalphConsole",
|
||||
]
|
||||
1168
.venv/lib/python3.11/site-packages/ralph_orchestrator/__main__.py
Normal file
1168
.venv/lib/python3.11/site-packages/ralph_orchestrator/__main__.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,26 @@
|
||||
# ABOUTME: Tool adapter interfaces and implementations
|
||||
# ABOUTME: Provides unified interface for Claude, Q Chat, Gemini, ACP, and other tools
|
||||
|
||||
"""Tool adapters for Ralph Orchestrator."""
|
||||
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from .claude import ClaudeAdapter
|
||||
from .qchat import QChatAdapter
|
||||
from .kiro import KiroAdapter
|
||||
from .gemini import GeminiAdapter
|
||||
from .acp import ACPAdapter
|
||||
from .acp_handlers import ACPHandlers, PermissionRequest, PermissionResult, Terminal
|
||||
|
||||
__all__ = [
|
||||
"ToolAdapter",
|
||||
"ToolResponse",
|
||||
"ClaudeAdapter",
|
||||
"QChatAdapter",
|
||||
"KiroAdapter",
|
||||
"GeminiAdapter",
|
||||
"ACPAdapter",
|
||||
"ACPHandlers",
|
||||
"PermissionRequest",
|
||||
"PermissionResult",
|
||||
"Terminal",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,984 @@
|
||||
# ABOUTME: ACP Adapter for Agent Client Protocol integration
|
||||
# ABOUTME: Provides subprocess-based communication with ACP-compliant agents like Gemini CLI
|
||||
|
||||
"""ACP (Agent Client Protocol) adapter for Ralph Orchestrator.
|
||||
|
||||
This adapter enables Ralph to use any ACP-compliant agent (like Gemini CLI)
|
||||
as a backend for task execution. It manages the subprocess lifecycle,
|
||||
handles the initialization handshake, and routes session messages.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from .acp_client import ACPClient, ACPClientError
|
||||
from .acp_models import ACPAdapterConfig, ACPSession, UpdatePayload
|
||||
from .acp_handlers import ACPHandlers
|
||||
from ..output.console import RalphConsole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ACP Protocol version this adapter supports (integer per spec)
|
||||
ACP_PROTOCOL_VERSION = 1
|
||||
|
||||
|
||||
class ACPAdapter(ToolAdapter):
|
||||
"""Adapter for ACP-compliant agents like Gemini CLI.
|
||||
|
||||
Manages subprocess lifecycle, initialization handshake, and session
|
||||
message routing for Agent Client Protocol communication.
|
||||
|
||||
Attributes:
|
||||
agent_command: Command to spawn the agent (default: gemini).
|
||||
agent_args: Additional arguments for agent command.
|
||||
timeout: Request timeout in seconds.
|
||||
permission_mode: How to handle permission requests.
|
||||
"""
|
||||
|
||||
_TOOL_FIELD_ALIASES = {
|
||||
"toolName": ("toolName", "tool_name", "name", "tool"),
|
||||
"toolCallId": ("toolCallId", "tool_call_id", "id"),
|
||||
"arguments": ("arguments", "args", "parameters", "params", "input"),
|
||||
"status": ("status",),
|
||||
"result": ("result",),
|
||||
"error": ("error",),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_command: str = "gemini",
|
||||
agent_args: Optional[list[str]] = None,
|
||||
timeout: int = 300,
|
||||
permission_mode: str = "auto_approve",
|
||||
permission_allowlist: Optional[list[str]] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Initialize ACPAdapter.
|
||||
|
||||
Args:
|
||||
agent_command: Command to spawn the agent (default: gemini).
|
||||
agent_args: Additional command-line arguments.
|
||||
timeout: Request timeout in seconds (default: 300).
|
||||
permission_mode: Permission handling mode (default: auto_approve).
|
||||
permission_allowlist: Patterns for allowlist mode.
|
||||
verbose: Enable verbose streaming output (default: False).
|
||||
"""
|
||||
self.agent_command = agent_command
|
||||
self.agent_args = agent_args or []
|
||||
self.timeout = timeout
|
||||
self.permission_mode = permission_mode
|
||||
self.permission_allowlist = permission_allowlist or []
|
||||
self.verbose = verbose
|
||||
self._current_verbose = verbose # Per-request verbose flag
|
||||
|
||||
# Console for verbose output
|
||||
self._console = RalphConsole()
|
||||
|
||||
# State
|
||||
self._client: Optional[ACPClient] = None
|
||||
self._session_id: Optional[str] = None
|
||||
self._initialized = False
|
||||
self._session: Optional[ACPSession] = None
|
||||
|
||||
# Create permission handlers
|
||||
self._handlers = ACPHandlers(
|
||||
permission_mode=permission_mode,
|
||||
permission_allowlist=self.permission_allowlist,
|
||||
on_permission_log=self._log_permission,
|
||||
)
|
||||
|
||||
# Thread synchronization
|
||||
self._lock = threading.Lock()
|
||||
self._shutdown_requested = False
|
||||
|
||||
# Signal handlers
|
||||
self._original_sigint = None
|
||||
self._original_sigterm = None
|
||||
|
||||
# Call parent init - this will call check_availability()
|
||||
super().__init__("acp")
|
||||
|
||||
# Register signal handlers
|
||||
self._register_signal_handlers()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ACPAdapterConfig) -> "ACPAdapter":
|
||||
"""Create ACPAdapter from configuration object.
|
||||
|
||||
Args:
|
||||
config: ACPAdapterConfig with adapter settings.
|
||||
|
||||
Returns:
|
||||
Configured ACPAdapter instance.
|
||||
"""
|
||||
return cls(
|
||||
agent_command=config.agent_command,
|
||||
agent_args=config.agent_args,
|
||||
timeout=config.timeout,
|
||||
permission_mode=config.permission_mode,
|
||||
permission_allowlist=config.permission_allowlist,
|
||||
)
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if the agent command is available.
|
||||
|
||||
Returns:
|
||||
True if agent command exists in PATH, False otherwise.
|
||||
"""
|
||||
return shutil.which(self.agent_command) is not None
|
||||
|
||||
def _register_signal_handlers(self) -> None:
|
||||
"""Register signal handlers for graceful shutdown."""
|
||||
try:
|
||||
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
|
||||
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
except ValueError as e:
|
||||
logger.warning("Cannot register signal handlers (not in main thread): %s. Graceful shutdown via Ctrl+C will not work.", e)
|
||||
|
||||
def _restore_signal_handlers(self) -> None:
|
||||
"""Restore original signal handlers."""
|
||||
try:
|
||||
if self._original_sigint is not None:
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if self._original_sigterm is not None:
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning("Failed to restore signal handlers: %s", e)
|
||||
|
||||
def _signal_handler(self, signum: int, frame) -> None:
|
||||
"""Handle shutdown signals.
|
||||
|
||||
Terminates running subprocess synchronously (signal-safe),
|
||||
then propagates to original handler (orchestrator).
|
||||
|
||||
Args:
|
||||
signum: Signal number.
|
||||
frame: Current stack frame.
|
||||
"""
|
||||
with self._lock:
|
||||
self._shutdown_requested = True
|
||||
|
||||
# Kill subprocess synchronously (signal-safe)
|
||||
self.kill_subprocess_sync()
|
||||
|
||||
# Propagate signal to original handler (orchestrator's handler)
|
||||
original = self._original_sigint if signum == signal.SIGINT else self._original_sigterm
|
||||
if original and callable(original):
|
||||
original(signum, frame)
|
||||
|
||||
def kill_subprocess_sync(self) -> None:
|
||||
"""Synchronously kill the agent subprocess (signal-safe).
|
||||
|
||||
This method is safe to call from signal handlers.
|
||||
Uses non-blocking approach with immediate force kill after 2 seconds.
|
||||
"""
|
||||
if self._client and self._client._process:
|
||||
try:
|
||||
process = self._client._process
|
||||
if process.returncode is None:
|
||||
# Try graceful termination first
|
||||
process.terminate()
|
||||
|
||||
# Non-blocking poll with timeout
|
||||
import time
|
||||
start = time.time()
|
||||
timeout = 2.0
|
||||
|
||||
while time.time() - start < timeout:
|
||||
if process.poll() is not None:
|
||||
# Process terminated successfully
|
||||
return
|
||||
time.sleep(0.01) # Brief sleep to avoid busy-wait
|
||||
|
||||
# Timeout reached, force kill
|
||||
try:
|
||||
process.kill()
|
||||
# Brief wait to ensure kill completes
|
||||
time.sleep(0.1)
|
||||
process.poll()
|
||||
except Exception as e:
|
||||
logger.debug("Exception during subprocess kill: %s", e)
|
||||
except Exception as e:
|
||||
logger.debug("Exception during subprocess kill: %s", e)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Initialize ACP connection with agent.
|
||||
|
||||
Performs the ACP initialization handshake:
|
||||
1. Start ACPClient subprocess
|
||||
2. Send initialize request with protocol version
|
||||
3. Receive and validate initialize response
|
||||
4. Send session/new request
|
||||
5. Store session_id
|
||||
|
||||
Raises:
|
||||
ACPClientError: If initialization fails.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Build effective args, auto-adding ACP flags for known agents
|
||||
effective_args = list(self.agent_args)
|
||||
|
||||
# Gemini CLI requires --experimental-acp flag to enter ACP mode
|
||||
# Also add --yolo to auto-approve internal tool executions
|
||||
# And --allowed-tools to enable native Gemini tools
|
||||
agent_basename = os.path.basename(self.agent_command)
|
||||
if agent_basename == "gemini":
|
||||
if "--experimental-acp" not in effective_args:
|
||||
logger.info("Auto-adding --experimental-acp flag for Gemini CLI")
|
||||
effective_args.append("--experimental-acp")
|
||||
if "--yolo" not in effective_args:
|
||||
logger.info("Auto-adding --yolo flag for Gemini CLI tool execution")
|
||||
effective_args.append("--yolo")
|
||||
# Enable native Gemini tools for ACP mode
|
||||
# Note: Excluding write_file and run_shell_command - they have bugs in ACP mode
|
||||
# Gemini should fall back to ACP's fs/write_text_file and terminal/create
|
||||
if "--allowed-tools" not in effective_args:
|
||||
logger.info("Auto-adding --allowed-tools for Gemini CLI native tools")
|
||||
effective_args.extend([
|
||||
"--allowed-tools",
|
||||
"list_directory",
|
||||
"read_many_files",
|
||||
"read_file",
|
||||
"web_fetch",
|
||||
"google_web_search",
|
||||
])
|
||||
|
||||
# Create and start client
|
||||
self._client = ACPClient(
|
||||
command=self.agent_command,
|
||||
args=effective_args,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
await self._client.start()
|
||||
|
||||
# Register notification handler for session updates
|
||||
self._client.on_notification(self._handle_notification)
|
||||
|
||||
# Register request handler for permission requests
|
||||
self._client.on_request(self._handle_request)
|
||||
|
||||
try:
|
||||
# Send initialize request (per ACP spec)
|
||||
init_future = self._client.send_request(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": ACP_PROTOCOL_VERSION,
|
||||
"clientCapabilities": {
|
||||
"fs": {
|
||||
"readTextFile": True,
|
||||
"writeTextFile": True,
|
||||
},
|
||||
"terminal": True,
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "ralph-orchestrator",
|
||||
"title": "Ralph Orchestrator",
|
||||
"version": "1.2.3",
|
||||
},
|
||||
},
|
||||
)
|
||||
init_response = await asyncio.wait_for(init_future, timeout=self.timeout)
|
||||
|
||||
# Validate response
|
||||
if "protocolVersion" not in init_response:
|
||||
raise ACPClientError("Invalid initialize response: missing protocolVersion")
|
||||
|
||||
# Create new session (cwd and mcpServers are required per ACP spec)
|
||||
session_future = self._client.send_request(
|
||||
"session/new",
|
||||
{
|
||||
"cwd": os.getcwd(),
|
||||
"mcpServers": [], # No MCP servers by default
|
||||
},
|
||||
)
|
||||
session_response = await asyncio.wait_for(session_future, timeout=self.timeout)
|
||||
|
||||
# Store session ID
|
||||
self._session_id = session_response.get("sessionId")
|
||||
if not self._session_id:
|
||||
raise ACPClientError("Invalid session/new response: missing sessionId")
|
||||
|
||||
# Create session state tracker
|
||||
self._session = ACPSession(session_id=self._session_id)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await self._client.stop()
|
||||
raise ACPClientError("Initialization timed out")
|
||||
except Exception:
|
||||
await self._client.stop()
|
||||
raise
|
||||
|
||||
def _handle_notification(self, method: str, params: dict) -> None:
|
||||
"""Handle notifications from agent.
|
||||
|
||||
Args:
|
||||
method: Notification method name.
|
||||
params: Notification parameters.
|
||||
"""
|
||||
if method == "session/update" and self._session:
|
||||
# Handle both notification formats:
|
||||
# Format 1 (flat): {"kind": "agent_message_chunk", "content": "..."}
|
||||
# Format 2 (nested): {"update": {"sessionUpdate": "agent_message_chunk", "content": {...}}}
|
||||
if "update" in params:
|
||||
# Nested format (Gemini)
|
||||
update = params["update"]
|
||||
kind = update.get("sessionUpdate", "")
|
||||
content_obj = update.get("content")
|
||||
content = None
|
||||
flat_params = {"kind": kind, "content": content}
|
||||
allow_name_id = kind in ("tool_call", "tool_call_update")
|
||||
# Extract text content if it's an object
|
||||
if isinstance(content_obj, dict):
|
||||
if "text" in content_obj:
|
||||
content = content_obj.get("text", "")
|
||||
flat_params["content"] = content
|
||||
elif isinstance(content_obj, list):
|
||||
for entry in content_obj:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
self._merge_tool_fields(flat_params, entry, allow_name_id=allow_name_id)
|
||||
nested_tool_call = entry.get("toolCall") or entry.get("tool_call")
|
||||
if isinstance(nested_tool_call, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool_call,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
nested_tool = entry.get("tool")
|
||||
if isinstance(nested_tool, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
else:
|
||||
content = str(content_obj) if content_obj else ""
|
||||
flat_params["content"] = content
|
||||
self._merge_tool_fields(flat_params, update, allow_name_id=allow_name_id)
|
||||
if isinstance(content_obj, dict):
|
||||
self._merge_tool_fields(flat_params, content_obj, allow_name_id=allow_name_id)
|
||||
nested_tool_call = content_obj.get("toolCall") or content_obj.get("tool_call")
|
||||
if isinstance(nested_tool_call, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool_call,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
nested_tool = content_obj.get("tool")
|
||||
if isinstance(nested_tool, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
nested_tool_call = update.get("toolCall") or update.get("tool_call")
|
||||
if isinstance(nested_tool_call, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool_call,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
nested_tool = update.get("tool")
|
||||
if isinstance(nested_tool, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
payload = UpdatePayload.from_dict(flat_params)
|
||||
payload._raw = update
|
||||
payload._raw_flat = flat_params
|
||||
else:
|
||||
# Flat format
|
||||
payload = UpdatePayload.from_dict(params)
|
||||
payload._raw = params
|
||||
|
||||
# Stream to console if verbose; always show tool calls
|
||||
if self._current_verbose:
|
||||
self._stream_update(payload, show_details=True)
|
||||
elif payload.kind == "tool_call":
|
||||
self._stream_update(payload, show_details=False)
|
||||
|
||||
self._session.process_update(payload)
|
||||
|
||||
def _merge_tool_fields(
|
||||
self,
|
||||
target: dict,
|
||||
source: dict,
|
||||
*,
|
||||
allow_name_id: bool = False,
|
||||
) -> None:
|
||||
"""Merge tool call fields from source into target with alias support."""
|
||||
for canonical, aliases in self._TOOL_FIELD_ALIASES.items():
|
||||
if not allow_name_id and canonical in ("toolName", "toolCallId"):
|
||||
aliases = tuple(
|
||||
alias for alias in aliases if alias not in ("name", "id")
|
||||
)
|
||||
if canonical in target and target[canonical] not in (None, ""):
|
||||
continue
|
||||
for key in aliases:
|
||||
if key in source:
|
||||
value = source[key]
|
||||
if key == "tool" and canonical == "toolName":
|
||||
if isinstance(value, dict):
|
||||
value = value.get("name") or value.get("toolName") or value.get("tool_name")
|
||||
elif not isinstance(value, str):
|
||||
value = None
|
||||
if value is None or value == "":
|
||||
continue
|
||||
target[canonical] = value
|
||||
break
|
||||
|
||||
def _format_agent_label(self) -> str:
|
||||
"""Return the ACP agent command with arguments for display."""
|
||||
if not self.agent_args:
|
||||
return self.agent_command
|
||||
return " ".join([self.agent_command, *self.agent_args])
|
||||
|
||||
def _format_payload_value(self, value: object, limit: int = 200) -> str:
|
||||
"""Format payload values for console output."""
|
||||
if value is None:
|
||||
return ""
|
||||
value_str = str(value)
|
||||
if len(value_str) > limit:
|
||||
return value_str[: limit - 3] + "..."
|
||||
return value_str
|
||||
|
||||
def _format_payload_error(self, error: object) -> str:
|
||||
"""Extract a readable error string from ACP error payloads."""
|
||||
if error is None:
|
||||
return ""
|
||||
if isinstance(error, dict):
|
||||
message = error.get("message") or error.get("error") or error.get("detail")
|
||||
code = error.get("code")
|
||||
data = error.get("data")
|
||||
parts = []
|
||||
if message:
|
||||
parts.append(message)
|
||||
if code is not None:
|
||||
parts.append(f"code={code}")
|
||||
if data and not message:
|
||||
parts.append(str(data))
|
||||
if parts:
|
||||
return self._format_payload_value(" ".join(parts), limit=200)
|
||||
return self._format_payload_value(error, limit=200)
|
||||
|
||||
def _get_raw_payload(self, payload: UpdatePayload) -> dict | None:
|
||||
raw = getattr(payload, "_raw", None)
|
||||
return raw if isinstance(raw, dict) else None
|
||||
|
||||
def _extract_tool_field(self, raw: dict | None, key: str) -> object:
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
value = raw.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
for nested_key in ("toolCall", "tool_call", "tool"):
|
||||
nested = raw.get(nested_key)
|
||||
if isinstance(nested, dict) and key in nested:
|
||||
nested_value = nested.get(key)
|
||||
if nested_value not in (None, ""):
|
||||
return nested_value
|
||||
return None
|
||||
|
||||
def _extract_tool_name_from_meta(self, raw: dict | None) -> str | None:
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
meta = raw.get("_meta")
|
||||
if not isinstance(meta, dict):
|
||||
return None
|
||||
for key in ("codex", "claudeCode", "agent", "acp"):
|
||||
entry = meta.get(key)
|
||||
if isinstance(entry, dict):
|
||||
for name_key in ("toolName", "tool_name", "name", "tool"):
|
||||
value = entry.get(name_key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
for name_key in ("toolName", "tool_name", "name", "tool"):
|
||||
value = meta.get(name_key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
def _extract_tool_response(self, raw: dict | None) -> object:
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
meta = raw.get("_meta")
|
||||
if not isinstance(meta, dict):
|
||||
return None
|
||||
for key in ("codex", "claudeCode", "agent", "acp"):
|
||||
entry = meta.get(key)
|
||||
if isinstance(entry, dict) and "toolResponse" in entry:
|
||||
return entry.get("toolResponse")
|
||||
if "toolResponse" in meta:
|
||||
return meta.get("toolResponse")
|
||||
return None
|
||||
|
||||
def _stream_update(self, payload: UpdatePayload, show_details: bool = True) -> None:
|
||||
"""Stream session update to console.
|
||||
|
||||
Args:
|
||||
payload: The update payload to stream.
|
||||
show_details: Include detailed info (arguments, results, progress).
|
||||
"""
|
||||
kind = payload.kind
|
||||
|
||||
if kind == "agent_message_chunk":
|
||||
# Stream agent output text
|
||||
if payload.content:
|
||||
self._console.print_message(payload.content)
|
||||
|
||||
elif kind == "agent_thought_chunk":
|
||||
# Stream agent internal reasoning (dimmed)
|
||||
if payload.content:
|
||||
if self._console.console:
|
||||
self._console.console.print(
|
||||
f"[dim italic]{payload.content}[/dim italic]",
|
||||
end="",
|
||||
)
|
||||
else:
|
||||
print(payload.content, end="")
|
||||
|
||||
elif kind == "tool_call":
|
||||
# Show tool call start
|
||||
tool_name = payload.tool_name
|
||||
raw_update = self._get_raw_payload(payload)
|
||||
meta_tool_name = self._extract_tool_name_from_meta(raw_update)
|
||||
title = self._extract_tool_field(raw_update, "title")
|
||||
kind = self._extract_tool_field(raw_update, "kind")
|
||||
raw_input = self._extract_tool_field(raw_update, "rawInput")
|
||||
if raw_input is None:
|
||||
raw_input = self._extract_tool_field(raw_update, "input")
|
||||
tool_name = tool_name or meta_tool_name or title or kind or "unknown"
|
||||
tool_id = payload.tool_call_id or "unknown"
|
||||
self._console.print_separator()
|
||||
self._console.print_status(f"TOOL CALL: {tool_name}", style="cyan bold")
|
||||
self._console.print_info(f"ID: {tool_id[:12]}...")
|
||||
self._console.print_info(f"Agent: {self._format_agent_label()}")
|
||||
if show_details:
|
||||
if title and title != tool_name:
|
||||
self._console.print_info(f"Title: {title}")
|
||||
if kind:
|
||||
self._console.print_info(f"Kind: {kind}")
|
||||
if tool_name == "unknown":
|
||||
raw_str = self._format_payload_value(raw_update, limit=300)
|
||||
if raw_str:
|
||||
self._console.print_info(f"Update: {raw_str}")
|
||||
if payload.arguments or raw_input:
|
||||
input_value = payload.arguments or raw_input
|
||||
if isinstance(input_value, dict):
|
||||
self._console.print_info("Arguments:")
|
||||
for key, value in input_value.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:97] + "..."
|
||||
self._console.print_info(f" - {key}: {value_str}")
|
||||
else:
|
||||
input_str = self._format_payload_value(input_value, limit=300)
|
||||
if input_str:
|
||||
self._console.print_info(f"Input: {input_str}")
|
||||
|
||||
elif kind == "tool_call_update":
|
||||
if not show_details:
|
||||
return
|
||||
# Show tool call status update
|
||||
tool_id = payload.tool_call_id or "unknown"
|
||||
status = payload.status or "unknown"
|
||||
tool_name = payload.tool_name
|
||||
tool_args = None
|
||||
tool_call = None
|
||||
if self._session and payload.tool_call_id:
|
||||
tool_call = self._session.get_tool_call(payload.tool_call_id)
|
||||
if tool_call:
|
||||
tool_name = tool_name or tool_call.tool_name
|
||||
tool_args = tool_call.arguments or None
|
||||
raw_update = self._get_raw_payload(payload)
|
||||
meta_tool_name = self._extract_tool_name_from_meta(raw_update)
|
||||
title = self._extract_tool_field(raw_update, "title")
|
||||
kind = self._extract_tool_field(raw_update, "kind")
|
||||
raw_input = self._extract_tool_field(raw_update, "rawInput")
|
||||
raw_output = self._extract_tool_field(raw_update, "rawOutput")
|
||||
tool_name = tool_name or meta_tool_name or title
|
||||
display_name = tool_name or kind or "unknown"
|
||||
if display_name == "unknown":
|
||||
status_label = f"Tool call {tool_id[:12]}..."
|
||||
else:
|
||||
status_label = f"Tool {display_name} ({tool_id[:12]}...)"
|
||||
|
||||
if status == "completed":
|
||||
self._console.print_success(
|
||||
f"{status_label} completed"
|
||||
)
|
||||
result_value = payload.result
|
||||
if result_value is None and tool_call:
|
||||
result_value = tool_call.result
|
||||
if result_value is None:
|
||||
result_value = raw_output
|
||||
if result_value is None:
|
||||
result_value = self._extract_tool_response(raw_update)
|
||||
result_str = self._format_payload_value(result_value)
|
||||
if result_str:
|
||||
self._console.print_info(f"Result: {result_str}")
|
||||
elif status == "failed":
|
||||
self._console.print_error(
|
||||
f"{status_label} failed"
|
||||
)
|
||||
if display_name == "unknown":
|
||||
raw_str = self._format_payload_value(raw_update, limit=300)
|
||||
if raw_str:
|
||||
self._console.print_info(f"Update: {raw_str}")
|
||||
error_str = self._format_payload_error(payload.error)
|
||||
if not error_str and payload.result is not None:
|
||||
error_str = self._format_payload_value(payload.result)
|
||||
if not error_str and raw_output is not None:
|
||||
error_str = self._format_payload_value(raw_output)
|
||||
if not error_str:
|
||||
error_str = self._format_payload_value(
|
||||
self._extract_tool_response(raw_update)
|
||||
)
|
||||
if error_str:
|
||||
self._console.print_error(f"Error: {error_str}")
|
||||
if tool_args or raw_input:
|
||||
if tool_args is None:
|
||||
tool_args = raw_input
|
||||
self._console.print_info("Arguments:")
|
||||
if isinstance(tool_args, dict):
|
||||
for key, value in tool_args.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:97] + "..."
|
||||
self._console.print_info(f" - {key}: {value_str}")
|
||||
else:
|
||||
arg_str = self._format_payload_value(tool_args, limit=300)
|
||||
if arg_str:
|
||||
self._console.print_info(f" - {arg_str}")
|
||||
elif status == "running":
|
||||
self._console.print_status(
|
||||
f"{status_label} running",
|
||||
style="yellow",
|
||||
)
|
||||
progress_value = payload.result or payload.content
|
||||
if progress_value is None:
|
||||
progress_value = raw_output
|
||||
progress_str = self._format_payload_value(progress_value, limit=200)
|
||||
if progress_str:
|
||||
self._console.print_info(f"Progress: {progress_str}")
|
||||
else:
|
||||
self._console.print_status(
|
||||
f"{status_label} {status}",
|
||||
style="yellow",
|
||||
)
|
||||
if title and title != display_name:
|
||||
self._console.print_info(f"Title: {title}")
|
||||
if kind:
|
||||
self._console.print_info(f"Kind: {kind}")
|
||||
|
||||
def _handle_request(self, method: str, params: dict) -> dict:
|
||||
"""Handle requests from agent.
|
||||
|
||||
Routes requests to appropriate handlers:
|
||||
- session/request_permission: Permission checks
|
||||
- fs/read_text_file: File read operations
|
||||
- fs/write_text_file: File write operations
|
||||
- terminal/*: Terminal operations
|
||||
|
||||
Args:
|
||||
method: Request method name.
|
||||
params: Request parameters.
|
||||
|
||||
Returns:
|
||||
Response result dict.
|
||||
"""
|
||||
logger.info("ACP REQUEST: method=%s", method)
|
||||
if method == "session/request_permission":
|
||||
# Permission handler already returns ACP-compliant format
|
||||
return self._handle_permission_request(params)
|
||||
|
||||
# File operations - return raw result (client wraps in JSON-RPC)
|
||||
if method == "fs/read_text_file":
|
||||
return self._handlers.handle_read_file(params)
|
||||
if method == "fs/write_text_file":
|
||||
return self._handlers.handle_write_file(params)
|
||||
|
||||
# Terminal operations - return raw result (client wraps in JSON-RPC)
|
||||
if method == "terminal/create":
|
||||
return self._handlers.handle_terminal_create(params)
|
||||
if method == "terminal/output":
|
||||
return self._handlers.handle_terminal_output(params)
|
||||
if method == "terminal/wait_for_exit":
|
||||
return self._handlers.handle_terminal_wait_for_exit(params)
|
||||
if method == "terminal/kill":
|
||||
return self._handlers.handle_terminal_kill(params)
|
||||
if method == "terminal/release":
|
||||
return self._handlers.handle_terminal_release(params)
|
||||
|
||||
# Unknown request - log and return error
|
||||
logger.warning("Unknown ACP request method: %s with params: %s", method, params)
|
||||
return {"error": {"code": -32601, "message": f"Method not found: {method}"}}
|
||||
|
||||
def _handle_permission_request(self, params: dict) -> dict:
|
||||
"""Handle permission request from agent.
|
||||
|
||||
Delegates to ACPHandlers which supports multiple modes:
|
||||
- auto_approve: Always approve
|
||||
- deny_all: Always deny
|
||||
- allowlist: Check against configured patterns
|
||||
- interactive: Prompt user (if terminal available)
|
||||
|
||||
Args:
|
||||
params: Permission request parameters.
|
||||
|
||||
Returns:
|
||||
Response with approved: True/False.
|
||||
"""
|
||||
return self._handlers.handle_request_permission(params)
|
||||
|
||||
def _log_permission(self, message: str) -> None:
|
||||
"""Log permission decision.
|
||||
|
||||
Args:
|
||||
message: Permission decision message.
|
||||
"""
|
||||
logger.info(message)
|
||||
|
||||
def get_permission_history(self) -> list:
|
||||
"""Get permission decision history.
|
||||
|
||||
Returns:
|
||||
List of (request, result) tuples.
|
||||
"""
|
||||
return self._handlers.get_history()
|
||||
|
||||
def get_permission_stats(self) -> dict:
|
||||
"""Get permission decision statistics.
|
||||
|
||||
Returns:
|
||||
Dict with approved_count and denied_count.
|
||||
"""
|
||||
return {
|
||||
"approved_count": self._handlers.get_approved_count(),
|
||||
"denied_count": self._handlers.get_denied_count(),
|
||||
}
|
||||
|
||||
async def _execute_prompt(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute a prompt through the ACP agent.
|
||||
|
||||
Sends session/prompt request with messages array and waits for response.
|
||||
Session updates (streaming output, thoughts, tool calls) are processed
|
||||
through _handle_notification during the request.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to execute.
|
||||
**kwargs: Additional arguments (verbose: bool).
|
||||
|
||||
Returns:
|
||||
ToolResponse with execution result.
|
||||
"""
|
||||
# Get verbose from kwargs (per-call override) without mutating instance state
|
||||
verbose = kwargs.get("verbose", self.verbose)
|
||||
# Store for use in _handle_notification during this request
|
||||
self._current_verbose = verbose
|
||||
|
||||
# Reset session state for new prompt (preserve session_id)
|
||||
if self._session:
|
||||
self._session.reset()
|
||||
|
||||
# Print header if verbose
|
||||
if verbose:
|
||||
self._console.print_header(f"ACP AGENT ({self.agent_command})")
|
||||
self._console.print_status("Processing prompt...")
|
||||
|
||||
# Build prompt array per ACP spec (ContentBlock format)
|
||||
prompt_blocks = [{"type": "text", "text": prompt}]
|
||||
|
||||
# Send session/prompt request
|
||||
try:
|
||||
prompt_future = self._client.send_request(
|
||||
"session/prompt",
|
||||
{
|
||||
"sessionId": self._session_id,
|
||||
"prompt": prompt_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
# Wait for response with timeout
|
||||
response = await asyncio.wait_for(prompt_future, timeout=self.timeout)
|
||||
|
||||
# Check for error stop reason
|
||||
stop_reason = response.get("stopReason", "unknown")
|
||||
if stop_reason == "error":
|
||||
error_obj = response.get("error", {})
|
||||
error_msg = error_obj.get("message", "Unknown error from agent")
|
||||
if verbose:
|
||||
self._console.print_separator()
|
||||
self._console.print_error(f"Agent error: {error_msg}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=self._session.output if self._session else "",
|
||||
error=error_msg,
|
||||
metadata={
|
||||
"tool": "acp",
|
||||
"agent": self.agent_command,
|
||||
"session_id": self._session_id,
|
||||
"stop_reason": stop_reason,
|
||||
},
|
||||
)
|
||||
|
||||
# Build successful response
|
||||
output = self._session.output if self._session else ""
|
||||
if verbose:
|
||||
self._console.print_separator()
|
||||
tool_count = len(self._session.tool_calls) if self._session else 0
|
||||
self._console.print_success(f"Agent completed (tools: {tool_count})")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=output,
|
||||
metadata={
|
||||
"tool": "acp",
|
||||
"agent": self.agent_command,
|
||||
"session_id": self._session_id,
|
||||
"stop_reason": stop_reason,
|
||||
"tool_calls_count": len(self._session.tool_calls) if self._session else 0,
|
||||
"has_thoughts": bool(self._session.thoughts) if self._session else False,
|
||||
},
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if verbose:
|
||||
self._console.print_separator()
|
||||
self._console.print_error(f"Timeout after {self.timeout}s")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=self._session.output if self._session else "",
|
||||
error=f"Prompt execution timed out after {self.timeout} seconds",
|
||||
metadata={
|
||||
"tool": "acp",
|
||||
"agent": self.agent_command,
|
||||
"session_id": self._session_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def _shutdown(self) -> None:
|
||||
"""Shutdown the ACP connection.
|
||||
|
||||
Stops the client and cleans up state.
|
||||
"""
|
||||
# Kill all running terminals first
|
||||
if self._handlers:
|
||||
for terminal_id in list(self._handlers._terminals.keys()):
|
||||
try:
|
||||
self._handlers.handle_terminal_kill({"terminalId": terminal_id})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to kill terminal %s: %s", terminal_id, e)
|
||||
|
||||
if self._client:
|
||||
await self._client.stop()
|
||||
self._client = None
|
||||
|
||||
self._initialized = False
|
||||
self._session_id = None
|
||||
self._session = None
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute the prompt synchronously.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to execute.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
ToolResponse with execution result.
|
||||
"""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"ACP adapter not available: {self.agent_command} not found",
|
||||
)
|
||||
|
||||
# Run async method in new event loop
|
||||
try:
|
||||
return asyncio.run(self.aexecute(prompt, **kwargs))
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute the prompt asynchronously.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to execute.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
ToolResponse with execution result.
|
||||
"""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"ACP adapter not available: {self.agent_command} not found",
|
||||
)
|
||||
|
||||
try:
|
||||
# Initialize if needed
|
||||
if not self._initialized:
|
||||
await self._initialize()
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Execute prompt
|
||||
return await self._execute_prompt(enhanced_prompt, **kwargs)
|
||||
|
||||
except ACPClientError as e:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"ACP error: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Estimate execution cost.
|
||||
|
||||
ACP doesn't provide billing information, so returns 0.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to estimate.
|
||||
|
||||
Returns:
|
||||
Always 0.0 (no billing info from ACP).
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup on deletion."""
|
||||
self._restore_signal_handlers()
|
||||
|
||||
# Best-effort cleanup
|
||||
if self._client:
|
||||
try:
|
||||
self.kill_subprocess_sync()
|
||||
except Exception as e:
|
||||
logger.debug("Exception during cleanup in __del__: %s", e)
|
||||
@@ -0,0 +1,352 @@
|
||||
# ABOUTME: ACPClient manages subprocess lifecycle for ACP agents
|
||||
# ABOUTME: Handles async message routing, request tracking, and graceful shutdown
|
||||
|
||||
"""ACPClient subprocess manager for ACP (Agent Client Protocol)."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from .acp_protocol import ACPProtocol, MessageType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default asyncio StreamReader line limit (in bytes) for ACP agent stdout/stderr.
|
||||
# Some ACP agents can emit large single-line JSON-RPC frames (e.g., big tool payloads).
|
||||
# The asyncio default (~64KiB) can raise LimitOverrunError and break the ACP session.
|
||||
DEFAULT_ACP_STREAM_LIMIT = 8 * 1024 * 1024 # 8 MiB
|
||||
|
||||
|
||||
class ACPClientError(Exception):
|
||||
"""Exception raised by ACPClient operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ACPClient:
|
||||
"""Manages subprocess lifecycle and async message routing for ACP agents.
|
||||
|
||||
Spawns an agent subprocess, handles JSON-RPC message serialization,
|
||||
routes responses to pending requests, and invokes callbacks for
|
||||
notifications and incoming requests.
|
||||
|
||||
Attributes:
|
||||
command: The command to spawn the agent.
|
||||
args: Additional command-line arguments.
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: str,
|
||||
args: Optional[list[str]] = None,
|
||||
timeout: int = 300,
|
||||
stream_limit: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize ACPClient.
|
||||
|
||||
Args:
|
||||
command: The command to spawn the agent (e.g., "gemini").
|
||||
args: Additional command-line arguments.
|
||||
timeout: Request timeout in seconds (default: 300).
|
||||
"""
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
self.timeout = timeout
|
||||
# Limit used by asyncio for readline()/readuntil(); must be > max ACP frame line length.
|
||||
env_limit = os.environ.get("RALPH_ACP_STREAM_LIMIT")
|
||||
if stream_limit is not None:
|
||||
self.stream_limit = stream_limit
|
||||
elif env_limit:
|
||||
try:
|
||||
self.stream_limit = int(env_limit)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Invalid RALPH_ACP_STREAM_LIMIT=%r; using default %d",
|
||||
env_limit,
|
||||
DEFAULT_ACP_STREAM_LIMIT,
|
||||
)
|
||||
self.stream_limit = DEFAULT_ACP_STREAM_LIMIT
|
||||
else:
|
||||
self.stream_limit = DEFAULT_ACP_STREAM_LIMIT
|
||||
|
||||
self._protocol = ACPProtocol()
|
||||
self._process: Optional[asyncio.subprocess.Process] = None
|
||||
self._read_task: Optional[asyncio.Task] = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
|
||||
# Pending requests: id -> Future
|
||||
self._pending_requests: dict[int, asyncio.Future] = {}
|
||||
|
||||
# Notification handlers
|
||||
self._notification_handlers: list[Callable[[str, dict], None]] = []
|
||||
|
||||
# Request handlers (for incoming requests from agent)
|
||||
self._request_handlers: list[Callable[[str, dict], Any]] = []
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if subprocess is running.
|
||||
|
||||
Returns:
|
||||
True if subprocess is running, False otherwise.
|
||||
"""
|
||||
return self._process is not None and self._process.returncode is None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the agent subprocess.
|
||||
|
||||
Spawns the subprocess with stdin/stdout/stderr pipes and starts
|
||||
the read loop task.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If already running.
|
||||
FileNotFoundError: If command not found.
|
||||
"""
|
||||
if self.is_running:
|
||||
raise RuntimeError("ACPClient is already running")
|
||||
|
||||
cmd = [self.command] + self.args
|
||||
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
limit=self.stream_limit,
|
||||
)
|
||||
|
||||
# Start the read loop
|
||||
self._read_task = asyncio.create_task(self._read_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the agent subprocess.
|
||||
|
||||
Terminates the subprocess gracefully with 2 second timeout, then kills if necessary.
|
||||
Cancels the read loop task and all pending requests.
|
||||
"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Cancel read loop first
|
||||
if self._read_task and not self._read_task.done():
|
||||
self._read_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._read_task, timeout=0.5)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Read task cancelled during shutdown")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Read task cancellation timed out")
|
||||
|
||||
# Terminate subprocess with 2 second timeout
|
||||
if self._process:
|
||||
try:
|
||||
self._process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
# Force kill if graceful termination fails
|
||||
logger.warning("Process did not terminate gracefully, killing")
|
||||
self._process.kill()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Process did not die after kill signal")
|
||||
except ProcessLookupError:
|
||||
logger.debug("Process already terminated")
|
||||
|
||||
self._process = None
|
||||
self._read_task = None
|
||||
|
||||
# Cancel all pending requests
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
async def _read_loop(self) -> None:
|
||||
"""Continuously read stdout and route messages.
|
||||
|
||||
Reads newline-delimited JSON-RPC messages from subprocess stdout.
|
||||
"""
|
||||
if not self._process or not self._process.stdout:
|
||||
return
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
line = await self._process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
|
||||
message_str = line.decode().strip()
|
||||
if message_str:
|
||||
await self._handle_message(message_str)
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected during shutdown
|
||||
except Exception as e:
|
||||
logger.error("ACP read loop failed: %s", e, exc_info=True)
|
||||
finally:
|
||||
# Cancel all pending requests when read loop exits (subprocess died or cancelled)
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(ACPClientError("Agent subprocess terminated"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
async def _handle_message(self, message_str: str) -> None:
|
||||
"""Handle a received JSON-RPC message.
|
||||
|
||||
Routes message to appropriate handler based on type.
|
||||
|
||||
Args:
|
||||
message_str: Raw JSON string.
|
||||
"""
|
||||
parsed = self._protocol.parse_message(message_str)
|
||||
msg_type = parsed.get("type")
|
||||
|
||||
if msg_type == MessageType.RESPONSE:
|
||||
# Route to pending request
|
||||
request_id = parsed.get("id")
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(parsed.get("result"))
|
||||
|
||||
elif msg_type == MessageType.ERROR:
|
||||
# Route error to pending request
|
||||
request_id = parsed.get("id")
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
error = parsed.get("error", {})
|
||||
error_msg = error.get("message", "Unknown error")
|
||||
future.set_exception(ACPClientError(error_msg))
|
||||
|
||||
elif msg_type == MessageType.NOTIFICATION:
|
||||
# Invoke notification handlers
|
||||
method = parsed.get("method", "")
|
||||
params = parsed.get("params", {})
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
handler(method, params)
|
||||
except Exception as e:
|
||||
logger.error("Notification handler failed for method=%s: %s", method, e, exc_info=True)
|
||||
|
||||
elif msg_type == MessageType.REQUEST:
|
||||
# Invoke request handlers and send response
|
||||
request_id = parsed.get("id")
|
||||
method = parsed.get("method", "")
|
||||
params = parsed.get("params", {})
|
||||
|
||||
for handler in self._request_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(method, params)
|
||||
else:
|
||||
result = handler(method, params)
|
||||
|
||||
# Check if handler returned an error (dict with "error" key)
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
error_info = result["error"]
|
||||
error_code = error_info.get("code", -32603) if isinstance(error_info, dict) else -32603
|
||||
error_msg = error_info.get("message", str(error_info)) if isinstance(error_info, dict) else str(error_info)
|
||||
response = self._protocol.create_error_response(request_id, error_code, error_msg)
|
||||
else:
|
||||
response = self._protocol.create_response(request_id, result)
|
||||
await self._write_message(response)
|
||||
break # Only first handler responds
|
||||
except Exception as e:
|
||||
# Send error response
|
||||
error_response = self._protocol.create_error_response(
|
||||
request_id, -32603, str(e)
|
||||
)
|
||||
await self._write_message(error_response)
|
||||
break
|
||||
|
||||
async def _write_message(self, message: str) -> None:
|
||||
"""Write a JSON-RPC message to subprocess stdin.
|
||||
|
||||
Args:
|
||||
message: JSON string to write.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not running.
|
||||
"""
|
||||
if not self.is_running or not self._process or not self._process.stdin:
|
||||
raise RuntimeError("ACPClient is not running")
|
||||
|
||||
async with self._write_lock:
|
||||
self._process.stdin.write((message + "\n").encode())
|
||||
await self._process.stdin.drain()
|
||||
|
||||
async def _do_send(self, request_id: int, message: str) -> None:
|
||||
"""Helper to send message and handle write errors.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
message: The JSON-RPC message to send.
|
||||
"""
|
||||
try:
|
||||
await self._write_message(message)
|
||||
except Exception as e:
|
||||
future = self._pending_requests.pop(request_id, None)
|
||||
if future and not future.done():
|
||||
future.set_exception(ACPClientError(f"Failed to send request: {e}"))
|
||||
|
||||
def send_request(
|
||||
self, method: str, params: dict[str, Any]
|
||||
) -> asyncio.Future[Any]:
|
||||
"""Send a JSON-RPC request and return Future for response.
|
||||
|
||||
Args:
|
||||
method: The RPC method name.
|
||||
params: The request parameters.
|
||||
|
||||
Returns:
|
||||
Future that resolves with the response result.
|
||||
"""
|
||||
request_id, message = self._protocol.create_request(method, params)
|
||||
|
||||
# Create future for response
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Any] = loop.create_future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
# Schedule write with error handling
|
||||
asyncio.create_task(self._do_send(request_id, message))
|
||||
|
||||
return future
|
||||
|
||||
async def send_notification(
|
||||
self, method: str, params: dict[str, Any]
|
||||
) -> None:
|
||||
"""Send a JSON-RPC notification (no response expected).
|
||||
|
||||
Args:
|
||||
method: The notification method name.
|
||||
params: The notification parameters.
|
||||
"""
|
||||
message = self._protocol.create_notification(method, params)
|
||||
await self._write_message(message)
|
||||
|
||||
def on_notification(
|
||||
self, handler: Callable[[str, dict], None]
|
||||
) -> None:
|
||||
"""Register a notification handler.
|
||||
|
||||
Args:
|
||||
handler: Callback invoked with (method, params) for each notification.
|
||||
"""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
def on_request(
|
||||
self, handler: Callable[[str, dict], Any]
|
||||
) -> None:
|
||||
"""Register a request handler for incoming requests from agent.
|
||||
|
||||
Handler should return the response result. Can be sync or async.
|
||||
|
||||
Args:
|
||||
handler: Callback invoked with (method, params), returns result.
|
||||
"""
|
||||
self._request_handlers.append(handler)
|
||||
@@ -0,0 +1,889 @@
|
||||
# ABOUTME: ACP handlers for permission requests and file/terminal operations
|
||||
# ABOUTME: Provides permission_mode handling (auto_approve, deny_all, allowlist, interactive)
|
||||
# ABOUTME: Implements fs/read_text_file and fs/write_text_file handlers with security
|
||||
# ABOUTME: Implements terminal/* handlers for command execution
|
||||
|
||||
"""ACP Handlers for permission requests and agent-to-host operations.
|
||||
|
||||
This module provides the ACPHandlers class which manages permission requests
|
||||
from ACP-compliant agents and handles file operations. It supports:
|
||||
|
||||
Permission modes:
|
||||
- auto_approve: Approve all requests automatically
|
||||
- deny_all: Deny all requests
|
||||
- allowlist: Only approve requests matching configured patterns
|
||||
- interactive: Prompt user for each request (requires terminal)
|
||||
|
||||
File operations:
|
||||
- fs/read_text_file: Read file content with security validation
|
||||
- fs/write_text_file: Write file content with security validation
|
||||
|
||||
Terminal operations:
|
||||
- terminal/create: Create a new terminal with command
|
||||
- terminal/output: Read output from a terminal
|
||||
- terminal/wait_for_exit: Wait for terminal process to exit
|
||||
- terminal/kill: Kill a terminal process
|
||||
- terminal/release: Release terminal resources
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Terminal:
|
||||
"""Represents a terminal subprocess.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the terminal.
|
||||
process: The subprocess.Popen instance.
|
||||
output_buffer: Accumulated output from stdout/stderr.
|
||||
"""
|
||||
|
||||
id: str
|
||||
process: subprocess.Popen
|
||||
output_buffer: str = ""
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the process is still running."""
|
||||
return self.process.poll() is None
|
||||
|
||||
@property
|
||||
def exit_code(self) -> Optional[int]:
|
||||
"""Get the exit code if process has exited."""
|
||||
return self.process.poll()
|
||||
|
||||
def read_output(self) -> str:
|
||||
"""Read any available output without blocking.
|
||||
|
||||
Returns:
|
||||
New output since last read.
|
||||
"""
|
||||
import select
|
||||
|
||||
new_output = ""
|
||||
|
||||
# Try to read from stdout and stderr
|
||||
for stream in [self.process.stdout, self.process.stderr]:
|
||||
if stream is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Non-blocking read using select
|
||||
while True:
|
||||
ready, _, _ = select.select([stream], [], [], 0)
|
||||
if not ready:
|
||||
break
|
||||
chunk = stream.read(4096)
|
||||
if chunk:
|
||||
new_output += chunk
|
||||
else:
|
||||
break
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Error reading terminal output: %s", e)
|
||||
break
|
||||
|
||||
self.output_buffer += new_output
|
||||
return new_output
|
||||
|
||||
def kill(self) -> None:
|
||||
"""Kill the subprocess."""
|
||||
if self.is_running:
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
self.process.wait()
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> int:
|
||||
"""Wait for the process to exit.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds.
|
||||
|
||||
Returns:
|
||||
Exit code of the process.
|
||||
|
||||
Raises:
|
||||
subprocess.TimeoutExpired: If timeout is reached.
|
||||
"""
|
||||
return self.process.wait(timeout=timeout)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermissionRequest:
|
||||
"""Represents a permission request from an agent.
|
||||
|
||||
Attributes:
|
||||
operation: The operation being requested (e.g., 'fs/read_text_file').
|
||||
path: Optional path for filesystem operations.
|
||||
command: Optional command for terminal operations.
|
||||
arguments: Full arguments dict from the request.
|
||||
"""
|
||||
|
||||
operation: str
|
||||
path: Optional[str] = None
|
||||
command: Optional[str] = None
|
||||
arguments: dict = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, params: dict) -> "PermissionRequest":
|
||||
"""Create PermissionRequest from request parameters.
|
||||
|
||||
Args:
|
||||
params: Permission request parameters from agent.
|
||||
|
||||
Returns:
|
||||
Parsed PermissionRequest instance.
|
||||
"""
|
||||
return cls(
|
||||
operation=params.get("operation", ""),
|
||||
path=params.get("path"),
|
||||
command=params.get("command"),
|
||||
arguments=params,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermissionResult:
|
||||
"""Result of a permission decision.
|
||||
|
||||
Attributes:
|
||||
approved: Whether the request was approved.
|
||||
reason: Optional reason for the decision.
|
||||
mode: Permission mode that made the decision.
|
||||
"""
|
||||
|
||||
approved: bool
|
||||
reason: Optional[str] = None
|
||||
mode: str = "unknown"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to ACP response format.
|
||||
|
||||
Returns:
|
||||
Dict with 'approved' key for ACP response.
|
||||
"""
|
||||
return {"approved": self.approved}
|
||||
|
||||
|
||||
class ACPHandlers:
|
||||
"""Handles ACP permission requests with configurable modes.
|
||||
|
||||
Supports four permission modes:
|
||||
- auto_approve: Always approve (useful for trusted environments)
|
||||
- deny_all: Always deny (useful for testing)
|
||||
- allowlist: Only approve operations matching configured patterns
|
||||
- interactive: Prompt user for each request
|
||||
|
||||
Attributes:
|
||||
permission_mode: Current permission mode.
|
||||
allowlist: List of allowed operation patterns.
|
||||
on_permission_log: Optional callback for logging decisions.
|
||||
"""
|
||||
|
||||
# Valid permission modes
|
||||
VALID_MODES = ("auto_approve", "deny_all", "allowlist", "interactive")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
permission_mode: str = "auto_approve",
|
||||
permission_allowlist: Optional[list[str]] = None,
|
||||
on_permission_log: Optional[Callable[[str], None]] = None,
|
||||
) -> None:
|
||||
"""Initialize ACPHandlers.
|
||||
|
||||
Args:
|
||||
permission_mode: Permission handling mode (default: auto_approve).
|
||||
permission_allowlist: List of allowed operation patterns for allowlist mode.
|
||||
on_permission_log: Optional callback for logging permission decisions.
|
||||
|
||||
Raises:
|
||||
ValueError: If permission_mode is not valid.
|
||||
"""
|
||||
if permission_mode not in self.VALID_MODES:
|
||||
raise ValueError(
|
||||
f"Invalid permission_mode: {permission_mode}. "
|
||||
f"Must be one of: {', '.join(self.VALID_MODES)}"
|
||||
)
|
||||
|
||||
self.permission_mode = permission_mode
|
||||
self.allowlist = permission_allowlist or []
|
||||
self.on_permission_log = on_permission_log
|
||||
|
||||
# Track permission history for debugging
|
||||
self._history: list[tuple[PermissionRequest, PermissionResult]] = []
|
||||
|
||||
# Track active terminals
|
||||
self._terminals: dict[str, Terminal] = {}
|
||||
|
||||
def handle_request_permission(self, params: dict) -> dict:
|
||||
"""Handle a permission request from an agent.
|
||||
|
||||
Returns ACP-compliant response with nested outcome structure.
|
||||
The agent handles tool execution after receiving permission.
|
||||
|
||||
Args:
|
||||
params: Permission request parameters including options list.
|
||||
|
||||
Returns:
|
||||
Dict with result.outcome.outcome (selected/cancelled) and optionId.
|
||||
"""
|
||||
request = PermissionRequest.from_params(params)
|
||||
result = self._evaluate_permission(request)
|
||||
|
||||
# Log the decision
|
||||
self._log_decision(request, result)
|
||||
|
||||
# Store in history
|
||||
self._history.append((request, result))
|
||||
|
||||
# Extract options from params to find the appropriate optionId
|
||||
options = params.get("options", [])
|
||||
|
||||
if result.approved:
|
||||
# Find first "allow" option to use as optionId
|
||||
selected_option_id = None
|
||||
for option in options:
|
||||
if option.get("type") == "allow":
|
||||
selected_option_id = option.get("id")
|
||||
break
|
||||
|
||||
# Fallback to first option if no "allow" type found
|
||||
if not selected_option_id and options:
|
||||
selected_option_id = options[0].get("id", "proceed_once")
|
||||
elif not selected_option_id:
|
||||
# Default if no options provided
|
||||
selected_option_id = "proceed_once"
|
||||
|
||||
# Return raw result (client wraps in JSON-RPC response)
|
||||
return {
|
||||
"outcome": {
|
||||
"outcome": "selected",
|
||||
"optionId": selected_option_id
|
||||
}
|
||||
}
|
||||
else:
|
||||
# Permission denied - return cancelled outcome
|
||||
return {
|
||||
"outcome": {
|
||||
"outcome": "cancelled"
|
||||
}
|
||||
}
|
||||
|
||||
def _evaluate_permission(self, request: PermissionRequest) -> PermissionResult:
|
||||
"""Evaluate a permission request based on current mode.
|
||||
|
||||
Args:
|
||||
request: The permission request to evaluate.
|
||||
|
||||
Returns:
|
||||
PermissionResult with decision and reason.
|
||||
"""
|
||||
if self.permission_mode == "auto_approve":
|
||||
return PermissionResult(
|
||||
approved=True,
|
||||
reason="auto_approve mode",
|
||||
mode="auto_approve",
|
||||
)
|
||||
|
||||
if self.permission_mode == "deny_all":
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="deny_all mode",
|
||||
mode="deny_all",
|
||||
)
|
||||
|
||||
if self.permission_mode == "allowlist":
|
||||
return self._evaluate_allowlist(request)
|
||||
|
||||
if self.permission_mode == "interactive":
|
||||
return self._evaluate_interactive(request)
|
||||
|
||||
# Fallback - should not reach here
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="unknown mode",
|
||||
mode="unknown",
|
||||
)
|
||||
|
||||
def _evaluate_allowlist(self, request: PermissionRequest) -> PermissionResult:
|
||||
"""Evaluate permission against allowlist patterns.
|
||||
|
||||
Patterns can be:
|
||||
- Exact match: 'fs/read_text_file'
|
||||
- Glob pattern: 'fs/*' (matches any fs operation)
|
||||
- Regex pattern: '/^terminal\\/.*$/' (surrounded by slashes)
|
||||
|
||||
Args:
|
||||
request: The permission request to evaluate.
|
||||
|
||||
Returns:
|
||||
PermissionResult with decision.
|
||||
"""
|
||||
operation = request.operation
|
||||
|
||||
for pattern in self.allowlist:
|
||||
if self._matches_pattern(operation, pattern):
|
||||
return PermissionResult(
|
||||
approved=True,
|
||||
reason=f"matches allowlist pattern: {pattern}",
|
||||
mode="allowlist",
|
||||
)
|
||||
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="no matching allowlist pattern",
|
||||
mode="allowlist",
|
||||
)
|
||||
|
||||
def _matches_pattern(self, operation: str, pattern: str) -> bool:
|
||||
"""Check if an operation matches a pattern.
|
||||
|
||||
Args:
|
||||
operation: The operation name to check.
|
||||
pattern: Pattern to match against.
|
||||
|
||||
Returns:
|
||||
True if operation matches pattern.
|
||||
"""
|
||||
# Check for regex pattern (surrounded by slashes)
|
||||
if pattern.startswith("/") and pattern.endswith("/"):
|
||||
try:
|
||||
regex_pattern = pattern[1:-1]
|
||||
return bool(re.match(regex_pattern, operation))
|
||||
except re.error as e:
|
||||
logger.warning("Invalid regex pattern '%s' in permission allowlist: %s", pattern, e)
|
||||
return False
|
||||
|
||||
# Check for glob pattern
|
||||
if "*" in pattern or "?" in pattern:
|
||||
return fnmatch.fnmatch(operation, pattern)
|
||||
|
||||
# Exact match
|
||||
return operation == pattern
|
||||
|
||||
def _evaluate_interactive(self, request: PermissionRequest) -> PermissionResult:
|
||||
"""Evaluate permission interactively by prompting user.
|
||||
|
||||
Falls back to deny_all if no terminal is available.
|
||||
|
||||
Args:
|
||||
request: The permission request to evaluate.
|
||||
|
||||
Returns:
|
||||
PermissionResult with user's decision.
|
||||
"""
|
||||
# Check if we have a terminal
|
||||
if not sys.stdin.isatty():
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="no terminal available for interactive mode",
|
||||
mode="interactive",
|
||||
)
|
||||
|
||||
# Format the prompt
|
||||
prompt = self._format_interactive_prompt(request)
|
||||
|
||||
try:
|
||||
print(prompt, file=sys.stderr)
|
||||
response = input("[y/N]: ").strip().lower()
|
||||
|
||||
if response in ("y", "yes"):
|
||||
return PermissionResult(
|
||||
approved=True,
|
||||
reason="user approved",
|
||||
mode="interactive",
|
||||
)
|
||||
else:
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="user denied",
|
||||
mode="interactive",
|
||||
)
|
||||
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="input interrupted",
|
||||
mode="interactive",
|
||||
)
|
||||
|
||||
def _format_interactive_prompt(self, request: PermissionRequest) -> str:
|
||||
"""Format an interactive permission prompt.
|
||||
|
||||
Args:
|
||||
request: The permission request to display.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string.
|
||||
"""
|
||||
lines = [
|
||||
"",
|
||||
"=" * 60,
|
||||
f"Permission Request: {request.operation}",
|
||||
"=" * 60,
|
||||
]
|
||||
|
||||
if request.path:
|
||||
lines.append(f" Path: {request.path}")
|
||||
if request.command:
|
||||
lines.append(f" Command: {request.command}")
|
||||
|
||||
# Add other arguments
|
||||
for key, value in request.arguments.items():
|
||||
if key not in ("operation", "path", "command"):
|
||||
lines.append(f" {key}: {value}")
|
||||
|
||||
lines.extend([
|
||||
"=" * 60,
|
||||
"Approve this operation?",
|
||||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _log_decision(
|
||||
self, request: PermissionRequest, result: PermissionResult
|
||||
) -> None:
|
||||
"""Log a permission decision.
|
||||
|
||||
Args:
|
||||
request: The permission request.
|
||||
result: The permission decision.
|
||||
"""
|
||||
if self.on_permission_log:
|
||||
status = "APPROVED" if result.approved else "DENIED"
|
||||
message = (
|
||||
f"Permission {status}: {request.operation} "
|
||||
f"[mode={result.mode}, reason={result.reason}]"
|
||||
)
|
||||
self.on_permission_log(message)
|
||||
|
||||
def get_history(self) -> list[tuple[PermissionRequest, PermissionResult]]:
|
||||
"""Get permission decision history.
|
||||
|
||||
Returns:
|
||||
List of (request, result) tuples.
|
||||
"""
|
||||
return self._history.copy()
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear permission decision history."""
|
||||
self._history.clear()
|
||||
|
||||
def get_approved_count(self) -> int:
|
||||
"""Get count of approved permissions.
|
||||
|
||||
Returns:
|
||||
Number of approved permission requests.
|
||||
"""
|
||||
return sum(1 for _, result in self._history if result.approved)
|
||||
|
||||
def get_denied_count(self) -> int:
|
||||
"""Get count of denied permissions.
|
||||
|
||||
Returns:
|
||||
Number of denied permission requests.
|
||||
"""
|
||||
return sum(1 for _, result in self._history if not result.approved)
|
||||
|
||||
# =========================================================================
|
||||
# File Operation Handlers
|
||||
# =========================================================================
|
||||
|
||||
def handle_read_file(self, params: dict) -> dict:
|
||||
"""Handle fs/read_text_file request from agent.
|
||||
|
||||
Reads file content with security validation to prevent path traversal.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'path' key.
|
||||
|
||||
Returns:
|
||||
Dict with 'content' on success, or 'error' on failure.
|
||||
"""
|
||||
path_str = params.get("path")
|
||||
|
||||
if not path_str:
|
||||
return {"error": {"code": -32602, "message": "Missing required parameter: path"}}
|
||||
|
||||
try:
|
||||
# Resolve the path
|
||||
path = Path(path_str)
|
||||
|
||||
# Security: require absolute path
|
||||
if not path.is_absolute():
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": f"Path must be absolute: {path_str}",
|
||||
}
|
||||
}
|
||||
|
||||
# Resolve symlinks and normalize
|
||||
resolved_path = path.resolve()
|
||||
|
||||
# Check if file exists - return null content for non-existent files
|
||||
# (this allows agents to check file existence without error)
|
||||
if not resolved_path.exists():
|
||||
return {"content": None, "exists": False}
|
||||
|
||||
# Check if it's a file (not directory)
|
||||
if not resolved_path.is_file():
|
||||
return {
|
||||
"error": {
|
||||
"code": -32002,
|
||||
"message": f"Path is not a file: {path_str}",
|
||||
}
|
||||
}
|
||||
|
||||
# Read file content
|
||||
content = resolved_path.read_text(encoding="utf-8")
|
||||
|
||||
return {"content": content}
|
||||
|
||||
except PermissionError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32003,
|
||||
"message": f"Permission denied: {path_str}",
|
||||
}
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32004,
|
||||
"message": f"File is not valid UTF-8 text: {path_str}",
|
||||
}
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": f"Failed to read file: {e}",
|
||||
}
|
||||
}
|
||||
|
||||
def handle_write_file(self, params: dict) -> dict:
|
||||
"""Handle fs/write_text_file request from agent.
|
||||
|
||||
Writes content to file with security validation.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'path' and 'content' keys.
|
||||
|
||||
Returns:
|
||||
Dict with 'success: True' on success, or 'error' on failure.
|
||||
"""
|
||||
path_str = params.get("path")
|
||||
content = params.get("content")
|
||||
|
||||
if not path_str:
|
||||
return {"error": {"code": -32602, "message": "Missing required parameter: path"}}
|
||||
|
||||
if content is None:
|
||||
return {"error": {"code": -32602, "message": "Missing required parameter: content"}}
|
||||
|
||||
try:
|
||||
# Resolve the path
|
||||
path = Path(path_str)
|
||||
|
||||
# Security: require absolute path
|
||||
if not path.is_absolute():
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": f"Path must be absolute: {path_str}",
|
||||
}
|
||||
}
|
||||
|
||||
# Resolve symlinks and normalize
|
||||
resolved_path = path.resolve()
|
||||
|
||||
# Check if path exists and is a directory (can't write to directory)
|
||||
if resolved_path.exists() and resolved_path.is_dir():
|
||||
return {
|
||||
"error": {
|
||||
"code": -32002,
|
||||
"message": f"Path is a directory: {path_str}",
|
||||
}
|
||||
}
|
||||
|
||||
# Create parent directories if needed
|
||||
resolved_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write file content
|
||||
resolved_path.write_text(content, encoding="utf-8")
|
||||
|
||||
return {"success": True}
|
||||
|
||||
except PermissionError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32003,
|
||||
"message": f"Permission denied: {path_str}",
|
||||
}
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": f"Failed to write file: {e}",
|
||||
}
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Terminal Operation Handlers
|
||||
# =========================================================================
|
||||
|
||||
def handle_terminal_create(self, params: dict) -> dict:
|
||||
"""Handle terminal/create request from agent.
|
||||
|
||||
Creates a new terminal subprocess for command execution.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'command' (list of strings) and
|
||||
optional 'cwd' (working directory).
|
||||
|
||||
Returns:
|
||||
Dict with 'terminalId' on success, or 'error' on failure.
|
||||
"""
|
||||
command = params.get("command")
|
||||
|
||||
if command is None:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: command",
|
||||
}
|
||||
}
|
||||
|
||||
if not isinstance(command, list):
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "command must be a list of strings",
|
||||
}
|
||||
}
|
||||
|
||||
if len(command) == 0:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "command list cannot be empty",
|
||||
}
|
||||
}
|
||||
|
||||
cwd = params.get("cwd")
|
||||
|
||||
try:
|
||||
# Create subprocess with pipes for stdout/stderr
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
stdin=subprocess.DEVNULL,
|
||||
cwd=cwd,
|
||||
text=True,
|
||||
bufsize=0,
|
||||
)
|
||||
|
||||
# Generate unique terminal ID
|
||||
terminal_id = str(uuid.uuid4())
|
||||
|
||||
# Create terminal instance
|
||||
terminal = Terminal(id=terminal_id, process=process)
|
||||
self._terminals[terminal_id] = terminal
|
||||
|
||||
return {"terminalId": terminal_id}
|
||||
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Command not found: {command[0]}",
|
||||
}
|
||||
}
|
||||
except PermissionError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32003,
|
||||
"message": f"Permission denied executing: {command[0]}",
|
||||
}
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": f"Failed to create terminal: {e}",
|
||||
}
|
||||
}
|
||||
|
||||
def handle_terminal_output(self, params: dict) -> dict:
|
||||
"""Handle terminal/output request from agent.
|
||||
|
||||
Reads available output from a terminal.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'terminalId'.
|
||||
|
||||
Returns:
|
||||
Dict with 'output' and 'done' on success, or 'error' on failure.
|
||||
"""
|
||||
terminal_id = params.get("terminalId")
|
||||
|
||||
if not terminal_id:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: terminalId",
|
||||
}
|
||||
}
|
||||
|
||||
terminal = self._terminals.get(terminal_id)
|
||||
if not terminal:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Terminal not found: {terminal_id}",
|
||||
}
|
||||
}
|
||||
|
||||
# Read any new output
|
||||
terminal.read_output()
|
||||
|
||||
return {
|
||||
"output": terminal.output_buffer,
|
||||
"done": not terminal.is_running,
|
||||
}
|
||||
|
||||
def handle_terminal_wait_for_exit(self, params: dict) -> dict:
|
||||
"""Handle terminal/wait_for_exit request from agent.
|
||||
|
||||
Waits for a terminal process to exit.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'terminalId' and optional 'timeout'.
|
||||
|
||||
Returns:
|
||||
Dict with 'exitCode' on success, or 'error' on failure/timeout.
|
||||
"""
|
||||
terminal_id = params.get("terminalId")
|
||||
|
||||
if not terminal_id:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: terminalId",
|
||||
}
|
||||
}
|
||||
|
||||
terminal = self._terminals.get(terminal_id)
|
||||
if not terminal:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Terminal not found: {terminal_id}",
|
||||
}
|
||||
}
|
||||
|
||||
timeout = params.get("timeout")
|
||||
|
||||
try:
|
||||
exit_code = terminal.wait(timeout=timeout)
|
||||
# Read any remaining output
|
||||
terminal.read_output()
|
||||
return {"exitCode": exit_code}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": f"Wait timed out after {timeout}s",
|
||||
}
|
||||
}
|
||||
|
||||
def handle_terminal_kill(self, params: dict) -> dict:
|
||||
"""Handle terminal/kill request from agent.
|
||||
|
||||
Kills a terminal process.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'terminalId'.
|
||||
|
||||
Returns:
|
||||
Dict with 'success: True' on success, or 'error' on failure.
|
||||
"""
|
||||
terminal_id = params.get("terminalId")
|
||||
|
||||
if not terminal_id:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: terminalId",
|
||||
}
|
||||
}
|
||||
|
||||
terminal = self._terminals.get(terminal_id)
|
||||
if not terminal:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Terminal not found: {terminal_id}",
|
||||
}
|
||||
}
|
||||
|
||||
terminal.kill()
|
||||
return {"success": True}
|
||||
|
||||
def handle_terminal_release(self, params: dict) -> dict:
|
||||
"""Handle terminal/release request from agent.
|
||||
|
||||
Releases terminal resources, killing the process if still running.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'terminalId'.
|
||||
|
||||
Returns:
|
||||
Dict with 'success: True' on success, or 'error' on failure.
|
||||
"""
|
||||
terminal_id = params.get("terminalId")
|
||||
|
||||
if not terminal_id:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: terminalId",
|
||||
}
|
||||
}
|
||||
|
||||
terminal = self._terminals.get(terminal_id)
|
||||
if not terminal:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Terminal not found: {terminal_id}",
|
||||
}
|
||||
}
|
||||
|
||||
# Kill the process if still running before releasing
|
||||
if terminal.is_running:
|
||||
try:
|
||||
terminal.kill()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to kill terminal %s during release: %s", terminal_id, e)
|
||||
|
||||
# Clean up the terminal
|
||||
del self._terminals[terminal_id]
|
||||
return {"success": True}
|
||||
@@ -0,0 +1,536 @@
|
||||
# ABOUTME: Typed dataclasses for ACP (Agent Client Protocol) messages
|
||||
# ABOUTME: Defines request/response types, session state, and configuration
|
||||
|
||||
"""Typed data models for ACP messages and session state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ralph_orchestrator.main import AdapterConfig
|
||||
|
||||
# Get logger for this module
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateKind(str, Enum):
|
||||
"""Types of session updates."""
|
||||
|
||||
AGENT_MESSAGE_CHUNK = "agent_message_chunk"
|
||||
AGENT_THOUGHT_CHUNK = "agent_thought_chunk"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_CALL_UPDATE = "tool_call_update"
|
||||
PLAN = "plan"
|
||||
|
||||
|
||||
class ToolCallStatus(str, Enum):
|
||||
"""Status of a tool call."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PermissionMode(str, Enum):
|
||||
"""Permission modes for ACP operations."""
|
||||
|
||||
AUTO_APPROVE = "auto_approve"
|
||||
DENY_ALL = "deny_all"
|
||||
ALLOWLIST = "allowlist"
|
||||
INTERACTIVE = "interactive"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPRequest:
|
||||
"""JSON-RPC 2.0 request message.
|
||||
|
||||
Attributes:
|
||||
id: Request identifier for matching response.
|
||||
method: The RPC method to invoke.
|
||||
params: Method parameters.
|
||||
"""
|
||||
|
||||
id: int
|
||||
method: str
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPRequest":
|
||||
"""Parse ACPRequest from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with id, method, and optional params.
|
||||
|
||||
Returns:
|
||||
ACPRequest instance.
|
||||
|
||||
Raises:
|
||||
KeyError: If id or method is missing.
|
||||
"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
method=data["method"],
|
||||
params=data.get("params", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPNotification:
|
||||
"""JSON-RPC 2.0 notification (no response expected).
|
||||
|
||||
Attributes:
|
||||
method: The notification method.
|
||||
params: Method parameters.
|
||||
"""
|
||||
|
||||
method: str
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPNotification":
|
||||
"""Parse ACPNotification from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with method and optional params.
|
||||
|
||||
Returns:
|
||||
ACPNotification instance.
|
||||
"""
|
||||
return cls(
|
||||
method=data["method"],
|
||||
params=data.get("params", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPResponse:
|
||||
"""JSON-RPC 2.0 success response.
|
||||
|
||||
Attributes:
|
||||
id: Request identifier this response matches.
|
||||
result: The response result data.
|
||||
"""
|
||||
|
||||
id: int
|
||||
result: Any
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPResponse":
|
||||
"""Parse ACPResponse from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with id and result.
|
||||
|
||||
Returns:
|
||||
ACPResponse instance.
|
||||
"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
result=data.get("result"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPErrorObject:
|
||||
"""JSON-RPC 2.0 error object.
|
||||
|
||||
Attributes:
|
||||
code: Error code (negative integers for standard errors).
|
||||
message: Human-readable error message.
|
||||
data: Optional additional error data.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: Any = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPErrorObject":
|
||||
"""Parse ACPErrorObject from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with code, message, and optional data.
|
||||
|
||||
Returns:
|
||||
ACPErrorObject instance.
|
||||
"""
|
||||
return cls(
|
||||
code=data["code"],
|
||||
message=data["message"],
|
||||
data=data.get("data"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPError:
|
||||
"""JSON-RPC 2.0 error response.
|
||||
|
||||
Attributes:
|
||||
id: Request identifier this error matches.
|
||||
error: The error object with code and message.
|
||||
"""
|
||||
|
||||
id: int
|
||||
error: ACPErrorObject
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPError":
|
||||
"""Parse ACPError from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with id and error object.
|
||||
|
||||
Returns:
|
||||
ACPError instance.
|
||||
"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
error=ACPErrorObject.from_dict(data["error"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdatePayload:
|
||||
"""Payload for session/update notifications.
|
||||
|
||||
Handles different update kinds:
|
||||
- agent_message_chunk: Text output from agent
|
||||
- agent_thought_chunk: Internal reasoning (verbose mode)
|
||||
- tool_call: Agent requesting tool execution
|
||||
- tool_call_update: Status update for tool call
|
||||
- plan: Agent's execution plan
|
||||
|
||||
Attributes:
|
||||
kind: The update type (see UpdateKind enum for valid values).
|
||||
content: Text content (for message/thought chunks).
|
||||
tool_name: Name of tool being called.
|
||||
tool_call_id: Unique identifier for tool call.
|
||||
arguments: Tool call arguments.
|
||||
status: Tool call status (see ToolCallStatus enum for valid values).
|
||||
result: Tool call result data.
|
||||
error: Tool call error message.
|
||||
"""
|
||||
|
||||
kind: str # Valid values: UpdateKind enum members
|
||||
content: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
arguments: Optional[dict[str, Any]] = None
|
||||
status: Optional[str] = None
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "UpdatePayload":
|
||||
"""Parse UpdatePayload from dict.
|
||||
|
||||
Handles camelCase to snake_case conversion for ACP fields.
|
||||
|
||||
Args:
|
||||
data: Dict with kind and kind-specific fields.
|
||||
|
||||
Returns:
|
||||
UpdatePayload instance.
|
||||
"""
|
||||
kind = data["kind"]
|
||||
tool_name = data.get("toolName")
|
||||
tool_call_id = data.get("toolCallId")
|
||||
arguments = data.get("arguments")
|
||||
if kind in (UpdateKind.TOOL_CALL, UpdateKind.TOOL_CALL_UPDATE):
|
||||
if tool_name is None:
|
||||
tool_name = data.get("tool_name") or data.get("name")
|
||||
if tool_name is None and isinstance(data.get("tool"), str):
|
||||
tool_name = data.get("tool")
|
||||
if tool_call_id is None:
|
||||
tool_call_id = data.get("tool_call_id") or data.get("id")
|
||||
if kind == UpdateKind.TOOL_CALL and arguments is None:
|
||||
arguments = data.get("args") or data.get("parameters") or data.get("params")
|
||||
|
||||
return cls(
|
||||
kind=kind,
|
||||
content=data.get("content"),
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
status=data.get("status"),
|
||||
result=data.get("result"),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionUpdate:
|
||||
"""Wrapper for session/update notification.
|
||||
|
||||
Attributes:
|
||||
method: Should be "session/update".
|
||||
payload: The update payload.
|
||||
"""
|
||||
|
||||
method: str
|
||||
payload: UpdatePayload
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionUpdate":
|
||||
"""Parse SessionUpdate from notification dict.
|
||||
|
||||
Args:
|
||||
data: Dict with method and params.
|
||||
|
||||
Returns:
|
||||
SessionUpdate instance.
|
||||
"""
|
||||
return cls(
|
||||
method=data["method"],
|
||||
payload=UpdatePayload.from_dict(data["params"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""Tracks a tool execution within a session.
|
||||
|
||||
Attributes:
|
||||
tool_call_id: Unique identifier for this call.
|
||||
tool_name: Name of the tool being called.
|
||||
arguments: Arguments passed to the tool.
|
||||
status: Current status (pending/running/completed/failed).
|
||||
result: Result data if completed.
|
||||
error: Error message if failed.
|
||||
"""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
arguments: dict[str, Any]
|
||||
status: str = "pending"
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolCall":
|
||||
"""Parse ToolCall from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with toolCallId, toolName, arguments.
|
||||
|
||||
Returns:
|
||||
ToolCall instance.
|
||||
"""
|
||||
return cls(
|
||||
tool_call_id=data["toolCallId"],
|
||||
tool_name=data["toolName"],
|
||||
arguments=data.get("arguments", {}),
|
||||
status=data.get("status", "pending"),
|
||||
result=data.get("result"),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPSession:
|
||||
"""Accumulates session state during prompt execution.
|
||||
|
||||
Tracks output chunks, thoughts, and tool calls for building
|
||||
the final response.
|
||||
|
||||
Attributes:
|
||||
session_id: Unique session identifier.
|
||||
output: Accumulated agent output text.
|
||||
thoughts: Accumulated agent thoughts (verbose).
|
||||
tool_calls: List of tool calls in this session.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
output: str = ""
|
||||
thoughts: str = ""
|
||||
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||
|
||||
def append_output(self, text: str) -> None:
|
||||
"""Append text to accumulated output.
|
||||
|
||||
Args:
|
||||
text: Text chunk to append.
|
||||
"""
|
||||
self.output += text
|
||||
|
||||
def append_thought(self, text: str) -> None:
|
||||
"""Append text to accumulated thoughts.
|
||||
|
||||
Args:
|
||||
text: Thought chunk to append.
|
||||
"""
|
||||
self.thoughts += text
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCall) -> None:
|
||||
"""Add a new tool call to track.
|
||||
|
||||
Args:
|
||||
tool_call: The ToolCall to track.
|
||||
"""
|
||||
self.tool_calls.append(tool_call)
|
||||
|
||||
def get_tool_call(self, tool_call_id: str) -> Optional[ToolCall]:
|
||||
"""Find a tool call by ID.
|
||||
|
||||
Args:
|
||||
tool_call_id: The ID to look up.
|
||||
|
||||
Returns:
|
||||
ToolCall if found, None otherwise.
|
||||
"""
|
||||
for tc in self.tool_calls:
|
||||
if tc.tool_call_id == tool_call_id:
|
||||
return tc
|
||||
return None
|
||||
|
||||
def process_update(self, payload: UpdatePayload) -> None:
|
||||
"""Process a session update payload.
|
||||
|
||||
Routes update to appropriate handler based on kind.
|
||||
|
||||
Args:
|
||||
payload: The update payload to process.
|
||||
"""
|
||||
if payload.kind == "agent_message_chunk":
|
||||
if payload.content:
|
||||
self.append_output(payload.content)
|
||||
elif payload.kind == "agent_thought_chunk":
|
||||
if payload.content:
|
||||
self.append_thought(payload.content)
|
||||
elif payload.kind == "tool_call":
|
||||
tool_call = ToolCall(
|
||||
tool_call_id=payload.tool_call_id or "",
|
||||
tool_name=payload.tool_name or "",
|
||||
arguments=payload.arguments or {},
|
||||
)
|
||||
self.add_tool_call(tool_call)
|
||||
elif payload.kind == "tool_call_update":
|
||||
if payload.tool_call_id:
|
||||
tc = self.get_tool_call(payload.tool_call_id)
|
||||
if tc:
|
||||
if payload.status:
|
||||
tc.status = payload.status
|
||||
if payload.result is not None:
|
||||
tc.result = payload.result
|
||||
if payload.error:
|
||||
tc.error = payload.error
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset session state for a new prompt.
|
||||
|
||||
Preserves session_id but clears accumulated data.
|
||||
"""
|
||||
self.output = ""
|
||||
self.thoughts = ""
|
||||
self.tool_calls = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPAdapterConfig:
|
||||
"""Configuration for the ACP adapter.
|
||||
|
||||
Attributes:
|
||||
agent_command: Command to spawn the agent (default: gemini).
|
||||
agent_args: Additional arguments for agent command.
|
||||
timeout: Request timeout in seconds.
|
||||
permission_mode: How to handle permission requests.
|
||||
- auto_approve: Approve all requests.
|
||||
- deny_all: Deny all requests.
|
||||
- allowlist: Check against permission_allowlist.
|
||||
- interactive: Prompt user for each request.
|
||||
permission_allowlist: Patterns to allow in allowlist mode.
|
||||
"""
|
||||
|
||||
agent_command: str = "gemini"
|
||||
agent_args: list[str] = field(default_factory=list)
|
||||
timeout: int = 300
|
||||
permission_mode: str = "auto_approve"
|
||||
permission_allowlist: list[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPAdapterConfig":
|
||||
"""Parse ACPAdapterConfig from dict.
|
||||
|
||||
Uses defaults for missing keys.
|
||||
|
||||
Args:
|
||||
data: Configuration dict.
|
||||
|
||||
Returns:
|
||||
ACPAdapterConfig instance.
|
||||
"""
|
||||
return cls(
|
||||
agent_command=data.get("agent_command", "gemini"),
|
||||
agent_args=data.get("agent_args", []),
|
||||
timeout=data.get("timeout", 300),
|
||||
permission_mode=data.get("permission_mode", "auto_approve"),
|
||||
permission_allowlist=data.get("permission_allowlist", []),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_adapter_config(cls, adapter_config: "AdapterConfig") -> "ACPAdapterConfig":
|
||||
"""Create ACPAdapterConfig from AdapterConfig with env var overrides.
|
||||
|
||||
Extracts ACP-specific settings from AdapterConfig.tool_permissions
|
||||
and applies environment variable overrides.
|
||||
|
||||
Environment Variables:
|
||||
RALPH_ACP_AGENT: Override agent_command
|
||||
RALPH_ACP_PERMISSION_MODE: Override permission_mode
|
||||
RALPH_ACP_TIMEOUT: Override timeout (integer)
|
||||
|
||||
Args:
|
||||
adapter_config: General adapter configuration.
|
||||
|
||||
Returns:
|
||||
ACPAdapterConfig with ACP-specific settings.
|
||||
"""
|
||||
# Start with tool_permissions or empty dict
|
||||
tool_perms = adapter_config.tool_permissions or {}
|
||||
|
||||
# Get base values from tool_permissions
|
||||
agent_command = tool_perms.get("agent_command", "gemini")
|
||||
agent_args = tool_perms.get("agent_args", [])
|
||||
timeout = tool_perms.get("timeout", adapter_config.timeout)
|
||||
permission_mode = tool_perms.get("permission_mode", "auto_approve")
|
||||
permission_allowlist = tool_perms.get("permission_allowlist", [])
|
||||
|
||||
# Apply environment variable overrides
|
||||
if env_agent := os.environ.get("RALPH_ACP_AGENT"):
|
||||
agent_command = env_agent
|
||||
|
||||
if env_mode := os.environ.get("RALPH_ACP_PERMISSION_MODE"):
|
||||
valid_modes = {"auto_approve", "deny_all", "allowlist", "interactive"}
|
||||
if env_mode in valid_modes:
|
||||
permission_mode = env_mode
|
||||
else:
|
||||
_logger.warning(
|
||||
"Invalid RALPH_ACP_PERMISSION_MODE value '%s'. Valid modes: %s. Using default: %s",
|
||||
env_mode,
|
||||
", ".join(valid_modes),
|
||||
permission_mode,
|
||||
)
|
||||
|
||||
if env_timeout := os.environ.get("RALPH_ACP_TIMEOUT"):
|
||||
try:
|
||||
timeout = int(env_timeout)
|
||||
except ValueError:
|
||||
_logger.warning(
|
||||
"Invalid RALPH_ACP_TIMEOUT value '%s' - must be integer. Using default: %d",
|
||||
env_timeout,
|
||||
timeout,
|
||||
)
|
||||
|
||||
return cls(
|
||||
agent_command=agent_command,
|
||||
agent_args=agent_args,
|
||||
timeout=timeout,
|
||||
permission_mode=permission_mode,
|
||||
permission_allowlist=permission_allowlist,
|
||||
)
|
||||
@@ -0,0 +1,214 @@
|
||||
# ABOUTME: JSON-RPC 2.0 protocol handler for ACP (Agent Client Protocol)
|
||||
# ABOUTME: Handles message serialization, parsing, and protocol state
|
||||
|
||||
"""JSON-RPC 2.0 protocol handling for ACP."""
|
||||
|
||||
import json
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of JSON-RPC 2.0 messages."""
|
||||
|
||||
REQUEST = auto() # Has id and method
|
||||
NOTIFICATION = auto() # Has method but no id
|
||||
RESPONSE = auto() # Has id and result
|
||||
ERROR = auto() # Has id and error
|
||||
PARSE_ERROR = auto() # Failed to parse JSON
|
||||
INVALID = auto() # Invalid JSON-RPC message
|
||||
|
||||
|
||||
class ACPErrorCodes:
|
||||
"""Standard JSON-RPC 2.0 and ACP-specific error codes."""
|
||||
|
||||
# Standard JSON-RPC 2.0 error codes
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
|
||||
# ACP-specific error codes
|
||||
PERMISSION_DENIED = -32001
|
||||
FILE_NOT_FOUND = -32002
|
||||
FILE_ACCESS_ERROR = -32003
|
||||
TERMINAL_ERROR = -32004
|
||||
|
||||
|
||||
class ACPProtocol:
|
||||
"""JSON-RPC 2.0 protocol handler for ACP.
|
||||
|
||||
Handles serialization and deserialization of JSON-RPC messages
|
||||
for the Agent Client Protocol.
|
||||
|
||||
Attributes:
|
||||
_request_id: Auto-incrementing request ID counter.
|
||||
"""
|
||||
|
||||
JSONRPC_VERSION = "2.0"
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize protocol handler with request ID counter at 0."""
|
||||
self._request_id: int = 0
|
||||
|
||||
def create_request(self, method: str, params: dict[str, Any]) -> tuple[int, str]:
|
||||
"""Create a JSON-RPC 2.0 request message.
|
||||
|
||||
Args:
|
||||
method: The RPC method name (e.g., "session/prompt").
|
||||
params: The request parameters.
|
||||
|
||||
Returns:
|
||||
Tuple of (request_id, json_string) for tracking and sending.
|
||||
"""
|
||||
self._request_id += 1
|
||||
request_id = self._request_id
|
||||
|
||||
message = {
|
||||
"jsonrpc": self.JSONRPC_VERSION,
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
return request_id, json.dumps(message)
|
||||
|
||||
def create_notification(self, method: str, params: dict[str, Any]) -> str:
|
||||
"""Create a JSON-RPC 2.0 notification message (no id, no response expected).
|
||||
|
||||
Args:
|
||||
method: The RPC method name.
|
||||
params: The notification parameters.
|
||||
|
||||
Returns:
|
||||
JSON string of the notification.
|
||||
"""
|
||||
message = {
|
||||
"jsonrpc": self.JSONRPC_VERSION,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
return json.dumps(message)
|
||||
|
||||
def parse_message(self, data: str) -> dict[str, Any]:
|
||||
"""Parse an incoming JSON-RPC 2.0 message.
|
||||
|
||||
Determines the message type and validates structure.
|
||||
|
||||
Args:
|
||||
data: Raw JSON string to parse.
|
||||
|
||||
Returns:
|
||||
Dict with 'type' key indicating MessageType and parsed fields.
|
||||
On error, includes 'error' key with description.
|
||||
"""
|
||||
# Try to parse JSON
|
||||
try:
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError as e:
|
||||
return {
|
||||
"type": MessageType.PARSE_ERROR,
|
||||
"error": f"JSON parse error: {e}",
|
||||
}
|
||||
|
||||
# Validate jsonrpc version field
|
||||
if message.get("jsonrpc") != self.JSONRPC_VERSION:
|
||||
return {
|
||||
"type": MessageType.INVALID,
|
||||
"error": f"Invalid or missing jsonrpc field. Expected '2.0', got '{message.get('jsonrpc')}'",
|
||||
}
|
||||
|
||||
# Determine message type based on fields
|
||||
has_id = "id" in message
|
||||
has_method = "method" in message
|
||||
has_result = "result" in message
|
||||
has_error = "error" in message
|
||||
|
||||
if has_error and has_id:
|
||||
# Error response
|
||||
return {
|
||||
"type": MessageType.ERROR,
|
||||
"id": message["id"],
|
||||
"error": message["error"],
|
||||
}
|
||||
elif has_result and has_id:
|
||||
# Success response
|
||||
return {
|
||||
"type": MessageType.RESPONSE,
|
||||
"id": message["id"],
|
||||
"result": message["result"],
|
||||
}
|
||||
elif has_method and has_id:
|
||||
# Request (has id, expects response)
|
||||
return {
|
||||
"type": MessageType.REQUEST,
|
||||
"id": message["id"],
|
||||
"method": message["method"],
|
||||
"params": message.get("params", {}),
|
||||
}
|
||||
elif has_method and not has_id:
|
||||
# Notification (no id, no response expected)
|
||||
return {
|
||||
"type": MessageType.NOTIFICATION,
|
||||
"method": message["method"],
|
||||
"params": message.get("params", {}),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": MessageType.INVALID,
|
||||
"error": "Invalid JSON-RPC message structure",
|
||||
}
|
||||
|
||||
def create_response(self, request_id: int, result: Any) -> str:
|
||||
"""Create a JSON-RPC 2.0 success response.
|
||||
|
||||
Args:
|
||||
request_id: The ID from the original request.
|
||||
result: The result data to return.
|
||||
|
||||
Returns:
|
||||
JSON string of the response.
|
||||
"""
|
||||
message = {
|
||||
"jsonrpc": self.JSONRPC_VERSION,
|
||||
"id": request_id,
|
||||
"result": result,
|
||||
}
|
||||
|
||||
return json.dumps(message)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
request_id: int,
|
||||
code: int,
|
||||
message: str,
|
||||
data: Any = None,
|
||||
) -> str:
|
||||
"""Create a JSON-RPC 2.0 error response.
|
||||
|
||||
Args:
|
||||
request_id: The ID from the original request.
|
||||
code: The error code (use ACPErrorCodes constants).
|
||||
message: Human-readable error message.
|
||||
data: Optional additional error data.
|
||||
|
||||
Returns:
|
||||
JSON string of the error response.
|
||||
"""
|
||||
error_obj: dict[str, Any] = {
|
||||
"code": code,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if data is not None:
|
||||
error_obj["data"] = data
|
||||
|
||||
response = {
|
||||
"jsonrpc": self.JSONRPC_VERSION,
|
||||
"id": request_id,
|
||||
"error": error_obj,
|
||||
}
|
||||
|
||||
return json.dumps(response)
|
||||
@@ -0,0 +1,189 @@
|
||||
# ABOUTME: Abstract base class for tool adapters
|
||||
# ABOUTME: Defines the interface all tool adapters must implement
|
||||
|
||||
"""Base adapter interface for AI tools."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResponse:
|
||||
"""Response from a tool execution."""
|
||||
|
||||
success: bool
|
||||
output: str
|
||||
error: Optional[str] = None
|
||||
tokens_used: Optional[int] = None
|
||||
cost: Optional[float] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
class ToolAdapter(ABC):
|
||||
"""Abstract base class for tool adapters."""
|
||||
|
||||
def __init__(self, name: str, config=None):
|
||||
self.name = name
|
||||
self.config = config or type('Config', (), {
|
||||
'enabled': True, 'timeout': 300, 'max_retries': 3,
|
||||
'args': [], 'env': {}
|
||||
})()
|
||||
self.completion_promise: Optional[str] = None
|
||||
self.available = self.check_availability()
|
||||
|
||||
@abstractmethod
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if the tool is available and properly configured."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute the tool with the given prompt."""
|
||||
pass
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Async execute the tool with the given prompt.
|
||||
|
||||
Default implementation runs sync execute in thread pool.
|
||||
Subclasses can override for native async support.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
# Create a function that can be called with no arguments for run_in_executor
|
||||
def execute_with_args():
|
||||
return self.execute(prompt, **kwargs)
|
||||
return await loop.run_in_executor(None, execute_with_args)
|
||||
|
||||
def execute_with_file(self, prompt_file: Path, **kwargs) -> ToolResponse:
|
||||
"""Execute the tool with a prompt file."""
|
||||
if not prompt_file.exists():
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"Prompt file {prompt_file} not found"
|
||||
)
|
||||
|
||||
with open(prompt_file, 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
return self.execute(prompt, **kwargs)
|
||||
|
||||
async def aexecute_with_file(self, prompt_file: Path, **kwargs) -> ToolResponse:
|
||||
"""Async execute the tool with a prompt file."""
|
||||
if not prompt_file.exists():
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"Prompt file {prompt_file} not found"
|
||||
)
|
||||
|
||||
# Use asyncio.to_thread to avoid blocking the event loop with file I/O
|
||||
prompt = await asyncio.to_thread(prompt_file.read_text, encoding='utf-8')
|
||||
|
||||
return await self.aexecute(prompt, **kwargs)
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Estimate the cost of executing this prompt."""
|
||||
# Default implementation - subclasses can override
|
||||
return 0.0
|
||||
|
||||
def _enhance_prompt_with_instructions(self, prompt: str, completion_promise: Optional[str] = None) -> str:
|
||||
"""Enhance prompt with orchestration context and instructions.
|
||||
|
||||
Args:
|
||||
prompt: The original prompt
|
||||
completion_promise: Optional string that signals task completion.
|
||||
If not provided, uses self.completion_promise.
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with orchestration instructions
|
||||
"""
|
||||
# Resolve completion promise: argument takes precedence, then instance state
|
||||
promise = completion_promise or self.completion_promise
|
||||
|
||||
# Check if instructions already exist in the prompt
|
||||
instruction_markers = [
|
||||
"ORCHESTRATION CONTEXT:",
|
||||
"IMPORTANT INSTRUCTIONS:",
|
||||
"Implement only ONE small, focused task"
|
||||
]
|
||||
|
||||
# If any marker exists, assume instructions are already present
|
||||
instructions_present = False
|
||||
for marker in instruction_markers:
|
||||
if marker in prompt:
|
||||
instructions_present = True
|
||||
break
|
||||
|
||||
enhanced_prompt = prompt
|
||||
|
||||
if not instructions_present:
|
||||
# Add orchestration context and instructions
|
||||
orchestration_instructions = """
|
||||
ORCHESTRATION CONTEXT:
|
||||
You are running within the Ralph Orchestrator loop. This system will call you repeatedly
|
||||
for multiple iterations until the overall task is complete. Each iteration is a separate
|
||||
execution where you should make incremental progress.
|
||||
|
||||
The final output must be well-tested, documented, and production ready.
|
||||
|
||||
IMPORTANT INSTRUCTIONS:
|
||||
1. Implement only ONE small, focused task from this prompt per iteration.
|
||||
- Each iteration is independent - focus on a single atomic change
|
||||
- The orchestrator will handle calling you again for the next task
|
||||
- Mark subtasks complete as you finish them
|
||||
- You must commit your changes after each iteration, for checkpointing.
|
||||
2. Use the .agent/workspace/ directory for any temporary files or workspaces if not already instructed in the prompt.
|
||||
3. Follow this workflow for implementing features:
|
||||
- Explore: Research and understand the codebase
|
||||
- Plan: Design your implementation approach
|
||||
- Implement: Use Test-Driven Development (TDD) - write tests first, then code
|
||||
- Commit: Commit your changes with clear messages
|
||||
4. When you complete a subtask, document it in the prompt file so the next iteration knows what's done.
|
||||
5. For maximum efficiency, whenever you need to perform multiple independent operations, invoke all relevant tools simultaneously rather than sequentially.
|
||||
6. If you create any temporary new files, scripts, or helper files for iteration, clean up these files by removing them at the end of the task.
|
||||
|
||||
## Agent Scratchpad
|
||||
Before starting your work, check if .agent/scratchpad.md exists in the current working directory.
|
||||
If it does, read it to understand what was accomplished in previous iterations and continue from there.
|
||||
|
||||
At the end of your iteration, update .agent/scratchpad.md with:
|
||||
- What you accomplished this iteration
|
||||
- What remains to be done
|
||||
- Any important context or decisions made
|
||||
- Current blockers or issues (if any)
|
||||
|
||||
Do NOT restart from scratch if the scratchpad shows previous progress. Continue where the previous iteration left off.
|
||||
|
||||
Create the .agent/ directory if it doesn't exist.
|
||||
|
||||
---
|
||||
ORIGINAL PROMPT:
|
||||
|
||||
"""
|
||||
enhanced_prompt = orchestration_instructions + prompt
|
||||
|
||||
# Inject completion promise if requested and not present
|
||||
if promise:
|
||||
# Check if promise is already in the text (simple check)
|
||||
# We check both the raw promise and the section header to be safe,
|
||||
# but mainly we want to avoid duplicating the specific instruction.
|
||||
if promise not in enhanced_prompt:
|
||||
promise_section = f"""
|
||||
|
||||
## Completion Promise
|
||||
When you have completed the task, output this exact line:
|
||||
{promise}
|
||||
"""
|
||||
enhanced_prompt += promise_section
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} (available: {self.available})"
|
||||
@@ -0,0 +1,616 @@
|
||||
# ABOUTME: Claude SDK adapter implementation
|
||||
# ABOUTME: Provides integration with Anthropic's Claude via Python SDK
|
||||
# ABOUTME: Supports inheriting user's Claude Code settings (MCP servers, CLAUDE.md, etc.)
|
||||
|
||||
"""Claude SDK adapter for Ralph Orchestrator."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
from typing import Optional
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from ..error_formatter import ClaudeErrorFormatter
|
||||
from ..output.console import RalphConsole
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, query
|
||||
CLAUDE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
# Fallback to old package name for backwards compatibility
|
||||
try:
|
||||
from claude_code_sdk import ClaudeCodeOptions as ClaudeAgentOptions, query
|
||||
CLAUDE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
CLAUDE_SDK_AVAILABLE = False
|
||||
query = None
|
||||
ClaudeAgentOptions = None
|
||||
|
||||
|
||||
class ClaudeAdapter(ToolAdapter):
|
||||
"""Adapter for Claude using the Python SDK."""
|
||||
|
||||
# Default max buffer size: 10MB (handles large screenshots from chrome-devtools-mcp)
|
||||
DEFAULT_MAX_BUFFER_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# Default model: Claude Opus 4.5 (most intelligent model)
|
||||
DEFAULT_MODEL = "claude-opus-4-5-20251101"
|
||||
|
||||
# Model pricing (per million tokens)
|
||||
MODEL_PRICING = {
|
||||
"claude-opus-4-5-20251101": {"input": 5.0, "output": 25.0},
|
||||
"claude-sonnet-4-5-20250929": {"input": 3.0, "output": 15.0},
|
||||
"claude-haiku-4-5-20251001": {"input": 1.0, "output": 5.0},
|
||||
# Legacy models
|
||||
"claude-3-opus": {"input": 15.0, "output": 75.0},
|
||||
"claude-3-sonnet": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-haiku": {"input": 0.25, "output": 1.25},
|
||||
}
|
||||
|
||||
def __init__(self, verbose: bool = False, max_buffer_size: int = None,
|
||||
inherit_user_settings: bool = True, cli_path: str = None,
|
||||
model: str = None):
|
||||
super().__init__("claude")
|
||||
self.sdk_available = CLAUDE_SDK_AVAILABLE
|
||||
self._system_prompt = None
|
||||
self._allowed_tools = None
|
||||
self._disallowed_tools = None
|
||||
self._enable_all_tools = False
|
||||
self._enable_web_search = True # Enable WebSearch by default
|
||||
self._max_buffer_size = max_buffer_size or self.DEFAULT_MAX_BUFFER_SIZE
|
||||
self.verbose = verbose
|
||||
# Enable loading user's Claude Code settings (including MCP servers) by default
|
||||
self._inherit_user_settings = inherit_user_settings
|
||||
# Optional path to user's Claude Code CLI (uses bundled CLI if not specified)
|
||||
self._cli_path = cli_path
|
||||
self._model = model or self.DEFAULT_MODEL
|
||||
self._subprocess_pid: Optional[int] = None
|
||||
self._console = RalphConsole()
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if Claude SDK is available and properly configured."""
|
||||
# Claude Code SDK works without API key - it uses the local environment
|
||||
return CLAUDE_SDK_AVAILABLE
|
||||
|
||||
def configure(self,
|
||||
system_prompt: Optional[str] = None,
|
||||
allowed_tools: Optional[list] = None,
|
||||
disallowed_tools: Optional[list] = None,
|
||||
enable_all_tools: bool = False,
|
||||
enable_web_search: bool = True,
|
||||
inherit_user_settings: Optional[bool] = None,
|
||||
cli_path: Optional[str] = None,
|
||||
model: Optional[str] = None):
|
||||
"""Configure the Claude adapter with custom options.
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt for Claude
|
||||
allowed_tools: List of allowed tools for Claude to use (if None and enable_all_tools=True, all tools are enabled)
|
||||
disallowed_tools: List of disallowed tools
|
||||
enable_all_tools: If True and allowed_tools is None, enables all native Claude tools
|
||||
enable_web_search: If True, explicitly enables WebSearch tool (default: True)
|
||||
inherit_user_settings: If True, load user's Claude Code settings including MCP servers (default: True)
|
||||
cli_path: Path to user's Claude Code CLI (uses bundled CLI if not specified)
|
||||
model: Model to use (default: claude-opus-4-5-20251101)
|
||||
"""
|
||||
self._system_prompt = system_prompt
|
||||
self._allowed_tools = allowed_tools
|
||||
self._disallowed_tools = disallowed_tools
|
||||
self._enable_all_tools = enable_all_tools
|
||||
self._enable_web_search = enable_web_search
|
||||
|
||||
# Update user settings inheritance if specified
|
||||
if inherit_user_settings is not None:
|
||||
self._inherit_user_settings = inherit_user_settings
|
||||
|
||||
# Update CLI path if specified
|
||||
if cli_path is not None:
|
||||
self._cli_path = cli_path
|
||||
|
||||
# Update model if specified
|
||||
if model is not None:
|
||||
self._model = model
|
||||
|
||||
# If web search is enabled and we have an allowed tools list, add WebSearch to it
|
||||
if enable_web_search and allowed_tools is not None and 'WebSearch' not in allowed_tools:
|
||||
self._allowed_tools = allowed_tools + ['WebSearch']
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute Claude with the given prompt synchronously.
|
||||
|
||||
This is a blocking wrapper around the async implementation.
|
||||
"""
|
||||
try:
|
||||
# Create new event loop if needed
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(self.aexecute(prompt, **kwargs))
|
||||
else:
|
||||
# If loop is already running, schedule as task
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, self.aexecute(prompt, **kwargs))
|
||||
return future.result()
|
||||
except Exception as e:
|
||||
# Use error formatter for user-friendly error messages
|
||||
error_msg = ClaudeErrorFormatter.format_error_from_exception(
|
||||
iteration=kwargs.get('iteration', 0),
|
||||
exception=e
|
||||
)
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(error_msg)
|
||||
)
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute Claude with the given prompt asynchronously."""
|
||||
if not self.available:
|
||||
logger.debug("Claude SDK not available")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Claude SDK is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get configuration from kwargs or use defaults
|
||||
prompt_file = kwargs.get('prompt_file', 'PROMPT.md')
|
||||
|
||||
# Build options for Claude Code
|
||||
options_dict = {}
|
||||
|
||||
# Set system prompt with orchestration context
|
||||
system_prompt = kwargs.get('system_prompt', self._system_prompt)
|
||||
if not system_prompt:
|
||||
# Create a default system prompt with orchestration context
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
system_prompt = (
|
||||
f"You are helping complete a task. "
|
||||
f"The task is described in the file '{prompt_file}'. "
|
||||
f"Please edit this file directly to add your solution and progress updates."
|
||||
)
|
||||
# Use the enhanced prompt as the main prompt
|
||||
prompt = enhanced_prompt
|
||||
else:
|
||||
# If custom system prompt provided, still enhance the main prompt
|
||||
prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
options_dict['system_prompt'] = system_prompt
|
||||
|
||||
# Set tool restrictions if provided
|
||||
# If enable_all_tools is True and no allowed_tools specified, don't set any restrictions
|
||||
enable_all_tools = kwargs.get('enable_all_tools', self._enable_all_tools)
|
||||
enable_web_search = kwargs.get('enable_web_search', self._enable_web_search)
|
||||
allowed_tools = kwargs.get('allowed_tools', self._allowed_tools)
|
||||
disallowed_tools = kwargs.get('disallowed_tools', self._disallowed_tools)
|
||||
|
||||
# Add WebSearch to allowed tools if web search is enabled
|
||||
if enable_web_search and allowed_tools is not None and 'WebSearch' not in allowed_tools:
|
||||
allowed_tools = allowed_tools + ['WebSearch']
|
||||
|
||||
# Only set tool restrictions if we're not enabling all tools or if specific tools are provided
|
||||
if not enable_all_tools or allowed_tools:
|
||||
if allowed_tools:
|
||||
options_dict['allowed_tools'] = allowed_tools
|
||||
|
||||
if disallowed_tools:
|
||||
options_dict['disallowed_tools'] = disallowed_tools
|
||||
|
||||
# If enable_all_tools is True and no allowed_tools, Claude will have access to all native tools
|
||||
if enable_all_tools and not allowed_tools:
|
||||
if self.verbose:
|
||||
logger.debug("Enabling all native Claude tools (including WebSearch)")
|
||||
|
||||
# Set permission mode - default to bypassPermissions for smoother operation
|
||||
permission_mode = kwargs.get('permission_mode', 'bypassPermissions')
|
||||
options_dict['permission_mode'] = permission_mode
|
||||
if self.verbose:
|
||||
logger.debug(f"Permission mode: {permission_mode}")
|
||||
|
||||
# Set current working directory to ensure files are created in the right place
|
||||
import os
|
||||
cwd = kwargs.get('cwd', os.getcwd())
|
||||
options_dict['cwd'] = cwd
|
||||
if self.verbose:
|
||||
logger.debug(f"Working directory: {cwd}")
|
||||
|
||||
# Set max buffer size for handling large responses (e.g., screenshots)
|
||||
max_buffer_size = kwargs.get('max_buffer_size', self._max_buffer_size)
|
||||
options_dict['max_buffer_size'] = max_buffer_size
|
||||
if self.verbose:
|
||||
logger.debug(f"Max buffer size: {max_buffer_size} bytes")
|
||||
|
||||
# Configure setting sources to inherit user's Claude Code configuration
|
||||
# This enables MCP servers, CLAUDE.md files, and other user settings
|
||||
inherit_user_settings = kwargs.get('inherit_user_settings', self._inherit_user_settings)
|
||||
if inherit_user_settings:
|
||||
# Load user, project, and local settings (includes MCP servers)
|
||||
options_dict['setting_sources'] = ['user', 'project', 'local']
|
||||
if self.verbose:
|
||||
logger.debug("Inheriting user's Claude Code settings (MCP servers, CLAUDE.md, etc.)")
|
||||
|
||||
# Optional: use user's installed Claude Code CLI instead of bundled
|
||||
cli_path = kwargs.get('cli_path', self._cli_path)
|
||||
if cli_path:
|
||||
options_dict['cli_path'] = cli_path
|
||||
if self.verbose:
|
||||
logger.debug(f"Using custom Claude CLI: {cli_path}")
|
||||
|
||||
# Set model - defaults to Opus 4.5
|
||||
model = kwargs.get('model', self._model)
|
||||
options_dict['model'] = model
|
||||
if self.verbose:
|
||||
logger.debug(f"Using model: {model}")
|
||||
|
||||
# Create options
|
||||
options = ClaudeAgentOptions(**options_dict)
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose:
|
||||
logger.debug("Claude SDK Request:")
|
||||
logger.debug(f" Prompt length: {len(prompt)} characters")
|
||||
logger.debug(f" System prompt: {system_prompt}")
|
||||
if allowed_tools:
|
||||
logger.debug(f" Allowed tools: {allowed_tools}")
|
||||
if disallowed_tools:
|
||||
logger.debug(f" Disallowed tools: {disallowed_tools}")
|
||||
|
||||
# Collect all response chunks
|
||||
output_chunks = []
|
||||
tokens_used = 0
|
||||
chunk_count = 0
|
||||
|
||||
# Use one-shot query for simpler execution
|
||||
if self.verbose:
|
||||
logger.debug("Starting Claude SDK query...")
|
||||
self._console.print_header("CLAUDE PROCESSING")
|
||||
|
||||
async for message in query(prompt=prompt, options=options):
|
||||
chunk_count += 1
|
||||
msg_type = type(message).__name__
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[DEBUG: Received {msg_type}]", flush=True)
|
||||
logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle different message types
|
||||
if msg_type == 'AssistantMessage':
|
||||
# Extract content from AssistantMessage
|
||||
if hasattr(message, 'content') and message.content:
|
||||
for content_block in message.content:
|
||||
block_type = type(content_block).__name__
|
||||
|
||||
if hasattr(content_block, 'text'):
|
||||
# TextBlock
|
||||
text = content_block.text
|
||||
output_chunks.append(text)
|
||||
|
||||
if self.verbose and text:
|
||||
self._console.print_message(text)
|
||||
logger.debug(f"Received assistant text: {len(text)} characters")
|
||||
|
||||
elif block_type == 'ToolUseBlock':
|
||||
if self.verbose:
|
||||
tool_name = getattr(content_block, 'name', 'unknown')
|
||||
tool_id = getattr(content_block, 'id', 'unknown')
|
||||
tool_input = getattr(content_block, 'input', {})
|
||||
|
||||
self._console.print_separator()
|
||||
self._console.print_status(f"TOOL USE: {tool_name}", style="cyan bold")
|
||||
self._console.print_info(f"ID: {tool_id[:12]}...")
|
||||
|
||||
if tool_input:
|
||||
self._console.print_info("Input Parameters:")
|
||||
for key, value in tool_input.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:97] + "..."
|
||||
self._console.print_info(f" - {key}: {value_str}")
|
||||
|
||||
logger.debug(f"Tool use detected: {tool_name} (id: {tool_id[:8]}...)")
|
||||
if hasattr(content_block, 'input'):
|
||||
logger.debug(f" Tool input: {content_block.input}")
|
||||
|
||||
else:
|
||||
if self.verbose:
|
||||
logger.debug(f"Unknown content block type: {block_type}")
|
||||
|
||||
elif msg_type == 'ResultMessage':
|
||||
# ResultMessage contains final result and usage stats
|
||||
if hasattr(message, 'result'):
|
||||
# Don't append result - it's usually a duplicate of assistant message
|
||||
if self.verbose:
|
||||
logger.debug(f"Result message received: {len(str(message.result))} characters")
|
||||
|
||||
# Extract token usage from ResultMessage
|
||||
if hasattr(message, 'usage'):
|
||||
usage = message.usage
|
||||
if isinstance(usage, dict):
|
||||
tokens_used = usage.get('input_tokens', 0) + usage.get('output_tokens', 0)
|
||||
else:
|
||||
tokens_used = getattr(usage, 'total_tokens', 0)
|
||||
if self.verbose:
|
||||
logger.debug(f"Token usage: {tokens_used} tokens")
|
||||
|
||||
elif msg_type == 'SystemMessage':
|
||||
# SystemMessage is initialization data, skip it
|
||||
if self.verbose:
|
||||
logger.debug("System initialization message received")
|
||||
|
||||
elif msg_type == 'UserMessage':
|
||||
if self.verbose:
|
||||
logger.debug("User message (tool result) received")
|
||||
|
||||
if hasattr(message, 'content'):
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
for content_item in content:
|
||||
if hasattr(content_item, '__class__'):
|
||||
item_type = content_item.__class__.__name__
|
||||
if item_type == 'ToolResultBlock':
|
||||
tool_use_id = getattr(content_item, 'tool_use_id', 'unknown')
|
||||
result_content = getattr(content_item, 'content', None)
|
||||
is_error = getattr(content_item, 'is_error', False)
|
||||
|
||||
self._console.print_separator()
|
||||
self._console.print_status("TOOL RESULT", style="yellow bold")
|
||||
self._console.print_info(f"For Tool ID: {tool_use_id[:12]}...")
|
||||
|
||||
if is_error:
|
||||
self._console.print_error("Status: ERROR")
|
||||
else:
|
||||
self._console.print_success("Status: Success")
|
||||
|
||||
if result_content:
|
||||
self._console.print_info("Output:")
|
||||
if isinstance(result_content, str):
|
||||
if len(result_content) > 500:
|
||||
self._console.print_message(f" {result_content[:497]}...")
|
||||
else:
|
||||
self._console.print_message(f" {result_content}")
|
||||
elif isinstance(result_content, list):
|
||||
for item in result_content[:3]:
|
||||
self._console.print_info(f" - {item}")
|
||||
if len(result_content) > 3:
|
||||
self._console.print_info(f" ... and {len(result_content) - 3} more items")
|
||||
|
||||
elif msg_type == 'ToolResultMessage':
|
||||
if self.verbose:
|
||||
logger.debug("Tool result message received")
|
||||
|
||||
self._console.print_separator()
|
||||
self._console.print_status("TOOL RESULT MESSAGE", style="yellow bold")
|
||||
|
||||
if hasattr(message, 'tool_use_id'):
|
||||
self._console.print_info(f"Tool ID: {message.tool_use_id[:12]}...")
|
||||
|
||||
if hasattr(message, 'content'):
|
||||
content = message.content
|
||||
if content:
|
||||
self._console.print_info("Content:")
|
||||
if isinstance(content, str):
|
||||
if len(content) > 500:
|
||||
self._console.print_message(f" {content[:497]}...")
|
||||
else:
|
||||
self._console.print_message(f" {content}")
|
||||
elif isinstance(content, list):
|
||||
for item in content[:3]:
|
||||
self._console.print_info(f" - {item}")
|
||||
if len(content) > 3:
|
||||
self._console.print_info(f" ... and {len(content) - 3} more items")
|
||||
|
||||
if hasattr(message, 'is_error') and message.is_error:
|
||||
self._console.print_error("Error: True")
|
||||
|
||||
elif hasattr(message, 'text'):
|
||||
chunk_text = message.text
|
||||
output_chunks.append(chunk_text)
|
||||
if self.verbose:
|
||||
self._console.print_message(chunk_text)
|
||||
logger.debug(f"Received text chunk {chunk_count}: {len(chunk_text)} characters")
|
||||
|
||||
elif isinstance(message, str):
|
||||
output_chunks.append(message)
|
||||
if self.verbose:
|
||||
self._console.print_message(message)
|
||||
logger.debug(f"Received string chunk {chunk_count}: {len(message)} characters")
|
||||
|
||||
else:
|
||||
if self.verbose:
|
||||
logger.debug(f"Unknown message type {msg_type}: {message}")
|
||||
|
||||
# Combine output
|
||||
output = ''.join(output_chunks)
|
||||
|
||||
# End streaming section if verbose
|
||||
if self.verbose:
|
||||
self._console.print_separator()
|
||||
|
||||
# Always log the output we're about to return
|
||||
logger.debug(f"Claude adapter returning {len(output)} characters of output")
|
||||
if output:
|
||||
logger.debug(f"Output preview: {output[:200]}...")
|
||||
|
||||
# Calculate cost if we have token count (using model-specific pricing)
|
||||
cost = self._calculate_cost(tokens_used, model) if tokens_used > 0 else None
|
||||
|
||||
# Log response details if verbose
|
||||
if self.verbose:
|
||||
logger.debug("Claude SDK Response:")
|
||||
logger.debug(f" Output length: {len(output)} characters")
|
||||
logger.debug(f" Chunks received: {chunk_count}")
|
||||
if tokens_used > 0:
|
||||
logger.debug(f" Tokens used: {tokens_used}")
|
||||
if cost:
|
||||
logger.debug(f" Estimated cost: ${cost:.4f}")
|
||||
logger.debug(f"Response preview: {output[:500]}..." if len(output) > 500 else f"Response: {output}")
|
||||
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=output,
|
||||
tokens_used=tokens_used if tokens_used > 0 else None,
|
||||
cost=cost,
|
||||
metadata={"model": model}
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
# Use error formatter for user-friendly timeout message
|
||||
error_msg = ClaudeErrorFormatter.format_error_from_exception(
|
||||
iteration=kwargs.get('iteration', 0),
|
||||
exception=e
|
||||
)
|
||||
logger.warning(f"Claude SDK request timed out: {error_msg.message}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(error_msg)
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if this is a user-initiated cancellation (SIGINT = exit code -2)
|
||||
error_str = str(e)
|
||||
if "exit code -2" in error_str or "exit code: -2" in error_str:
|
||||
logger.debug("Claude execution cancelled by user")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Execution cancelled by user"
|
||||
)
|
||||
|
||||
# Use error formatter for user-friendly error messages
|
||||
error_msg = ClaudeErrorFormatter.format_error_from_exception(
|
||||
iteration=kwargs.get('iteration', 0),
|
||||
exception=e
|
||||
)
|
||||
|
||||
# Log additional debugging information for "Command failed" errors
|
||||
if "Command failed with exit code 1" in error_str:
|
||||
logger.error(f"Claude CLI command failed - this may indicate:")
|
||||
logger.error(" 1. Claude CLI not properly installed or configured")
|
||||
logger.error(" 2. Missing API key or authentication issues")
|
||||
logger.error(" 3. Network connectivity problems")
|
||||
logger.error(" 4. Insufficient permissions")
|
||||
logger.error(f" Full error: {error_str}")
|
||||
|
||||
# Try to provide more specific guidance
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(['claude', '--version'], capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
logger.error(f" Claude CLI version: {result.stdout.strip()}")
|
||||
else:
|
||||
logger.error(f" Claude CLI check failed: {result.stderr}")
|
||||
except Exception as check_e:
|
||||
logger.error(f" Could not check Claude CLI: {check_e}")
|
||||
|
||||
logger.error(f"Claude SDK error: {error_msg.message}", exc_info=True)
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(error_msg)
|
||||
)
|
||||
|
||||
def _calculate_cost(self, tokens: Optional[int], model: str = None) -> Optional[float]:
|
||||
"""Calculate estimated cost based on tokens and model.
|
||||
|
||||
Args:
|
||||
tokens: Total tokens used (input + output combined)
|
||||
model: Model ID used for the request
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD, or None if tokens is None/0
|
||||
"""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
model = model or self._model
|
||||
|
||||
# Get model pricing or use default
|
||||
if model in self.MODEL_PRICING:
|
||||
pricing = self.MODEL_PRICING[model]
|
||||
else:
|
||||
# Fallback to Opus 4.5 pricing for unknown models
|
||||
pricing = self.MODEL_PRICING[self.DEFAULT_MODEL]
|
||||
|
||||
# Estimate input/output split (typically ~30% input, ~70% output for agent work)
|
||||
# This is an approximation since we don't always get separate counts
|
||||
input_tokens = int(tokens * 0.3)
|
||||
output_tokens = int(tokens * 0.7)
|
||||
|
||||
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
||||
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
||||
|
||||
return input_cost + output_cost
|
||||
|
||||
def estimate_cost(self, prompt: str, model: str = None) -> float:
|
||||
"""Estimate cost for the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt text to estimate cost for
|
||||
model: Model ID to use for pricing (defaults to configured model)
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
# Rough estimation: 1 token ≈ 4 characters
|
||||
estimated_tokens = len(prompt) / 4
|
||||
return self._calculate_cost(int(estimated_tokens), model) or 0.0
|
||||
|
||||
def kill_subprocess_sync(self) -> None:
|
||||
"""
|
||||
Kill subprocess synchronously (safe to call from signal handler).
|
||||
|
||||
This method uses os.kill() which is signal-safe and can be called
|
||||
from the signal handler context. It immediately terminates the subprocess,
|
||||
which unblocks any I/O operations waiting on it.
|
||||
"""
|
||||
if self._subprocess_pid:
|
||||
try:
|
||||
# Try SIGTERM first for graceful shutdown
|
||||
os.kill(self._subprocess_pid, signal.SIGTERM)
|
||||
# Small delay to allow graceful shutdown - keep minimal for signal handler
|
||||
import time
|
||||
try:
|
||||
time.sleep(0.01)
|
||||
except Exception:
|
||||
pass # Ignore errors during sleep in signal handler
|
||||
# Then SIGKILL if still alive (more forceful)
|
||||
try:
|
||||
os.kill(self._subprocess_pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass # Already dead from SIGTERM
|
||||
except ProcessLookupError:
|
||||
pass # Already dead
|
||||
except (PermissionError, OSError):
|
||||
pass # Best effort - process might be owned by another user
|
||||
finally:
|
||||
self._subprocess_pid = None
|
||||
|
||||
async def _cleanup_transport(self) -> None:
|
||||
"""Clean up transport and kill subprocess with timeout protection."""
|
||||
# Kill subprocess first (if not already killed by signal handler)
|
||||
if self._subprocess_pid:
|
||||
try:
|
||||
# Try SIGTERM first
|
||||
os.kill(self._subprocess_pid, signal.SIGTERM)
|
||||
# Wait with timeout to avoid hanging
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.sleep(0.01), timeout=0.05)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Continue even if sleep times out
|
||||
# Force kill if still alive
|
||||
try:
|
||||
os.kill(self._subprocess_pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass # Already dead
|
||||
except ProcessLookupError:
|
||||
pass # Already terminated
|
||||
except (PermissionError, OSError):
|
||||
pass # Best effort cleanup
|
||||
finally:
|
||||
self._subprocess_pid = None
|
||||
@@ -0,0 +1,120 @@
|
||||
# ABOUTME: Gemini CLI adapter implementation
|
||||
# ABOUTME: Provides fallback integration with Google's Gemini AI
|
||||
|
||||
"""Gemini CLI adapter for Ralph Orchestrator."""
|
||||
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
|
||||
|
||||
class GeminiAdapter(ToolAdapter):
|
||||
"""Adapter for Gemini CLI tool."""
|
||||
|
||||
def __init__(self):
|
||||
self.command = "gemini"
|
||||
super().__init__("gemini")
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if Gemini CLI is available."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self.command, "--version"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute Gemini with the given prompt."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Gemini CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Build command
|
||||
cmd = [self.command]
|
||||
|
||||
# Add model if specified
|
||||
if kwargs.get("model"):
|
||||
cmd.extend(["--model", kwargs["model"]])
|
||||
|
||||
# Add the enhanced prompt
|
||||
cmd.extend(["-p", enhanced_prompt])
|
||||
|
||||
# Add output format if specified
|
||||
if kwargs.get("output_format"):
|
||||
cmd.extend(["--output", kwargs["output_format"]])
|
||||
|
||||
# Execute command
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=kwargs.get("timeout", 300) # 5 minute default
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
# Extract token count if available
|
||||
tokens = self._extract_token_count(result.stderr)
|
||||
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=result.stdout,
|
||||
tokens_used=tokens,
|
||||
cost=self._calculate_cost(tokens),
|
||||
metadata={"model": kwargs.get("model", "gemini-2.5-pro")}
|
||||
)
|
||||
else:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=result.stdout,
|
||||
error=result.stderr or "Gemini command failed"
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Gemini command timed out"
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _extract_token_count(self, stderr: str) -> Optional[int]:
|
||||
"""Extract token count from Gemini output."""
|
||||
# Implementation depends on Gemini's output format
|
||||
return None
|
||||
|
||||
def _calculate_cost(self, tokens: Optional[int]) -> Optional[float]:
|
||||
"""Calculate estimated cost based on tokens."""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
# Gemini has free tier up to 1M tokens
|
||||
if tokens < 1_000_000:
|
||||
return 0.0
|
||||
|
||||
# After free tier: $0.001 per 1K tokens (approximate)
|
||||
excess_tokens = tokens - 1_000_000
|
||||
cost_per_1k = 0.001
|
||||
return (excess_tokens / 1000) * cost_per_1k
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Estimate cost for the prompt."""
|
||||
# Rough estimation: 1 token ≈ 4 characters
|
||||
estimated_tokens = len(prompt) / 4
|
||||
return self._calculate_cost(estimated_tokens) or 0.0
|
||||
@@ -0,0 +1,562 @@
|
||||
# ABOUTME: Kiro CLI adapter implementation
|
||||
# ABOUTME: Provides integration with kiro-cli chat command for AI interactions
|
||||
|
||||
"""Kiro CLI adapter for Ralph Orchestrator."""
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import threading
|
||||
import asyncio
|
||||
import time
|
||||
try:
|
||||
import fcntl # Unix-only
|
||||
except ModuleNotFoundError:
|
||||
fcntl = None
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from ..logging_config import RalphLogger
|
||||
|
||||
# Get logger for this module
|
||||
logger = RalphLogger.get_logger(RalphLogger.ADAPTER_KIRO)
|
||||
|
||||
|
||||
class KiroAdapter(ToolAdapter):
|
||||
"""Adapter for Kiro CLI tool."""
|
||||
|
||||
def __init__(self):
|
||||
# Get configuration from environment variables
|
||||
self.command = os.getenv("RALPH_KIRO_COMMAND", "kiro-cli")
|
||||
self.default_timeout = int(os.getenv("RALPH_KIRO_TIMEOUT", "600"))
|
||||
self.default_prompt_file = os.getenv("RALPH_KIRO_PROMPT_FILE", "PROMPT.md")
|
||||
self.trust_all_tools = os.getenv("RALPH_KIRO_TRUST_TOOLS", "true").lower() == "true"
|
||||
self.no_interactive = os.getenv("RALPH_KIRO_NO_INTERACTIVE", "true").lower() == "true"
|
||||
|
||||
# Initialize signal handler attributes before calling super()
|
||||
self._original_sigint = None
|
||||
self._original_sigterm = None
|
||||
|
||||
super().__init__("kiro")
|
||||
self.current_process = None
|
||||
self.shutdown_requested = False
|
||||
|
||||
# Thread synchronization
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Register signal handlers to propagate shutdown to subprocess
|
||||
self._register_signal_handlers()
|
||||
|
||||
logger.info(f"Kiro adapter initialized - Command: {self.command}, "
|
||||
f"Default timeout: {self.default_timeout}s, "
|
||||
f"Trust tools: {self.trust_all_tools}")
|
||||
|
||||
def _register_signal_handlers(self):
|
||||
"""Register signal handlers and store originals."""
|
||||
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
|
||||
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
logger.debug("Signal handlers registered for SIGINT and SIGTERM")
|
||||
|
||||
def _restore_signal_handlers(self):
|
||||
"""Restore original signal handlers."""
|
||||
if hasattr(self, "_original_sigint") and self._original_sigint is not None:
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if hasattr(self, "_original_sigterm") and self._original_sigterm is not None:
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals and terminate running subprocess."""
|
||||
with self._lock:
|
||||
self.shutdown_requested = True
|
||||
process = self.current_process
|
||||
|
||||
if process and process.poll() is None:
|
||||
logger.warning(f"Received signal {signum}, terminating Kiro process...")
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=3)
|
||||
logger.debug("Process terminated gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Force killing Kiro process...")
|
||||
process.kill()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process may still be running after force kill")
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if kiro-cli (or q) is available."""
|
||||
# First check for configured command (defaults to kiro-cli)
|
||||
if self._check_command(self.command):
|
||||
logger.debug(f"Kiro command '{self.command}' available")
|
||||
return True
|
||||
|
||||
# If kiro-cli not found, try fallback to 'q'
|
||||
if self.command == "kiro-cli":
|
||||
if self._check_command("q"):
|
||||
logger.info("kiro-cli not found, falling back to 'q'")
|
||||
self.command = "q"
|
||||
return True
|
||||
|
||||
logger.warning(f"Kiro command '{self.command}' (and fallback) not found")
|
||||
return False
|
||||
|
||||
def _check_command(self, cmd: str) -> bool:
|
||||
"""Check if a specific command exists."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["which", cmd],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute kiro-cli chat with the given prompt."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Kiro CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get verbose flag from kwargs
|
||||
verbose = kwargs.get("verbose", True)
|
||||
|
||||
# Get the prompt file path from kwargs if available
|
||||
prompt_file = kwargs.get("prompt_file", self.default_prompt_file)
|
||||
|
||||
logger.info(f"Executing Kiro chat - Prompt file: {prompt_file}, Verbose: {verbose}")
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Construct a more effective prompt
|
||||
# Tell it explicitly to edit the prompt file
|
||||
effective_prompt = (
|
||||
f"Please read and complete the task described in the file '{prompt_file}'. "
|
||||
f"The current content is:\n\n{enhanced_prompt}\n\n"
|
||||
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
|
||||
)
|
||||
|
||||
# Build command
|
||||
cmd = [self.command, "chat"]
|
||||
|
||||
if self.no_interactive:
|
||||
cmd.append("--no-interactive")
|
||||
|
||||
if self.trust_all_tools:
|
||||
cmd.append("--trust-all-tools")
|
||||
|
||||
cmd.append(effective_prompt)
|
||||
|
||||
logger.debug(f"Command constructed: {' '.join(cmd)}")
|
||||
|
||||
timeout = kwargs.get("timeout", self.default_timeout)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"Starting {self.command} chat command...")
|
||||
logger.info(f"Command: {' '.join(cmd)}")
|
||||
logger.info(f"Working directory: {os.getcwd()}")
|
||||
logger.info(f"Timeout: {timeout} seconds")
|
||||
print("-" * 60, file=sys.stderr)
|
||||
|
||||
# Use Popen for real-time output streaming
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=os.getcwd(),
|
||||
bufsize=0, # Unbuffered to prevent deadlock
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
# Set process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = process
|
||||
|
||||
# Make pipes non-blocking to prevent deadlock
|
||||
self._make_non_blocking(process.stdout)
|
||||
self._make_non_blocking(process.stderr)
|
||||
|
||||
# Collect output while streaming
|
||||
stdout_lines = []
|
||||
stderr_lines = []
|
||||
|
||||
start_time = time.time()
|
||||
last_output_time = start_time
|
||||
|
||||
while True:
|
||||
# Check for shutdown signal first with lock
|
||||
with self._lock:
|
||||
shutdown = self.shutdown_requested
|
||||
|
||||
if shutdown:
|
||||
if verbose:
|
||||
print("Shutdown requested, terminating process...", file=sys.stderr)
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait(timeout=2)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="".join(stdout_lines),
|
||||
error="Process terminated due to shutdown signal"
|
||||
)
|
||||
|
||||
# Check for timeout
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Log progress every 30 seconds
|
||||
if int(elapsed_time) % 30 == 0 and int(elapsed_time) > 0:
|
||||
logger.debug(f"Kiro still running... elapsed: {elapsed_time:.1f}s / {timeout}s")
|
||||
|
||||
# Check if the process seems stuck (no output for a while)
|
||||
time_since_output = time.time() - last_output_time
|
||||
if time_since_output > 60:
|
||||
logger.info(f"No output received for {time_since_output:.1f}s, Kiro might be stuck")
|
||||
|
||||
if verbose:
|
||||
print(f"Kiro still running... elapsed: {elapsed_time:.1f}s / {timeout}s", file=sys.stderr)
|
||||
|
||||
if elapsed_time > timeout:
|
||||
logger.warning(f"Command timed out after {elapsed_time:.2f} seconds")
|
||||
if verbose:
|
||||
print(f"Command timed out after {elapsed_time:.2f} seconds", file=sys.stderr)
|
||||
|
||||
# Try to terminate gracefully first
|
||||
process.terminate()
|
||||
try:
|
||||
# Wait a bit for graceful termination
|
||||
process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Graceful termination failed, force killing process")
|
||||
if verbose:
|
||||
print("Graceful termination failed, force killing process", file=sys.stderr)
|
||||
process.kill()
|
||||
# Wait for force kill to complete
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process may still be running after kill")
|
||||
if verbose:
|
||||
print("Warning: Process may still be running after kill", file=sys.stderr)
|
||||
|
||||
# Try to capture any remaining output after termination
|
||||
try:
|
||||
remaining_stdout = process.stdout.read()
|
||||
remaining_stderr = process.stderr.read()
|
||||
if remaining_stdout:
|
||||
stdout_lines.append(remaining_stdout)
|
||||
if remaining_stderr:
|
||||
stderr_lines.append(remaining_stderr)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read remaining output after timeout: {e}")
|
||||
if verbose:
|
||||
print(f"Warning: Could not read remaining output after timeout: {e}", file=sys.stderr)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="".join(stdout_lines),
|
||||
error=f"Kiro command timed out after {elapsed_time:.2f} seconds"
|
||||
)
|
||||
|
||||
# Check if process is still running
|
||||
if process.poll() is not None:
|
||||
# Process finished, read remaining output
|
||||
remaining_stdout = process.stdout.read()
|
||||
remaining_stderr = process.stderr.read()
|
||||
|
||||
if remaining_stdout:
|
||||
stdout_lines.append(remaining_stdout)
|
||||
if verbose:
|
||||
print(f"{remaining_stdout}", end="", file=sys.stderr)
|
||||
|
||||
if remaining_stderr:
|
||||
stderr_lines.append(remaining_stderr)
|
||||
if verbose:
|
||||
print(f"{remaining_stderr}", end="", file=sys.stderr)
|
||||
|
||||
break
|
||||
|
||||
# Read available data without blocking
|
||||
try:
|
||||
# Read stdout
|
||||
stdout_data = self._read_available(process.stdout)
|
||||
if stdout_data:
|
||||
stdout_lines.append(stdout_data)
|
||||
last_output_time = time.time()
|
||||
if verbose:
|
||||
print(stdout_data, end="", file=sys.stderr)
|
||||
|
||||
# Read stderr
|
||||
stderr_data = self._read_available(process.stderr)
|
||||
if stderr_data:
|
||||
stderr_lines.append(stderr_data)
|
||||
last_output_time = time.time()
|
||||
if verbose:
|
||||
print(stderr_data, end="", file=sys.stderr)
|
||||
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.01)
|
||||
|
||||
except BlockingIOError:
|
||||
# Non-blocking read returns BlockingIOError when no data available
|
||||
pass
|
||||
|
||||
# Get final return code
|
||||
returncode = process.poll()
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"Process completed - Return code: {returncode}, Execution time: {execution_time:.2f}s")
|
||||
|
||||
if verbose:
|
||||
print("-" * 60, file=sys.stderr)
|
||||
print(f"Process completed with return code: {returncode}", file=sys.stderr)
|
||||
print(f"Total execution time: {execution_time:.2f} seconds", file=sys.stderr)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
# Combine output
|
||||
full_stdout = "".join(stdout_lines)
|
||||
full_stderr = "".join(stderr_lines)
|
||||
|
||||
if returncode == 0:
|
||||
logger.debug(f"Kiro chat succeeded - Output length: {len(full_stdout)} chars")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=full_stdout,
|
||||
metadata={
|
||||
"tool": "kiro chat",
|
||||
"execution_time": execution_time,
|
||||
"verbose": verbose,
|
||||
"return_code": returncode
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Kiro chat failed - Return code: {returncode}, Error: {full_stderr[:200]}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=full_stdout,
|
||||
error=full_stderr or f"Kiro command failed with code {returncode}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during Kiro chat execution: {str(e)}")
|
||||
if verbose:
|
||||
print(f"Exception occurred: {str(e)}", file=sys.stderr)
|
||||
# Clean up process reference on exception
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _make_non_blocking(self, pipe):
|
||||
"""Make a pipe non-blocking to prevent deadlock."""
|
||||
if not pipe or fcntl is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fd = pipe.fileno()
|
||||
if isinstance(fd, int) and fd >= 0:
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
except (AttributeError, ValueError, OSError):
|
||||
# In tests or when pipe doesn't support fileno()
|
||||
pass
|
||||
|
||||
|
||||
def _read_available(self, pipe):
|
||||
"""Read available data from a non-blocking pipe."""
|
||||
if not pipe:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Try to read up to 4KB at a time
|
||||
data = pipe.read(4096)
|
||||
# Ensure we always return a string, not None
|
||||
if data is None:
|
||||
return ""
|
||||
return data if data else ""
|
||||
except (IOError, OSError):
|
||||
# Would block or no data available
|
||||
return ""
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Native async execution using asyncio subprocess."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Kiro CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
verbose = kwargs.get("verbose", True)
|
||||
prompt_file = kwargs.get("prompt_file", self.default_prompt_file)
|
||||
timeout = kwargs.get("timeout", self.default_timeout)
|
||||
|
||||
logger.info(f"Executing Kiro chat async - Prompt file: {prompt_file}, Timeout: {timeout}s")
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Construct effective prompt
|
||||
effective_prompt = (
|
||||
f"Please read and complete the task described in the file '{prompt_file}'. "
|
||||
f"The current content is:\n\n{enhanced_prompt}\n\n"
|
||||
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
|
||||
)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
self.command,
|
||||
"chat",
|
||||
"--no-interactive",
|
||||
"--trust-all-tools",
|
||||
effective_prompt
|
||||
]
|
||||
|
||||
logger.debug(f"Starting async Kiro chat command: {' '.join(cmd)}")
|
||||
if verbose:
|
||||
print(f"Starting {self.command} chat command (async)...", file=sys.stderr)
|
||||
print(f"Command: {' '.join(cmd)}", file=sys.stderr)
|
||||
print("-" * 60, file=sys.stderr)
|
||||
|
||||
# Create async subprocess
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=os.getcwd()
|
||||
)
|
||||
|
||||
# Set process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = process
|
||||
|
||||
try:
|
||||
# Wait for completion with timeout
|
||||
stdout_data, stderr_data = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Decode output
|
||||
stdout = stdout_data.decode("utf-8") if stdout_data else ""
|
||||
stderr = stderr_data.decode("utf-8") if stderr_data else ""
|
||||
|
||||
if verbose and stdout:
|
||||
print(stdout, file=sys.stderr)
|
||||
if verbose and stderr:
|
||||
print(stderr, file=sys.stderr)
|
||||
|
||||
# Check return code
|
||||
if process.returncode == 0:
|
||||
logger.debug(f"Async Kiro chat succeeded - Output length: {len(stdout)} chars")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=stdout,
|
||||
metadata={
|
||||
"tool": "kiro chat",
|
||||
"verbose": verbose,
|
||||
"async": True,
|
||||
"return_code": process.returncode
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Async Kiro chat failed - Return code: {process.returncode}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=stdout,
|
||||
error=stderr or f"Kiro chat failed with code {process.returncode}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout occurred
|
||||
logger.warning(f"Async Kiro chat timed out after {timeout} seconds")
|
||||
if verbose:
|
||||
print(f"Async Kiro chat timed out after {timeout} seconds", file=sys.stderr)
|
||||
|
||||
# Try to terminate process
|
||||
try:
|
||||
process.terminate()
|
||||
await asyncio.wait_for(process.wait(), timeout=3)
|
||||
except (asyncio.TimeoutError, ProcessLookupError):
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"Kiro command timed out after {timeout} seconds"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up process reference
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Async execution error: {str(e)}")
|
||||
if kwargs.get("verbose"):
|
||||
print(f"Async execution error: {str(e)}", file=sys.stderr)
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Kiro chat cost estimation (if applicable)."""
|
||||
# Return 0 for now
|
||||
return 0.0
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
# Restore original signal handlers
|
||||
self._restore_signal_handlers()
|
||||
|
||||
# Ensure any running process is terminated
|
||||
if hasattr(self, "_lock"):
|
||||
with self._lock:
|
||||
process = self.current_process if hasattr(self, "current_process") else None
|
||||
else:
|
||||
process = getattr(self, "current_process", None)
|
||||
|
||||
if process:
|
||||
try:
|
||||
if hasattr(process, "poll"):
|
||||
# Sync process
|
||||
if process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait(timeout=1)
|
||||
else:
|
||||
# Async process - can't do much in __del__
|
||||
pass
|
||||
except Exception as e:
|
||||
# Best-effort cleanup during interpreter shutdown
|
||||
logger.debug(f"Cleanup warning in __del__: {type(e).__name__}: {e}")
|
||||
@@ -0,0 +1,579 @@
|
||||
# ABOUTME: Q Chat adapter implementation for q CLI tool (DEPRECATED)
|
||||
# ABOUTME: Provides integration with q chat command for AI interactions
|
||||
# ABOUTME: NOTE: Consider using KiroAdapter instead - Q CLI rebranded to Kiro CLI
|
||||
|
||||
"""Q Chat adapter for Ralph Orchestrator.
|
||||
|
||||
.. deprecated::
|
||||
The QChatAdapter is deprecated in favor of KiroAdapter.
|
||||
Amazon Q Developer CLI has been rebranded to Kiro CLI (v1.20+).
|
||||
Use `-a kiro` instead of `-a qchat` or `-a q` for new projects.
|
||||
|
||||
Migration guide:
|
||||
- Config paths changed: ~/.aws/amazonq/ -> ~/.kiro/
|
||||
- MCP servers: ~/.kiro/settings/mcp.json
|
||||
- Project files: .kiro/ folder
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import threading
|
||||
import asyncio
|
||||
import time
|
||||
import warnings
|
||||
try:
|
||||
import fcntl # Unix-only
|
||||
except ModuleNotFoundError:
|
||||
fcntl = None
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from ..logging_config import RalphLogger
|
||||
|
||||
# Get logger for this module
|
||||
logger = RalphLogger.get_logger(RalphLogger.ADAPTER_QCHAT)
|
||||
|
||||
|
||||
class QChatAdapter(ToolAdapter):
|
||||
"""Adapter for Q Chat CLI tool.
|
||||
|
||||
.. deprecated::
|
||||
Use :class:`KiroAdapter` instead. The Amazon Q Developer CLI has been
|
||||
rebranded to Kiro CLI. This adapter is maintained for backwards
|
||||
compatibility only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Emit deprecation warning
|
||||
warnings.warn(
|
||||
"QChatAdapter is deprecated. Use KiroAdapter instead. "
|
||||
"Amazon Q Developer CLI has been rebranded to Kiro CLI (v1.20+). "
|
||||
"Run with '-a kiro' instead of '-a qchat'.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Get configuration from environment variables
|
||||
self.command = os.getenv("RALPH_QCHAT_COMMAND", "q")
|
||||
self.default_timeout = int(os.getenv("RALPH_QCHAT_TIMEOUT", "600"))
|
||||
self.default_prompt_file = os.getenv("RALPH_QCHAT_PROMPT_FILE", "PROMPT.md")
|
||||
self.trust_all_tools = os.getenv("RALPH_QCHAT_TRUST_TOOLS", "true").lower() == "true"
|
||||
self.no_interactive = os.getenv("RALPH_QCHAT_NO_INTERACTIVE", "true").lower() == "true"
|
||||
|
||||
# Initialize signal handler attributes before calling super()
|
||||
self._original_sigint = None
|
||||
self._original_sigterm = None
|
||||
|
||||
super().__init__("qchat")
|
||||
self.current_process = None
|
||||
self.shutdown_requested = False
|
||||
|
||||
# Thread synchronization
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Register signal handlers to propagate shutdown to subprocess
|
||||
self._register_signal_handlers()
|
||||
|
||||
logger.info(f"Q Chat adapter initialized - Command: {self.command}, "
|
||||
f"Default timeout: {self.default_timeout}s, "
|
||||
f"Trust tools: {self.trust_all_tools}")
|
||||
|
||||
def _register_signal_handlers(self):
|
||||
"""Register signal handlers and store originals."""
|
||||
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
|
||||
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
logger.debug("Signal handlers registered for SIGINT and SIGTERM")
|
||||
|
||||
def _restore_signal_handlers(self):
|
||||
"""Restore original signal handlers."""
|
||||
if hasattr(self, '_original_sigint') and self._original_sigint is not None:
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if hasattr(self, '_original_sigterm') and self._original_sigterm is not None:
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals and terminate running subprocess."""
|
||||
with self._lock:
|
||||
self.shutdown_requested = True
|
||||
process = self.current_process
|
||||
|
||||
if process and process.poll() is None:
|
||||
logger.warning(f"Received signal {signum}, terminating q chat process...")
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=3)
|
||||
logger.debug("Process terminated gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Force killing q chat process...")
|
||||
process.kill()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process may still be running after force kill")
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if q CLI is available."""
|
||||
try:
|
||||
# Try to check if q command exists
|
||||
result = subprocess.run(
|
||||
["which", self.command],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
text=True
|
||||
)
|
||||
available = result.returncode == 0
|
||||
logger.debug(f"Q command '{self.command}' availability check: {available}")
|
||||
return available
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError) as e:
|
||||
logger.warning(f"Q command availability check failed: {e}")
|
||||
return False
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute q chat with the given prompt."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="q CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get verbose flag from kwargs
|
||||
verbose = kwargs.get('verbose', True)
|
||||
|
||||
# Get the prompt file path from kwargs if available
|
||||
prompt_file = kwargs.get('prompt_file', self.default_prompt_file)
|
||||
|
||||
logger.info(f"Executing Q chat - Prompt file: {prompt_file}, Verbose: {verbose}")
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Construct a more effective prompt for q chat
|
||||
# Tell it explicitly to edit the prompt file
|
||||
effective_prompt = (
|
||||
f"Please read and complete the task described in the file '{prompt_file}'. "
|
||||
f"The current content is:\n\n{enhanced_prompt}\n\n"
|
||||
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
|
||||
)
|
||||
|
||||
# Build command - q chat works with files by adding them to context
|
||||
# We pass the prompt through stdin and tell it to trust file operations
|
||||
cmd = [self.command, "chat"]
|
||||
|
||||
if self.no_interactive:
|
||||
cmd.append("--no-interactive")
|
||||
|
||||
if self.trust_all_tools:
|
||||
cmd.append("--trust-all-tools")
|
||||
|
||||
cmd.append(effective_prompt)
|
||||
|
||||
logger.debug(f"Command constructed: {' '.join(cmd)}")
|
||||
|
||||
timeout = kwargs.get("timeout", self.default_timeout)
|
||||
|
||||
if verbose:
|
||||
logger.info("Starting q chat command...")
|
||||
logger.info(f"Command: {' '.join(cmd)}")
|
||||
logger.info(f"Working directory: {os.getcwd()}")
|
||||
logger.info(f"Timeout: {timeout} seconds")
|
||||
print("-" * 60, file=sys.stderr)
|
||||
|
||||
# Use Popen for real-time output streaming
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=os.getcwd(),
|
||||
bufsize=0, # Unbuffered to prevent deadlock
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
# Set process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = process
|
||||
|
||||
# Make pipes non-blocking to prevent deadlock
|
||||
self._make_non_blocking(process.stdout)
|
||||
self._make_non_blocking(process.stderr)
|
||||
|
||||
# Collect output while streaming
|
||||
stdout_lines = []
|
||||
stderr_lines = []
|
||||
|
||||
start_time = time.time()
|
||||
last_output_time = start_time
|
||||
|
||||
while True:
|
||||
# Check for shutdown signal first with lock
|
||||
with self._lock:
|
||||
shutdown = self.shutdown_requested
|
||||
|
||||
if shutdown:
|
||||
if verbose:
|
||||
print("Shutdown requested, terminating q chat process...", file=sys.stderr)
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait(timeout=2)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="".join(stdout_lines),
|
||||
error="Process terminated due to shutdown signal"
|
||||
)
|
||||
|
||||
# Check for timeout
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Log progress every 30 seconds
|
||||
if int(elapsed_time) % 30 == 0 and int(elapsed_time) > 0:
|
||||
logger.debug(f"Q chat still running... elapsed: {elapsed_time:.1f}s / {timeout}s")
|
||||
|
||||
# Check if the process seems stuck (no output for a while)
|
||||
time_since_output = time.time() - last_output_time
|
||||
if time_since_output > 60:
|
||||
logger.info(f"No output received for {time_since_output:.1f}s, Q might be stuck")
|
||||
|
||||
if verbose:
|
||||
print(f"Q chat still running... elapsed: {elapsed_time:.1f}s / {timeout}s", file=sys.stderr)
|
||||
|
||||
if elapsed_time > timeout:
|
||||
logger.warning(f"Command timed out after {elapsed_time:.2f} seconds")
|
||||
if verbose:
|
||||
print(f"Command timed out after {elapsed_time:.2f} seconds", file=sys.stderr)
|
||||
|
||||
# Try to terminate gracefully first
|
||||
process.terminate()
|
||||
try:
|
||||
# Wait a bit for graceful termination
|
||||
process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Graceful termination failed, force killing process")
|
||||
if verbose:
|
||||
print("Graceful termination failed, force killing process", file=sys.stderr)
|
||||
process.kill()
|
||||
# Wait for force kill to complete
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process may still be running after kill")
|
||||
if verbose:
|
||||
print("Warning: Process may still be running after kill", file=sys.stderr)
|
||||
|
||||
# Try to capture any remaining output after termination
|
||||
try:
|
||||
remaining_stdout = process.stdout.read()
|
||||
remaining_stderr = process.stderr.read()
|
||||
if remaining_stdout:
|
||||
stdout_lines.append(remaining_stdout)
|
||||
if remaining_stderr:
|
||||
stderr_lines.append(remaining_stderr)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read remaining output after timeout: {e}")
|
||||
if verbose:
|
||||
print(f"Warning: Could not read remaining output after timeout: {e}", file=sys.stderr)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="".join(stdout_lines),
|
||||
error=f"q chat command timed out after {elapsed_time:.2f} seconds"
|
||||
)
|
||||
|
||||
# Check if process is still running
|
||||
if process.poll() is not None:
|
||||
# Process finished, read remaining output
|
||||
remaining_stdout = process.stdout.read()
|
||||
remaining_stderr = process.stderr.read()
|
||||
|
||||
if remaining_stdout:
|
||||
stdout_lines.append(remaining_stdout)
|
||||
if verbose:
|
||||
print(f"{remaining_stdout}", end='', file=sys.stderr)
|
||||
|
||||
if remaining_stderr:
|
||||
stderr_lines.append(remaining_stderr)
|
||||
if verbose:
|
||||
print(f"{remaining_stderr}", end='', file=sys.stderr)
|
||||
|
||||
break
|
||||
|
||||
# Read available data without blocking
|
||||
try:
|
||||
# Read stdout
|
||||
stdout_data = self._read_available(process.stdout)
|
||||
if stdout_data:
|
||||
stdout_lines.append(stdout_data)
|
||||
last_output_time = time.time()
|
||||
if verbose:
|
||||
print(stdout_data, end='', file=sys.stderr)
|
||||
|
||||
# Read stderr
|
||||
stderr_data = self._read_available(process.stderr)
|
||||
if stderr_data:
|
||||
stderr_lines.append(stderr_data)
|
||||
last_output_time = time.time()
|
||||
if verbose:
|
||||
print(stderr_data, end='', file=sys.stderr)
|
||||
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.01)
|
||||
|
||||
except BlockingIOError:
|
||||
# Non-blocking read returns BlockingIOError when no data available
|
||||
pass
|
||||
|
||||
# Get final return code
|
||||
returncode = process.poll()
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"Process completed - Return code: {returncode}, Execution time: {execution_time:.2f}s")
|
||||
|
||||
if verbose:
|
||||
print("-" * 60, file=sys.stderr)
|
||||
print(f"Process completed with return code: {returncode}", file=sys.stderr)
|
||||
print(f"Total execution time: {execution_time:.2f} seconds", file=sys.stderr)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
# Combine output
|
||||
full_stdout = "".join(stdout_lines)
|
||||
full_stderr = "".join(stderr_lines)
|
||||
|
||||
if returncode == 0:
|
||||
logger.debug(f"Q chat succeeded - Output length: {len(full_stdout)} chars")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=full_stdout,
|
||||
metadata={
|
||||
"tool": "q chat",
|
||||
"execution_time": execution_time,
|
||||
"verbose": verbose,
|
||||
"return_code": returncode
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Q chat failed - Return code: {returncode}, Error: {full_stderr[:200]}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=full_stdout,
|
||||
error=full_stderr or f"q chat command failed with code {returncode}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during Q chat execution: {str(e)}")
|
||||
if verbose:
|
||||
print(f"Exception occurred: {str(e)}", file=sys.stderr)
|
||||
# Clean up process reference on exception (matches async version's finally block)
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _make_non_blocking(self, pipe):
|
||||
"""Make a pipe non-blocking to prevent deadlock."""
|
||||
if not pipe or fcntl is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fd = pipe.fileno()
|
||||
if isinstance(fd, int) and fd >= 0:
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
except (AttributeError, ValueError, OSError):
|
||||
# In tests or when pipe doesn't support fileno()
|
||||
pass
|
||||
|
||||
|
||||
def _read_available(self, pipe):
|
||||
"""Read available data from a non-blocking pipe."""
|
||||
if not pipe:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Try to read up to 4KB at a time
|
||||
data = pipe.read(4096)
|
||||
# Ensure we always return a string, not None
|
||||
if data is None:
|
||||
return ""
|
||||
return data if data else ""
|
||||
except (IOError, OSError):
|
||||
# Would block or no data available
|
||||
return ""
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Native async execution using asyncio subprocess."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="q CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
verbose = kwargs.get('verbose', True)
|
||||
prompt_file = kwargs.get('prompt_file', self.default_prompt_file)
|
||||
timeout = kwargs.get('timeout', self.default_timeout)
|
||||
|
||||
logger.info(f"Executing Q chat async - Prompt file: {prompt_file}, Timeout: {timeout}s")
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Construct effective prompt
|
||||
effective_prompt = (
|
||||
f"Please read and complete the task described in the file '{prompt_file}'. "
|
||||
f"The current content is:\n\n{enhanced_prompt}\n\n"
|
||||
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
|
||||
)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
self.command,
|
||||
"chat",
|
||||
"--no-interactive",
|
||||
"--trust-all-tools",
|
||||
effective_prompt
|
||||
]
|
||||
|
||||
logger.debug(f"Starting async Q chat command: {' '.join(cmd)}")
|
||||
if verbose:
|
||||
print("Starting q chat command (async)...", file=sys.stderr)
|
||||
print(f"Command: {' '.join(cmd)}", file=sys.stderr)
|
||||
print("-" * 60, file=sys.stderr)
|
||||
|
||||
# Create async subprocess
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=os.getcwd()
|
||||
)
|
||||
|
||||
# Set process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = process
|
||||
|
||||
try:
|
||||
# Wait for completion with timeout
|
||||
stdout_data, stderr_data = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Decode output
|
||||
stdout = stdout_data.decode('utf-8') if stdout_data else ""
|
||||
stderr = stderr_data.decode('utf-8') if stderr_data else ""
|
||||
|
||||
if verbose and stdout:
|
||||
print(stdout, file=sys.stderr)
|
||||
if verbose and stderr:
|
||||
print(stderr, file=sys.stderr)
|
||||
|
||||
# Check return code
|
||||
if process.returncode == 0:
|
||||
logger.debug(f"Async Q chat succeeded - Output length: {len(stdout)} chars")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=stdout,
|
||||
metadata={
|
||||
"tool": "q chat",
|
||||
"verbose": verbose,
|
||||
"async": True,
|
||||
"return_code": process.returncode
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Async Q chat failed - Return code: {process.returncode}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=stdout,
|
||||
error=stderr or f"q chat failed with code {process.returncode}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout occurred
|
||||
logger.warning(f"Async q chat timed out after {timeout} seconds")
|
||||
if verbose:
|
||||
print(f"Async q chat timed out after {timeout} seconds", file=sys.stderr)
|
||||
|
||||
# Try to terminate process
|
||||
try:
|
||||
process.terminate()
|
||||
await asyncio.wait_for(process.wait(), timeout=3)
|
||||
except (asyncio.TimeoutError, ProcessLookupError):
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"q chat command timed out after {timeout} seconds"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up process reference
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Async execution error: {str(e)}")
|
||||
if kwargs.get('verbose'):
|
||||
print(f"Async execution error: {str(e)}", file=sys.stderr)
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Q chat cost estimation (if applicable)."""
|
||||
# Q chat might be free or have different pricing
|
||||
# Return 0 for now, can be updated based on actual pricing
|
||||
return 0.0
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
# Restore original signal handlers
|
||||
self._restore_signal_handlers()
|
||||
|
||||
# Ensure any running process is terminated
|
||||
if hasattr(self, '_lock'):
|
||||
with self._lock:
|
||||
process = self.current_process if hasattr(self, 'current_process') else None
|
||||
else:
|
||||
process = getattr(self, 'current_process', None)
|
||||
|
||||
if process:
|
||||
try:
|
||||
if hasattr(process, 'poll'):
|
||||
# Sync process
|
||||
if process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait(timeout=1)
|
||||
else:
|
||||
# Async process - can't do much in __del__
|
||||
pass
|
||||
except Exception as e:
|
||||
# Best-effort cleanup during interpreter shutdown
|
||||
# Log at debug level since __del__ is unreliable
|
||||
logger.debug(f"Cleanup warning in __del__: {type(e).__name__}: {e}")
|
||||
@@ -0,0 +1,472 @@
|
||||
# ABOUTME: Advanced async logging with rotation, thread safety, and security features
|
||||
# ABOUTME: Provides dual interface (async + sync) with unicode sanitization
|
||||
|
||||
"""
|
||||
Advanced async logging with rotation for Ralph Orchestrator.
|
||||
|
||||
Features:
|
||||
- Automatic log rotation at 10MB with 3 backups
|
||||
- Thread-safe rotation with threading.Lock
|
||||
- Unicode sanitization for encoding errors
|
||||
- Security-aware logging (masks sensitive data)
|
||||
- Dual interface: async methods + sync wrappers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ralph_orchestrator.security import SecurityValidator
|
||||
|
||||
|
||||
def async_method_warning(func):
|
||||
"""Decorator to warn when async methods are called without await."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
coro = func(self, *args, **kwargs)
|
||||
|
||||
class WarningCoroutine:
|
||||
"""Wrapper that warns when garbage collected without being awaited."""
|
||||
|
||||
def __init__(self, coro, method_name):
|
||||
self._coro = coro
|
||||
self._method_name = method_name
|
||||
self._warned = False
|
||||
self._awaited = False
|
||||
|
||||
def __await__(self):
|
||||
self._awaited = True
|
||||
return self._coro.__await__()
|
||||
|
||||
def __del__(self):
|
||||
if not self._warned and not self._awaited:
|
||||
warnings.warn(
|
||||
f"AsyncFileLogger.{self._method_name}() was called without await. "
|
||||
"The message was not logged. Use 'await logger.{self._method_name}(...)' "
|
||||
f"or 'logger.{self._method_name}_sync(...)' instead.",
|
||||
RuntimeWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
self._warned = True
|
||||
|
||||
def close(self):
|
||||
"""Support close() method for compatibility."""
|
||||
pass
|
||||
|
||||
return WarningCoroutine(coro, func.__name__)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AsyncFileLogger:
|
||||
"""Async file logger with timestamps, rotation, and security features."""
|
||||
|
||||
# Log rotation constants
|
||||
MAX_LOG_SIZE_BYTES = 10 * 1024 * 1024 # 10MB in bytes
|
||||
MAX_BACKUP_FILES = 3
|
||||
|
||||
# Default values for log parsing
|
||||
DEFAULT_RECENT_LINES_COUNT = 3
|
||||
|
||||
def __init__(self, log_file: str, verbose: bool = False) -> None:
|
||||
"""
|
||||
Initialize async logger.
|
||||
|
||||
Args:
|
||||
log_file: Path to log file
|
||||
verbose: If True, also print to console
|
||||
|
||||
Raises:
|
||||
ValueError: If log_file is None or empty
|
||||
"""
|
||||
if not log_file:
|
||||
raise ValueError("log_file cannot be None or empty")
|
||||
|
||||
# Convert to string for validation if it's a Path object
|
||||
log_file_str = str(log_file)
|
||||
if not log_file_str or not log_file_str.strip():
|
||||
raise ValueError("log_file cannot be empty")
|
||||
|
||||
self.log_file = Path(log_file)
|
||||
self.verbose = verbose
|
||||
self._lock = asyncio.Lock() # For async methods (single event loop)
|
||||
self._thread_lock = threading.Lock() # For sync methods (multi-threaded)
|
||||
self._rotation_lock = threading.Lock() # Thread safety for file rotation
|
||||
|
||||
# Emergency shutdown flag for graceful signal handling
|
||||
self._emergency_shutdown = False
|
||||
# Threading event for immediate signal-safe notification
|
||||
self._emergency_event = threading.Event()
|
||||
# Track logging failures when both file and stderr fail
|
||||
self._logging_failures_count = 0
|
||||
|
||||
# Ensure log directory exists
|
||||
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Rotate log if needed on startup (single-threaded, safe)
|
||||
self._rotate_if_needed()
|
||||
|
||||
def emergency_shutdown(self) -> None:
|
||||
"""
|
||||
Signal emergency shutdown to make logging operations non-blocking.
|
||||
|
||||
This method is signal-safe and can be called from signal handlers.
|
||||
After calling this, all logging operations will be skipped to allow
|
||||
rapid shutdown without blocking on file I/O.
|
||||
"""
|
||||
self._emergency_shutdown = True
|
||||
self._emergency_event.set()
|
||||
|
||||
def is_shutdown(self) -> bool:
|
||||
"""Check if emergency shutdown has been triggered."""
|
||||
return self._emergency_shutdown
|
||||
|
||||
async def log(self, level: str, message: str) -> None:
|
||||
"""
|
||||
Log a message with timestamp.
|
||||
|
||||
Args:
|
||||
level: Log level (INFO, SUCCESS, ERROR, WARNING)
|
||||
message: Message to log
|
||||
"""
|
||||
# Skip logging during emergency shutdown
|
||||
if self._emergency_shutdown:
|
||||
return
|
||||
|
||||
# Sanitize the message to handle problematic unicode
|
||||
sanitized_message = self._sanitize_unicode(message)
|
||||
|
||||
# Mask sensitive data to prevent security vulnerabilities
|
||||
secure_message = SecurityValidator.mask_sensitive_data(sanitized_message)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_line = f"{timestamp} [{level}] {secure_message}\n"
|
||||
|
||||
async with self._lock:
|
||||
# Double-check shutdown flag after acquiring lock
|
||||
if self._emergency_shutdown:
|
||||
return
|
||||
|
||||
# Write to file
|
||||
await asyncio.to_thread(self._write_to_file, log_line)
|
||||
|
||||
# Print to console if verbose
|
||||
if self.verbose:
|
||||
print(log_line.rstrip())
|
||||
|
||||
def _sanitize_unicode(self, message: str) -> str:
|
||||
"""
|
||||
Sanitize unicode message to prevent encoding errors.
|
||||
|
||||
Args:
|
||||
message: Original message
|
||||
|
||||
Returns:
|
||||
Sanitized message safe for UTF-8 encoding
|
||||
"""
|
||||
try:
|
||||
# Test if the message can be encoded as UTF-8
|
||||
message.encode("utf-8")
|
||||
return message
|
||||
except UnicodeEncodeError:
|
||||
# If encoding fails, replace problematic characters
|
||||
try:
|
||||
# Try to encode with errors='replace' first
|
||||
return message.encode("utf-8", errors="replace").decode("utf-8")
|
||||
except Exception:
|
||||
# If that still fails, use more aggressive sanitization
|
||||
sanitized = []
|
||||
for char in message:
|
||||
try:
|
||||
char.encode("utf-8")
|
||||
sanitized.append(char)
|
||||
except UnicodeEncodeError:
|
||||
# Replace problematic character with a placeholder
|
||||
sanitized.append("[?]")
|
||||
return "".join(sanitized)
|
||||
except Exception:
|
||||
# For any other unexpected errors, return a safe fallback
|
||||
return "[Unicode encoding error]"
|
||||
|
||||
def _write_to_file(self, line: str) -> None:
|
||||
"""Synchronous file write (called via to_thread)."""
|
||||
with open(self.log_file, "a", encoding="utf-8") as f:
|
||||
f.write(line)
|
||||
|
||||
# Check if rotation is needed (thread-safe)
|
||||
self._rotate_if_needed_thread_safe()
|
||||
|
||||
def _rotate_if_needed_thread_safe(self) -> None:
|
||||
"""Thread-safe version of _rotate_if_needed() to prevent race conditions."""
|
||||
with self._rotation_lock:
|
||||
self._rotate_if_needed()
|
||||
|
||||
def _rotate_if_needed(self) -> None:
|
||||
"""
|
||||
Rotate log file if it exceeds max size.
|
||||
|
||||
Note: This method should only be called:
|
||||
- During __init__ (single-threaded, safe before logger is shared)
|
||||
- From within _rotate_if_needed_thread_safe() when rotation lock is held
|
||||
"""
|
||||
if not self.log_file.exists():
|
||||
return
|
||||
|
||||
# Double-check file size with lock held
|
||||
try:
|
||||
file_size = self.log_file.stat().st_size
|
||||
except (OSError, IOError):
|
||||
# File might have been moved or deleted by another thread
|
||||
return
|
||||
|
||||
if file_size > self.MAX_LOG_SIZE_BYTES:
|
||||
# Create a temporary file to ensure atomic rotation
|
||||
temp_backup = self.log_file.with_suffix(".log.tmp")
|
||||
|
||||
try:
|
||||
# Atomically move current log to temporary backup
|
||||
shutil.move(str(self.log_file), str(temp_backup))
|
||||
|
||||
# Rotate backups in reverse order
|
||||
for i in range(self.MAX_BACKUP_FILES - 1, 0, -1):
|
||||
old_backup = self.log_file.with_suffix(f".log.{i}")
|
||||
new_backup = self.log_file.with_suffix(f".log.{i + 1}")
|
||||
if old_backup.exists():
|
||||
if new_backup.exists():
|
||||
new_backup.unlink()
|
||||
shutil.move(str(old_backup), str(new_backup))
|
||||
|
||||
# Move temporary backup to .1
|
||||
backup = self.log_file.with_suffix(".log.1")
|
||||
if backup.exists():
|
||||
backup.unlink()
|
||||
shutil.move(str(temp_backup), str(backup))
|
||||
|
||||
# Clean up any backups beyond MAX_BACKUP_FILES
|
||||
i = self.MAX_BACKUP_FILES + 1
|
||||
while True:
|
||||
old_backup = self.log_file.with_suffix(f".log.{i}")
|
||||
if old_backup.exists():
|
||||
old_backup.unlink()
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
except (OSError, IOError):
|
||||
# If rotation fails, try to restore from temporary backup
|
||||
if temp_backup.exists() and not self.log_file.exists():
|
||||
try:
|
||||
shutil.move(str(temp_backup), str(self.log_file))
|
||||
except (OSError, IOError):
|
||||
# If we can't restore, at least remove the temp file
|
||||
if temp_backup.exists():
|
||||
temp_backup.unlink()
|
||||
|
||||
async def log_info(self, message: str) -> None:
|
||||
"""Log info message."""
|
||||
await self.log("INFO", message)
|
||||
|
||||
async def log_success(self, message: str) -> None:
|
||||
"""Log success message."""
|
||||
await self.log("SUCCESS", message)
|
||||
|
||||
async def log_error(self, message: str) -> None:
|
||||
"""Log error message."""
|
||||
await self.log("ERROR", message)
|
||||
|
||||
async def log_warning(self, message: str) -> None:
|
||||
"""Log warning message."""
|
||||
await self.log("WARNING", message)
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to warn about unretrieved coroutines."""
|
||||
try:
|
||||
# Check if Python is shutting down
|
||||
import sys
|
||||
|
||||
if sys.meta_path is None:
|
||||
# Python is shutting down, skip cleanup to avoid errors
|
||||
return
|
||||
except Exception:
|
||||
# Silently ignore any errors during destructor to avoid crashes
|
||||
pass
|
||||
|
||||
# Synchronous wrapper methods for compatibility
|
||||
def _log_sync_direct(self, level: str, message: str) -> None:
|
||||
"""
|
||||
Log a message synchronously using threading.Lock (thread-safe).
|
||||
|
||||
This bypasses asyncio entirely for true multi-threaded safety.
|
||||
Includes defensive error handling with stderr fallback.
|
||||
"""
|
||||
if self._emergency_shutdown:
|
||||
return
|
||||
|
||||
# Sanitize and secure the message
|
||||
sanitized_message = self._sanitize_unicode(message)
|
||||
secure_message = SecurityValidator.mask_sensitive_data(sanitized_message)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_line = f"{timestamp} [{level}] {secure_message}\n"
|
||||
|
||||
try:
|
||||
with self._thread_lock:
|
||||
if self._emergency_shutdown:
|
||||
return
|
||||
self._write_to_file(log_line)
|
||||
if self.verbose:
|
||||
print(log_line.rstrip())
|
||||
except (PermissionError, OSError, IOError, FileNotFoundError) as e:
|
||||
# Fallback to stderr when file I/O fails
|
||||
# Truncate message to 200 chars for stderr output
|
||||
truncated_message = secure_message[:200]
|
||||
try:
|
||||
print(
|
||||
f"[LOGGING ERROR] {type(e).__name__}: {e}\n"
|
||||
f"Original message (truncated): {truncated_message}",
|
||||
file=sys.stderr
|
||||
)
|
||||
except Exception:
|
||||
# If stderr also fails, track the failure count for diagnostics
|
||||
# Cannot log or print - just increment counter
|
||||
self._logging_failures_count += 1
|
||||
|
||||
def log_info_sync(self, message: str) -> None:
|
||||
"""Log info message synchronously (thread-safe)."""
|
||||
self._log_sync_direct("INFO", message)
|
||||
|
||||
def log_success_sync(self, message: str) -> None:
|
||||
"""Log success message synchronously (thread-safe)."""
|
||||
self._log_sync_direct("SUCCESS", message)
|
||||
|
||||
def log_error_sync(self, message: str) -> None:
|
||||
"""Log error message synchronously (thread-safe)."""
|
||||
self._log_sync_direct("ERROR", message)
|
||||
|
||||
def log_warning_sync(self, message: str) -> None:
|
||||
"""Log warning message synchronously (thread-safe)."""
|
||||
self._log_sync_direct("WARNING", message)
|
||||
|
||||
# Standard logging interface methods for compatibility
|
||||
def info(self, message: str) -> None:
|
||||
"""Standard logging interface - log info message synchronously."""
|
||||
self.log_info_sync(message)
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""Standard logging interface - log debug message synchronously (maps to info)."""
|
||||
self.log_info_sync(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Standard logging interface - log warning message synchronously."""
|
||||
self.log_warning_sync(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Standard logging interface - log error message synchronously."""
|
||||
self.log_error_sync(message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
"""Standard logging interface - log critical message synchronously (maps to error)."""
|
||||
self.log_error_sync(message)
|
||||
|
||||
def get_stats(self) -> dict[str, int | str | None]:
|
||||
"""
|
||||
Get statistics from log file.
|
||||
|
||||
Returns:
|
||||
Dict with success_count, error_count, start_time
|
||||
"""
|
||||
if not self.log_file.exists():
|
||||
return {"success_count": 0, "error_count": 0, "start_time": None}
|
||||
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
start_time = None
|
||||
|
||||
with open(self.log_file, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
if lines:
|
||||
# Extract start time from first line
|
||||
first_line = lines[0]
|
||||
start_time = first_line.split(" [")[0] if " [" in first_line else None
|
||||
|
||||
# Count successes and errors
|
||||
for line in lines:
|
||||
if "Iteration" in line and "completed successfully" in line:
|
||||
success_count += 1
|
||||
elif "Iteration" in line and "failed" in line:
|
||||
error_count += 1
|
||||
|
||||
return {
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"start_time": start_time,
|
||||
}
|
||||
|
||||
def get_recent_lines(self, count: Optional[int] = None) -> list[str]:
|
||||
"""
|
||||
Get recent log lines.
|
||||
|
||||
Args:
|
||||
count: Number of recent lines to return (default: DEFAULT_RECENT_LINES_COUNT)
|
||||
|
||||
Returns:
|
||||
List of recent log lines
|
||||
"""
|
||||
if count is None:
|
||||
count = self.DEFAULT_RECENT_LINES_COUNT
|
||||
|
||||
if not self.log_file.exists():
|
||||
return []
|
||||
|
||||
with open(self.log_file, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return [line.rstrip() for line in lines[-count:]]
|
||||
|
||||
def count_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Count occurrences of pattern in log file.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to search for
|
||||
|
||||
Returns:
|
||||
Number of occurrences
|
||||
"""
|
||||
if not self.log_file.exists():
|
||||
return 0
|
||||
|
||||
with open(self.log_file, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return content.count(pattern)
|
||||
|
||||
def get_start_time(self) -> Optional[str]:
|
||||
"""
|
||||
Get start time from first log entry.
|
||||
|
||||
Returns:
|
||||
Start time string or None if no logs
|
||||
"""
|
||||
if not self.log_file.exists():
|
||||
return None
|
||||
|
||||
with open(self.log_file, encoding="utf-8") as f:
|
||||
first_line = f.readline()
|
||||
|
||||
if not first_line:
|
||||
return None
|
||||
|
||||
# Extract timestamp from first line (format: "YYYY-MM-DD HH:MM:SS [LEVEL] message")
|
||||
parts = first_line.split(" ", 2)
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[0]} {parts[1]}"
|
||||
|
||||
return None
|
||||
225
.venv/lib/python3.11/site-packages/ralph_orchestrator/context.py
Normal file
225
.venv/lib/python3.11/site-packages/ralph_orchestrator/context.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# ABOUTME: Context management and optimization for Ralph Orchestrator
|
||||
# ABOUTME: Handles prompt caching, summarization, and context window management
|
||||
|
||||
"""Context management for Ralph Orchestrator."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('ralph-orchestrator.context')
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Manage prompt context and optimization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_file: Path,
|
||||
max_context_size: int = 8000,
|
||||
cache_dir: Path = Path(".agent/cache"),
|
||||
prompt_text: Optional[str] = None
|
||||
):
|
||||
"""Initialize context manager.
|
||||
|
||||
Args:
|
||||
prompt_file: Path to the main prompt file
|
||||
max_context_size: Maximum context size in characters
|
||||
cache_dir: Directory for caching context
|
||||
prompt_text: Direct prompt text (overrides prompt_file if provided)
|
||||
"""
|
||||
self.prompt_file = prompt_file
|
||||
self.max_context_size = max_context_size
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.prompt_text = prompt_text # Direct prompt text override
|
||||
|
||||
# Context components
|
||||
self.stable_prefix: Optional[str] = None
|
||||
self.dynamic_context: List[str] = []
|
||||
self.error_history: List[str] = []
|
||||
self.success_patterns: List[str] = []
|
||||
|
||||
# Load initial prompt
|
||||
self._load_initial_prompt()
|
||||
|
||||
def _load_initial_prompt(self):
|
||||
"""Load and analyze the initial prompt."""
|
||||
# Use direct prompt text if provided
|
||||
if self.prompt_text:
|
||||
logger.info("Using direct prompt_text input")
|
||||
content = self.prompt_text
|
||||
elif self.prompt_file.exists():
|
||||
try:
|
||||
content = self.prompt_file.read_text()
|
||||
except UnicodeDecodeError as e:
|
||||
logger.warning(f"Encoding error reading {self.prompt_file}: {e}")
|
||||
return
|
||||
except PermissionError as e:
|
||||
logger.warning(f"Permission denied reading {self.prompt_file}: {e}")
|
||||
return
|
||||
except OSError as e:
|
||||
logger.warning(f"OS error reading {self.prompt_file}: {e}")
|
||||
return
|
||||
else:
|
||||
logger.info(f"Prompt file {self.prompt_file} not found")
|
||||
return
|
||||
|
||||
# Extract stable prefix (instructions that don't change)
|
||||
lines = content.split('\n')
|
||||
stable_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.startswith('#') or line.startswith('##'):
|
||||
stable_lines.append(line)
|
||||
# No longer breaking on completion markers
|
||||
elif len(stable_lines) > 0 and line.strip() == '':
|
||||
stable_lines.append(line)
|
||||
elif len(stable_lines) > 0:
|
||||
break
|
||||
|
||||
self.stable_prefix = '\n'.join(stable_lines)
|
||||
logger.info(f"Extracted stable prefix: {len(self.stable_prefix)} chars")
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the current prompt with optimizations."""
|
||||
# Use direct prompt text if provided
|
||||
if self.prompt_text:
|
||||
base_content = self.prompt_text
|
||||
elif self.prompt_file.exists():
|
||||
try:
|
||||
base_content = self.prompt_file.read_text()
|
||||
except UnicodeDecodeError as e:
|
||||
logger.warning(f"Encoding error reading {self.prompt_file}: {e}")
|
||||
return ""
|
||||
except PermissionError as e:
|
||||
logger.warning(f"Permission denied reading {self.prompt_file}: {e}")
|
||||
return ""
|
||||
except OSError as e:
|
||||
logger.warning(f"OS error reading {self.prompt_file}: {e}")
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"No prompt available: prompt_text={self.prompt_text is not None}, prompt_file={self.prompt_file}")
|
||||
return ""
|
||||
|
||||
# Check if we need to optimize
|
||||
if len(base_content) > self.max_context_size:
|
||||
return self._optimize_prompt(base_content)
|
||||
|
||||
# Add dynamic context if there's room
|
||||
if self.dynamic_context:
|
||||
context_addition = "\n\n## Previous Context\n" + "\n".join(self.dynamic_context[-3:])
|
||||
if len(base_content) + len(context_addition) < self.max_context_size:
|
||||
base_content += context_addition
|
||||
|
||||
# Add error history if relevant
|
||||
if self.error_history:
|
||||
error_addition = "\n\n## Recent Errors to Avoid\n" + "\n".join(self.error_history[-2:])
|
||||
if len(base_content) + len(error_addition) < self.max_context_size:
|
||||
base_content += error_addition
|
||||
|
||||
return base_content
|
||||
|
||||
def _optimize_prompt(self, content: str) -> str:
|
||||
"""Optimize a prompt that's too large."""
|
||||
logger.info("Optimizing large prompt")
|
||||
|
||||
# Strategy 1: Use stable prefix caching
|
||||
if self.stable_prefix:
|
||||
# Cache the stable prefix
|
||||
prefix_hash = hashlib.sha256(self.stable_prefix.encode()).hexdigest()[:8]
|
||||
cache_file = self.cache_dir / f"prefix_{prefix_hash}.txt"
|
||||
|
||||
if not cache_file.exists():
|
||||
cache_file.write_text(self.stable_prefix)
|
||||
|
||||
# Reference the cached prefix instead of including it
|
||||
optimized = f"<!-- Using cached prefix {prefix_hash} -->\n"
|
||||
|
||||
# Add the dynamic part
|
||||
dynamic_part = content[len(self.stable_prefix):]
|
||||
|
||||
# Truncate if still too large
|
||||
if len(dynamic_part) > self.max_context_size - 100:
|
||||
dynamic_part = self._summarize_content(dynamic_part)
|
||||
|
||||
optimized += dynamic_part
|
||||
return optimized
|
||||
|
||||
# Strategy 2: Summarize the content
|
||||
return self._summarize_content(content)
|
||||
|
||||
def _summarize_content(self, content: str) -> str:
|
||||
"""Summarize content to fit within limits."""
|
||||
lines = content.split('\n')
|
||||
|
||||
# Keep headers and key instructions
|
||||
important_lines = []
|
||||
for line in lines:
|
||||
if any([
|
||||
line.startswith('#'),
|
||||
# 'TODO' in line,
|
||||
'IMPORTANT' in line,
|
||||
'ERROR' in line,
|
||||
line.startswith('- [ ]'), # Unchecked tasks
|
||||
]):
|
||||
important_lines.append(line)
|
||||
|
||||
summary = '\n'.join(important_lines)
|
||||
|
||||
# If still too long, truncate
|
||||
if len(summary) > self.max_context_size:
|
||||
summary = summary[:self.max_context_size - 100] + "\n<!-- Content truncated -->"
|
||||
|
||||
return summary
|
||||
|
||||
def update_context(self, output: str):
|
||||
"""Update dynamic context based on agent output."""
|
||||
# Extract key information from output
|
||||
if "error" in output.lower():
|
||||
# Track errors for learning
|
||||
error_lines = [line for line in output.split('\n') if 'error' in line.lower()]
|
||||
self.error_history.extend(error_lines[:2])
|
||||
|
||||
# Keep only recent errors
|
||||
self.error_history = self.error_history[-5:]
|
||||
|
||||
if "success" in output.lower() or "complete" in output.lower():
|
||||
# Track successful patterns
|
||||
success_lines = [line for line in output.split('\n')
|
||||
if any(word in line.lower() for word in ['success', 'complete', 'done'])]
|
||||
self.success_patterns.extend(success_lines[:1])
|
||||
self.success_patterns = self.success_patterns[-3:]
|
||||
|
||||
# Add to dynamic context (summarized)
|
||||
if len(output) > 500:
|
||||
summary = output[:200] + "..." + output[-200:]
|
||||
self.dynamic_context.append(summary)
|
||||
else:
|
||||
self.dynamic_context.append(output)
|
||||
|
||||
# Keep dynamic context limited
|
||||
self.dynamic_context = self.dynamic_context[-5:]
|
||||
|
||||
def add_error_feedback(self, error: str):
|
||||
"""Add error feedback to context."""
|
||||
self.error_history.append(f"Error: {error}")
|
||||
self.error_history = self.error_history[-5:]
|
||||
|
||||
def reset(self):
|
||||
"""Reset dynamic context."""
|
||||
self.dynamic_context = []
|
||||
self.error_history = []
|
||||
self.success_patterns = []
|
||||
logger.info("Context reset")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get context statistics."""
|
||||
return {
|
||||
"stable_prefix_size": len(self.stable_prefix) if self.stable_prefix else 0,
|
||||
"dynamic_context_items": len(self.dynamic_context),
|
||||
"error_history_items": len(self.error_history),
|
||||
"success_patterns": len(self.success_patterns),
|
||||
"cache_files": len(list(self.cache_dir.glob("*.txt")))
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
# ABOUTME: Error formatter for Ralph Orchestrator
|
||||
# ABOUTME: Provides structured error messages with user-friendly suggestions
|
||||
|
||||
"""
|
||||
Error formatter for Ralph Orchestrator.
|
||||
|
||||
This module provides structured error messages with user-friendly suggestions
|
||||
and security-aware error sanitization for Claude SDK and adapter errors.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorMessage:
|
||||
"""Formatted error message with main message and suggestion.
|
||||
|
||||
Attributes:
|
||||
message: The main error message (user-facing)
|
||||
suggestion: A helpful suggestion for resolving the error
|
||||
"""
|
||||
message: str
|
||||
suggestion: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return combined error and suggestion."""
|
||||
return f"{self.message} | {self.suggestion}"
|
||||
|
||||
|
||||
class ClaudeErrorFormatter:
|
||||
"""Formats error messages with user-friendly suggestions.
|
||||
|
||||
This class provides static methods to format various error types
|
||||
encountered during Claude SDK operations into structured messages
|
||||
with helpful suggestions for resolution.
|
||||
|
||||
All methods use security-aware sanitization to prevent information
|
||||
disclosure of sensitive data in error messages.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def format_timeout_error(iteration: int, timeout: int) -> ErrorMessage:
|
||||
"""Format timeout error message.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
timeout: Timeout limit in seconds
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestion
|
||||
"""
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} exceeded timeout limit of {timeout}s",
|
||||
suggestion="Try: Increase iteration_timeout in config or simplify your prompt"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_process_terminated_error(iteration: int) -> ErrorMessage:
|
||||
"""Format process termination error message.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestion
|
||||
"""
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} failed: Claude subprocess terminated unexpectedly",
|
||||
suggestion="Try: Check if Claude Code CLI is properly installed and has correct permissions"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_interrupted_error(iteration: int) -> ErrorMessage:
|
||||
"""Format interrupted error message (SIGTERM received).
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestion
|
||||
"""
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} was interrupted (SIGTERM)",
|
||||
suggestion="This usually happens when stopping Ralph - no action needed"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_command_failed_error(iteration: int) -> ErrorMessage:
|
||||
"""Format command failed error message (exit code 1).
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestion
|
||||
"""
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} failed: Claude CLI command failed",
|
||||
suggestion="Try: Check Claude CLI installation with 'claude --version' or verify API key with 'claude login'"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_connection_error(iteration: int) -> ErrorMessage:
|
||||
"""Format connection error message.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestion
|
||||
"""
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} failed: Cannot connect to Claude CLI",
|
||||
suggestion="Try: Verify Claude Code CLI is installed: claude --version"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_rate_limit_error(iteration: int, retry_after: int = 60) -> ErrorMessage:
|
||||
"""Format rate limit error message.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
retry_after: Seconds until retry is recommended
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestion
|
||||
"""
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} failed: Rate limit exceeded",
|
||||
suggestion=f"Try: Wait {retry_after}s before retrying or reduce request frequency"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_authentication_error(iteration: int) -> ErrorMessage:
|
||||
"""Format authentication error message.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestion
|
||||
"""
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} failed: Authentication error",
|
||||
suggestion="Try: Check your API credentials or re-authenticate with 'claude login'"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_permission_error(iteration: int, path: str = "") -> ErrorMessage:
|
||||
"""Format permission error message.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
path: Optional path that caused the permission error
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestion
|
||||
"""
|
||||
# Sanitize path to avoid information disclosure
|
||||
safe_path = path if path and len(path) < 100 else ""
|
||||
if safe_path:
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} failed: Permission denied for '{safe_path}'",
|
||||
suggestion="Try: Check file permissions or run with appropriate privileges"
|
||||
)
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} failed: Permission denied",
|
||||
suggestion="Try: Check file/directory permissions or run with appropriate privileges"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_generic_error(iteration: int, error_type: str, error_str: str) -> ErrorMessage:
|
||||
"""Format generic error message with security sanitization.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
error_type: Type of the exception (e.g., 'ValueError')
|
||||
error_str: String representation of the error
|
||||
|
||||
Returns:
|
||||
Formatted error message with sanitized content
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from .security import SecurityValidator
|
||||
|
||||
# Sanitize error string to prevent information disclosure
|
||||
sanitized_error_str = SecurityValidator.mask_sensitive_data(error_str)
|
||||
|
||||
# Truncate very long error messages
|
||||
if len(sanitized_error_str) > 200:
|
||||
sanitized_error_str = sanitized_error_str[:197] + "..."
|
||||
|
||||
return ErrorMessage(
|
||||
message=f"Iteration {iteration} failed: {error_type}: {sanitized_error_str}",
|
||||
suggestion="Check logs for details or try reducing prompt complexity"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_error_from_exception(iteration: int, exception: Exception) -> ErrorMessage:
|
||||
"""Format error message from exception.
|
||||
|
||||
Analyzes the exception type and message to provide the most
|
||||
appropriate error format with helpful suggestions.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
exception: The exception that occurred
|
||||
|
||||
Returns:
|
||||
Formatted error message with appropriate suggestion
|
||||
"""
|
||||
error_type = type(exception).__name__
|
||||
error_str = str(exception)
|
||||
|
||||
# Match error patterns and provide specific suggestions
|
||||
|
||||
# Process transport issues (subprocess terminated)
|
||||
if "ProcessTransport is not ready" in error_str:
|
||||
return ClaudeErrorFormatter.format_process_terminated_error(iteration)
|
||||
|
||||
# SIGTERM interruption (exit code 143 = 128 + 15)
|
||||
if "Command failed with exit code 143" in error_str:
|
||||
return ClaudeErrorFormatter.format_interrupted_error(iteration)
|
||||
|
||||
# General command failure (exit code 1)
|
||||
if "Command failed with exit code 1" in error_str:
|
||||
return ClaudeErrorFormatter.format_command_failed_error(iteration)
|
||||
|
||||
# Connection errors
|
||||
if error_type == "CLIConnectionError" or "connection" in error_str.lower():
|
||||
return ClaudeErrorFormatter.format_connection_error(iteration)
|
||||
|
||||
# Timeout errors
|
||||
if error_type in ("TimeoutError", "asyncio.TimeoutError") or "timeout" in error_str.lower():
|
||||
return ClaudeErrorFormatter.format_timeout_error(iteration, 0)
|
||||
|
||||
# Rate limit errors
|
||||
if "rate limit" in error_str.lower() or error_type == "RateLimitError":
|
||||
return ClaudeErrorFormatter.format_rate_limit_error(iteration)
|
||||
|
||||
# Authentication errors
|
||||
if "authentication" in error_str.lower() or "auth" in error_str.lower() or error_type == "AuthenticationError":
|
||||
return ClaudeErrorFormatter.format_authentication_error(iteration)
|
||||
|
||||
# Permission errors
|
||||
if error_type == "PermissionError" or "permission denied" in error_str.lower():
|
||||
return ClaudeErrorFormatter.format_permission_error(iteration)
|
||||
|
||||
# Fall back to generic error format
|
||||
return ClaudeErrorFormatter.format_generic_error(iteration, error_type, error_str)
|
||||
@@ -0,0 +1,218 @@
|
||||
# ABOUTME: Logging configuration module for Ralph Orchestrator
|
||||
# ABOUTME: Provides centralized logging setup with proper formatters and handlers
|
||||
|
||||
"""Logging configuration for Ralph Orchestrator."""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class RalphLogger:
|
||||
"""Centralized logging configuration for Ralph Orchestrator."""
|
||||
|
||||
# Logger names for different components
|
||||
ORCHESTRATOR = "ralph.orchestrator"
|
||||
ADAPTER_BASE = "ralph.adapter"
|
||||
ADAPTER_QCHAT = "ralph.adapter.qchat"
|
||||
ADAPTER_KIRO = "ralph.adapter.kiro"
|
||||
ADAPTER_CLAUDE = "ralph.adapter.claude"
|
||||
ADAPTER_GEMINI = "ralph.adapter.gemini"
|
||||
SAFETY = "ralph.safety"
|
||||
METRICS = "ralph.metrics"
|
||||
|
||||
# Default log format
|
||||
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
DETAILED_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(funcName)s() - %(message)s"
|
||||
|
||||
_initialized = False
|
||||
_log_dir: Optional[Path] = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls,
|
||||
log_level: Optional[str] = None,
|
||||
log_file: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
console_output: Optional[bool] = None,
|
||||
detailed_format: bool = False) -> None:
|
||||
"""Initialize logging configuration.
|
||||
|
||||
Args:
|
||||
log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_file: Path to log file (optional)
|
||||
log_dir: Directory for log files (optional)
|
||||
console_output: Whether to output to console
|
||||
detailed_format: Use detailed format with file/line info
|
||||
"""
|
||||
if cls._initialized:
|
||||
return
|
||||
|
||||
# Get configuration from environment variables
|
||||
log_level = log_level or os.getenv("RALPH_LOG_LEVEL", "INFO")
|
||||
log_file = log_file or os.getenv("RALPH_LOG_FILE")
|
||||
log_dir = log_dir or os.getenv("RALPH_LOG_DIR", ".logs")
|
||||
|
||||
# Handle console_output properly - only use env var if not explicitly set
|
||||
if console_output is None:
|
||||
console_output = os.getenv("RALPH_LOG_CONSOLE", "true").lower() == "true"
|
||||
|
||||
detailed_format = detailed_format or \
|
||||
os.getenv("RALPH_LOG_DETAILED", "false").lower() == "true"
|
||||
|
||||
# Convert log level string to logging constant
|
||||
numeric_level = getattr(logging, log_level.upper(), logging.INFO)
|
||||
|
||||
# Choose format
|
||||
log_format = cls.DETAILED_FORMAT if detailed_format else cls.DEFAULT_FORMAT
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(log_format)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger("ralph")
|
||||
root_logger.setLevel(numeric_level)
|
||||
root_logger.handlers = [] # Clear existing handlers
|
||||
|
||||
# Add console handler
|
||||
if console_output:
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(numeric_level)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# Add file handler if specified
|
||||
if log_file or log_dir:
|
||||
cls._setup_file_handler(root_logger, formatter, log_file, log_dir, numeric_level)
|
||||
|
||||
cls._initialized = True
|
||||
|
||||
# Suppress verbose INFO logs from claude-agent-sdk internals
|
||||
# The SDK logs operational details at INFO level (e.g., "Using bundled Claude Code CLI")
|
||||
logging.getLogger('claude_agent_sdk').setLevel(logging.WARNING)
|
||||
|
||||
# Log initialization
|
||||
logger = logging.getLogger(cls.ORCHESTRATOR)
|
||||
logger.debug(f"Logging initialized - Level: {log_level}, Console: {console_output}, "
|
||||
f"File: {log_file or 'None'}, Dir: {log_dir or 'None'}")
|
||||
|
||||
@classmethod
|
||||
def _setup_file_handler(cls,
|
||||
logger: logging.Logger,
|
||||
formatter: logging.Formatter,
|
||||
log_file: Optional[str],
|
||||
log_dir: Optional[str],
|
||||
level: int) -> None:
|
||||
"""Setup file handler for logging.
|
||||
|
||||
Args:
|
||||
logger: Logger to add handler to
|
||||
formatter: Log formatter
|
||||
log_file: Specific log file path
|
||||
log_dir: Directory for log files
|
||||
level: Logging level
|
||||
"""
|
||||
# Determine log file path
|
||||
if log_file:
|
||||
log_path = Path(log_file)
|
||||
else:
|
||||
# Use log directory with default filename
|
||||
cls._log_dir = Path(log_dir)
|
||||
cls._log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_path = cls._log_dir / "ralph_orchestrator.log"
|
||||
|
||||
# Create parent directories if needed
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use rotating file handler
|
||||
max_bytes = int(os.getenv("RALPH_LOG_MAX_BYTES", "10485760")) # 10MB default
|
||||
backup_count = int(os.getenv("RALPH_LOG_BACKUP_COUNT", "5"))
|
||||
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
log_path,
|
||||
maxBytes=max_bytes,
|
||||
backupCount=backup_count
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(level)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls, name: str) -> logging.Logger:
|
||||
"""Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name (use class constants for consistency)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
if not cls._initialized:
|
||||
cls.initialize()
|
||||
|
||||
return logging.getLogger(name)
|
||||
|
||||
@classmethod
|
||||
def log_config(cls) -> Dict[str, Any]:
|
||||
"""Get current logging configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with current logging settings
|
||||
"""
|
||||
root_logger = logging.getLogger("ralph")
|
||||
|
||||
config = {
|
||||
"level": logging.getLevelName(root_logger.level),
|
||||
"handlers": [],
|
||||
"log_dir": str(cls._log_dir) if cls._log_dir else None,
|
||||
"initialized": cls._initialized
|
||||
}
|
||||
|
||||
for handler in root_logger.handlers:
|
||||
handler_info = {
|
||||
"type": handler.__class__.__name__,
|
||||
"level": logging.getLevelName(handler.level)
|
||||
}
|
||||
|
||||
if hasattr(handler, 'baseFilename'):
|
||||
handler_info["file"] = handler.baseFilename
|
||||
|
||||
config["handlers"].append(handler_info)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def set_level(cls, level: str, logger_name: Optional[str] = None) -> None:
|
||||
"""Dynamically set logging level.
|
||||
|
||||
Args:
|
||||
level: New logging level
|
||||
logger_name: Specific logger to update (None for root)
|
||||
"""
|
||||
numeric_level = getattr(logging, level.upper(), logging.INFO)
|
||||
|
||||
if logger_name:
|
||||
logger = logging.getLogger(logger_name)
|
||||
else:
|
||||
logger = logging.getLogger("ralph")
|
||||
|
||||
logger.setLevel(numeric_level)
|
||||
|
||||
# Update handlers
|
||||
for handler in logger.handlers:
|
||||
handler.setLevel(numeric_level)
|
||||
|
||||
|
||||
# Convenience function for getting loggers
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a configured logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
return RalphLogger.get_logger(name)
|
||||
626
.venv/lib/python3.11/site-packages/ralph_orchestrator/main.py
Executable file
626
.venv/lib/python3.11/site-packages/ralph_orchestrator/main.py
Executable file
@@ -0,0 +1,626 @@
|
||||
#!/usr/bin/env python3
|
||||
# ABOUTME: Ralph orchestrator main loop implementation with multi-agent support
|
||||
# ABOUTME: Implements the core Ralph Wiggum technique with continuous iteration
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import threading
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from .orchestrator import RalphOrchestrator
|
||||
|
||||
|
||||
# Configuration defaults
|
||||
DEFAULT_MAX_ITERATIONS = 100
|
||||
DEFAULT_MAX_RUNTIME = 14400 # 4 hours
|
||||
DEFAULT_PROMPT_FILE = "PROMPT.md"
|
||||
DEFAULT_CHECKPOINT_INTERVAL = 5
|
||||
DEFAULT_RETRY_DELAY = 2
|
||||
DEFAULT_MAX_TOKENS = 1000000 # 1M tokens total
|
||||
DEFAULT_MAX_COST = 50.0 # $50 USD
|
||||
DEFAULT_CONTEXT_WINDOW = 200000 # 200K token context window
|
||||
DEFAULT_CONTEXT_THRESHOLD = 0.8 # Trigger summarization at 80% of context
|
||||
DEFAULT_METRICS_INTERVAL = 10 # Log metrics every 10 iterations
|
||||
DEFAULT_MAX_PROMPT_SIZE = 10485760 # 10MB max prompt file size
|
||||
DEFAULT_COMPLETION_PROMISE = "LOOP_COMPLETE"
|
||||
|
||||
# Token costs per million (approximate)
|
||||
TOKEN_COSTS = {
|
||||
"claude": {"input": 3.0, "output": 15.0}, # Claude 3.5 Sonnet
|
||||
"q": {"input": 0.5, "output": 1.5}, # Estimated
|
||||
"kiro": {"input": 0.5, "output": 1.5}, # Estimated
|
||||
"gemini": {"input": 0.5, "output": 1.5} # Gemini Pro
|
||||
}
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger('ralph-orchestrator')
|
||||
|
||||
class AgentType(Enum):
|
||||
"""Supported AI agent types"""
|
||||
CLAUDE = "claude"
|
||||
Q = "q"
|
||||
KIRO = "kiro"
|
||||
GEMINI = "gemini"
|
||||
ACP = "acp"
|
||||
AUTO = "auto"
|
||||
|
||||
class ConfigValidator:
|
||||
"""Validates Ralph configuration settings.
|
||||
|
||||
Provides validation methods for configuration parameters with security
|
||||
checks and warnings for unusual values.
|
||||
"""
|
||||
|
||||
# Validation thresholds
|
||||
LARGE_DELAY_THRESHOLD_SECONDS = 3600 # 1 hour
|
||||
SHORT_TIMEOUT_THRESHOLD_SECONDS = 10 # Very short timeout
|
||||
TYPICAL_AI_ITERATION_MIN_SECONDS = 30 # Typical minimum time for AI iteration
|
||||
TYPICAL_AI_ITERATION_MAX_SECONDS = 300 # Typical maximum time for AI iteration
|
||||
|
||||
# Reasonable limits to prevent resource exhaustion
|
||||
MAX_ITERATIONS_LIMIT = 100000
|
||||
MAX_RUNTIME_LIMIT = 604800 # 1 week in seconds
|
||||
MAX_TOKENS_LIMIT = 100000000 # 100M tokens
|
||||
MAX_COST_LIMIT = 10000.0 # $10K USD
|
||||
|
||||
@staticmethod
|
||||
def validate_max_iterations(max_iterations: int) -> List[str]:
|
||||
"""Validate max iterations parameter."""
|
||||
errors = []
|
||||
if max_iterations < 0:
|
||||
errors.append("Max iterations must be non-negative")
|
||||
elif max_iterations > ConfigValidator.MAX_ITERATIONS_LIMIT:
|
||||
errors.append(f"Max iterations exceeds limit ({ConfigValidator.MAX_ITERATIONS_LIMIT})")
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def validate_max_runtime(max_runtime: int) -> List[str]:
|
||||
"""Validate max runtime parameter."""
|
||||
errors = []
|
||||
if max_runtime < 0:
|
||||
errors.append("Max runtime must be non-negative")
|
||||
elif max_runtime > ConfigValidator.MAX_RUNTIME_LIMIT:
|
||||
errors.append(f"Max runtime exceeds limit ({ConfigValidator.MAX_RUNTIME_LIMIT}s)")
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def validate_checkpoint_interval(checkpoint_interval: int) -> List[str]:
|
||||
"""Validate checkpoint interval parameter."""
|
||||
errors = []
|
||||
if checkpoint_interval < 0:
|
||||
errors.append("Checkpoint interval must be non-negative")
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def validate_retry_delay(retry_delay: int) -> List[str]:
|
||||
"""Validate retry delay parameter."""
|
||||
errors = []
|
||||
if retry_delay < 0:
|
||||
errors.append("Retry delay must be non-negative")
|
||||
elif retry_delay > ConfigValidator.LARGE_DELAY_THRESHOLD_SECONDS:
|
||||
errors.append(f"Retry delay exceeds limit ({ConfigValidator.LARGE_DELAY_THRESHOLD_SECONDS}s)")
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def validate_max_tokens(max_tokens: int) -> List[str]:
|
||||
"""Validate max tokens parameter."""
|
||||
errors = []
|
||||
if max_tokens < 0:
|
||||
errors.append("Max tokens must be non-negative")
|
||||
elif max_tokens > ConfigValidator.MAX_TOKENS_LIMIT:
|
||||
errors.append(f"Max tokens exceeds limit ({ConfigValidator.MAX_TOKENS_LIMIT})")
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def validate_max_cost(max_cost: float) -> List[str]:
|
||||
"""Validate max cost parameter."""
|
||||
errors = []
|
||||
if max_cost < 0:
|
||||
errors.append("Max cost must be non-negative")
|
||||
elif max_cost > ConfigValidator.MAX_COST_LIMIT:
|
||||
errors.append(f"Max cost exceeds limit (${ConfigValidator.MAX_COST_LIMIT})")
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def validate_context_threshold(context_threshold: float) -> List[str]:
|
||||
"""Validate context threshold parameter."""
|
||||
errors = []
|
||||
if not 0.0 <= context_threshold <= 1.0:
|
||||
errors.append("Context threshold must be between 0.0 and 1.0")
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def validate_prompt_file(prompt_file: str) -> List[str]:
|
||||
"""Validate prompt file exists and is readable."""
|
||||
errors = []
|
||||
path = Path(prompt_file)
|
||||
if not path.exists():
|
||||
errors.append(f"Prompt file not found: {prompt_file}")
|
||||
elif not path.is_file():
|
||||
errors.append(f"Prompt file is not a regular file: {prompt_file}")
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def get_warning_large_delay(retry_delay: int) -> List[str]:
|
||||
"""Check for unusually large delay values."""
|
||||
if retry_delay > ConfigValidator.LARGE_DELAY_THRESHOLD_SECONDS:
|
||||
return [
|
||||
f"Warning: Retry delay is very large ({retry_delay}s = {retry_delay/60:.1f}m). "
|
||||
f"Did you mean to use minutes instead of seconds?"
|
||||
]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_warning_single_iteration(max_iterations: int) -> List[str]:
|
||||
"""Check for max_iterations=1."""
|
||||
if max_iterations == 1:
|
||||
return [
|
||||
"Warning: max_iterations is 1. "
|
||||
"Ralph is designed for continuous loops. Did you mean 0 (infinite)?"
|
||||
]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_warning_short_timeout(max_runtime: int) -> List[str]:
|
||||
"""Check for very short runtime limits."""
|
||||
if 0 < max_runtime < ConfigValidator.SHORT_TIMEOUT_THRESHOLD_SECONDS:
|
||||
return [
|
||||
f"Warning: Max runtime is very short ({max_runtime}s). "
|
||||
f"AI iterations typically take {ConfigValidator.TYPICAL_AI_ITERATION_MIN_SECONDS}-"
|
||||
f"{ConfigValidator.TYPICAL_AI_ITERATION_MAX_SECONDS} seconds."
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig:
|
||||
"""Configuration for individual adapters"""
|
||||
enabled: bool = True
|
||||
args: List[str] = field(default_factory=list)
|
||||
env: Dict[str, str] = field(default_factory=dict)
|
||||
timeout: int = 300
|
||||
max_retries: int = 3
|
||||
tool_permissions: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class RalphConfig:
|
||||
"""Configuration for Ralph orchestrator.
|
||||
|
||||
Thread-safe configuration class with RLock protection for mutable fields.
|
||||
Provides both direct attribute access (backwards compatible) and thread-safe
|
||||
getter/setter methods for concurrent access scenarios.
|
||||
"""
|
||||
|
||||
# Core configuration fields
|
||||
agent: AgentType = AgentType.AUTO
|
||||
# Agent selection and fallback priority (used when agent=auto, and for fallback ordering)
|
||||
# Valid values: "acp", "claude", "gemini", "qchat" (also accepts aliases: "codex"->"acp", "q"->"qchat")
|
||||
agent_priority: List[str] = field(default_factory=lambda: ["claude", "kiro", "qchat", "gemini", "acp"])
|
||||
prompt_file: str = DEFAULT_PROMPT_FILE
|
||||
prompt_text: Optional[str] = None # Direct prompt text (overrides prompt_file)
|
||||
completion_promise: Optional[str] = DEFAULT_COMPLETION_PROMISE # String to match in agent output to stop
|
||||
max_iterations: int = DEFAULT_MAX_ITERATIONS
|
||||
max_runtime: int = DEFAULT_MAX_RUNTIME
|
||||
checkpoint_interval: int = DEFAULT_CHECKPOINT_INTERVAL
|
||||
retry_delay: int = DEFAULT_RETRY_DELAY
|
||||
archive_prompts: bool = True
|
||||
git_checkpoint: bool = True
|
||||
verbose: bool = False
|
||||
dry_run: bool = False
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS
|
||||
max_cost: float = DEFAULT_MAX_COST
|
||||
context_window: int = DEFAULT_CONTEXT_WINDOW
|
||||
context_threshold: float = DEFAULT_CONTEXT_THRESHOLD
|
||||
metrics_interval: int = DEFAULT_METRICS_INTERVAL
|
||||
enable_metrics: bool = True
|
||||
max_prompt_size: int = DEFAULT_MAX_PROMPT_SIZE
|
||||
allow_unsafe_paths: bool = False
|
||||
agent_args: List[str] = field(default_factory=list)
|
||||
adapters: Dict[str, AdapterConfig] = field(default_factory=dict)
|
||||
|
||||
# Output formatting configuration
|
||||
output_format: str = "rich" # "plain", "rich", or "json"
|
||||
output_verbosity: str = "normal" # "quiet", "normal", "verbose", "debug"
|
||||
show_token_usage: bool = True # Display token usage after iterations
|
||||
show_timestamps: bool = True # Include timestamps in output
|
||||
|
||||
# Thread safety lock - not included in initialization/equals
|
||||
_lock: threading.RLock = field(
|
||||
default_factory=threading.RLock, init=False, repr=False, compare=False
|
||||
)
|
||||
|
||||
# Thread-safe property access methods for mutable fields
|
||||
def get_max_iterations(self) -> int:
|
||||
"""Thread-safe access to max_iterations property."""
|
||||
with self._lock:
|
||||
return self.max_iterations
|
||||
|
||||
def set_max_iterations(self, value: int) -> None:
|
||||
"""Thread-safe setting of max_iterations property."""
|
||||
with self._lock:
|
||||
object.__setattr__(self, 'max_iterations', value)
|
||||
|
||||
def get_max_runtime(self) -> int:
|
||||
"""Thread-safe access to max_runtime property."""
|
||||
with self._lock:
|
||||
return self.max_runtime
|
||||
|
||||
def set_max_runtime(self, value: int) -> None:
|
||||
"""Thread-safe setting of max_runtime property."""
|
||||
with self._lock:
|
||||
object.__setattr__(self, 'max_runtime', value)
|
||||
|
||||
def get_checkpoint_interval(self) -> int:
|
||||
"""Thread-safe access to checkpoint_interval property."""
|
||||
with self._lock:
|
||||
return self.checkpoint_interval
|
||||
|
||||
def set_checkpoint_interval(self, value: int) -> None:
|
||||
"""Thread-safe setting of checkpoint_interval property."""
|
||||
with self._lock:
|
||||
object.__setattr__(self, 'checkpoint_interval', value)
|
||||
|
||||
def get_retry_delay(self) -> int:
|
||||
"""Thread-safe access to retry_delay property."""
|
||||
with self._lock:
|
||||
return self.retry_delay
|
||||
|
||||
def set_retry_delay(self, value: int) -> None:
|
||||
"""Thread-safe setting of retry_delay property."""
|
||||
with self._lock:
|
||||
object.__setattr__(self, 'retry_delay', value)
|
||||
|
||||
def get_max_tokens(self) -> int:
|
||||
"""Thread-safe access to max_tokens property."""
|
||||
with self._lock:
|
||||
return self.max_tokens
|
||||
|
||||
def set_max_tokens(self, value: int) -> None:
|
||||
"""Thread-safe setting of max_tokens property."""
|
||||
with self._lock:
|
||||
object.__setattr__(self, 'max_tokens', value)
|
||||
|
||||
def get_max_cost(self) -> float:
|
||||
"""Thread-safe access to max_cost property."""
|
||||
with self._lock:
|
||||
return self.max_cost
|
||||
|
||||
def set_max_cost(self, value: float) -> None:
|
||||
"""Thread-safe setting of max_cost property."""
|
||||
with self._lock:
|
||||
object.__setattr__(self, 'max_cost', value)
|
||||
|
||||
def get_verbose(self) -> bool:
|
||||
"""Thread-safe access to verbose property."""
|
||||
with self._lock:
|
||||
return self.verbose
|
||||
|
||||
def set_verbose(self, value: bool) -> None:
|
||||
"""Thread-safe setting of verbose property."""
|
||||
with self._lock:
|
||||
object.__setattr__(self, 'verbose', value)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, config_path: str) -> 'RalphConfig':
|
||||
"""Load configuration from YAML file."""
|
||||
config_file = Path(config_path)
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
# Convert agent string to AgentType enum
|
||||
if 'agent' in config_data:
|
||||
config_data['agent'] = AgentType(config_data['agent'])
|
||||
|
||||
# Process adapter configurations
|
||||
if 'adapters' in config_data:
|
||||
adapter_configs = {}
|
||||
for name, adapter_data in config_data['adapters'].items():
|
||||
if isinstance(adapter_data, dict):
|
||||
adapter_configs[name] = AdapterConfig(**adapter_data)
|
||||
else:
|
||||
# Simple boolean enable/disable
|
||||
adapter_configs[name] = AdapterConfig(enabled=bool(adapter_data))
|
||||
config_data['adapters'] = adapter_configs
|
||||
|
||||
# Filter out unknown keys
|
||||
valid_keys = {f.name for f in cls.__dataclass_fields__.values()}
|
||||
filtered_data = {k: v for k, v in config_data.items() if k in valid_keys}
|
||||
|
||||
return cls(**filtered_data)
|
||||
|
||||
def get_adapter_config(self, adapter_name: str) -> AdapterConfig:
|
||||
"""Get configuration for a specific adapter."""
|
||||
with self._lock:
|
||||
return self.adapters.get(adapter_name, AdapterConfig())
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate configuration settings.
|
||||
|
||||
Returns:
|
||||
List of validation errors (empty if valid).
|
||||
"""
|
||||
errors = []
|
||||
|
||||
with self._lock:
|
||||
errors.extend(ConfigValidator.validate_max_iterations(self.max_iterations))
|
||||
errors.extend(ConfigValidator.validate_max_runtime(self.max_runtime))
|
||||
errors.extend(ConfigValidator.validate_checkpoint_interval(self.checkpoint_interval))
|
||||
errors.extend(ConfigValidator.validate_retry_delay(self.retry_delay))
|
||||
errors.extend(ConfigValidator.validate_max_tokens(self.max_tokens))
|
||||
errors.extend(ConfigValidator.validate_max_cost(self.max_cost))
|
||||
errors.extend(ConfigValidator.validate_context_threshold(self.context_threshold))
|
||||
|
||||
return errors
|
||||
|
||||
def get_warnings(self) -> List[str]:
|
||||
"""Get configuration warnings (non-blocking issues).
|
||||
|
||||
Returns:
|
||||
List of warning messages.
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
with self._lock:
|
||||
warnings.extend(ConfigValidator.get_warning_large_delay(self.retry_delay))
|
||||
warnings.extend(ConfigValidator.get_warning_single_iteration(self.max_iterations))
|
||||
warnings.extend(ConfigValidator.get_warning_short_timeout(self.max_runtime))
|
||||
|
||||
return warnings
|
||||
|
||||
def create_output_formatter(self):
|
||||
"""Create an output formatter based on configuration settings.
|
||||
|
||||
Returns:
|
||||
OutputFormatter instance configured according to settings.
|
||||
"""
|
||||
from ralph_orchestrator.output import VerbosityLevel, create_formatter
|
||||
|
||||
# Map verbosity string to enum
|
||||
verbosity_map = {
|
||||
"quiet": VerbosityLevel.QUIET,
|
||||
"normal": VerbosityLevel.NORMAL,
|
||||
"verbose": VerbosityLevel.VERBOSE,
|
||||
"debug": VerbosityLevel.DEBUG,
|
||||
}
|
||||
|
||||
with self._lock:
|
||||
verbosity = verbosity_map.get(self.output_verbosity.lower(), VerbosityLevel.NORMAL)
|
||||
return create_formatter(
|
||||
format_type=self.output_format,
|
||||
verbosity=verbosity,
|
||||
)
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Ralph Wiggum Orchestrator - Put AI in a loop until done"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--agent", "-a",
|
||||
type=str,
|
||||
choices=["claude", "q", "kiro", "gemini", "acp", "auto"],
|
||||
default="auto",
|
||||
help="AI agent to use (default: auto-detect)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt-file", "-P",
|
||||
type=str,
|
||||
default=DEFAULT_PROMPT_FILE,
|
||||
dest="prompt",
|
||||
help="Prompt file path (default: PROMPT.md)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt-text", "-p",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Direct prompt text (overrides --prompt-file)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--completion-promise",
|
||||
type=str,
|
||||
default=DEFAULT_COMPLETION_PROMISE,
|
||||
help=f"Stop when agent output contains this exact string (default: {DEFAULT_COMPLETION_PROMISE})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-iterations", "-i",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_ITERATIONS,
|
||||
help=f"Maximum iterations (default: {DEFAULT_MAX_ITERATIONS})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-runtime", "-t",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_RUNTIME,
|
||||
help=f"Maximum runtime in seconds (default: {DEFAULT_MAX_RUNTIME})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-interval", "-c",
|
||||
type=int,
|
||||
default=DEFAULT_CHECKPOINT_INTERVAL,
|
||||
help=f"Checkpoint interval (default: {DEFAULT_CHECKPOINT_INTERVAL})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--retry-delay", "-r",
|
||||
type=int,
|
||||
default=DEFAULT_RETRY_DELAY,
|
||||
help=f"Retry delay in seconds (default: {DEFAULT_RETRY_DELAY})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
help=f"Maximum total tokens (default: {DEFAULT_MAX_TOKENS:,})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-cost",
|
||||
type=float,
|
||||
default=DEFAULT_MAX_COST,
|
||||
help=f"Maximum cost in USD (default: ${DEFAULT_MAX_COST:.2f})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-window",
|
||||
type=int,
|
||||
default=DEFAULT_CONTEXT_WINDOW,
|
||||
help=f"Context window size in tokens (default: {DEFAULT_CONTEXT_WINDOW:,})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-threshold",
|
||||
type=float,
|
||||
default=DEFAULT_CONTEXT_THRESHOLD,
|
||||
help=f"Context summarization threshold (default: {DEFAULT_CONTEXT_THRESHOLD:.1f} = {DEFAULT_CONTEXT_THRESHOLD*100:.0f}%%)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--metrics-interval",
|
||||
type=int,
|
||||
default=DEFAULT_METRICS_INTERVAL,
|
||||
help=f"Metrics logging interval (default: {DEFAULT_METRICS_INTERVAL})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-metrics",
|
||||
action="store_true",
|
||||
help="Disable metrics collection"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-prompt-size",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_PROMPT_SIZE,
|
||||
help=f"Maximum prompt file size in bytes (default: {DEFAULT_MAX_PROMPT_SIZE})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--allow-unsafe-paths",
|
||||
action="store_true",
|
||||
help="Allow potentially unsafe prompt paths (use with caution)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-git",
|
||||
action="store_true",
|
||||
help="Disable git checkpointing"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-archive",
|
||||
action="store_true",
|
||||
help="Disable prompt archiving"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose output"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Dry run mode (don't execute agents)"
|
||||
)
|
||||
|
||||
# Output formatting options
|
||||
parser.add_argument(
|
||||
"--output-format",
|
||||
type=str,
|
||||
choices=["plain", "rich", "json"],
|
||||
default="rich",
|
||||
help="Output format (default: rich)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-verbosity",
|
||||
type=str,
|
||||
choices=["quiet", "normal", "verbose", "debug"],
|
||||
default="normal",
|
||||
help="Output verbosity level (default: normal)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-token-usage",
|
||||
action="store_true",
|
||||
help="Disable token usage display"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-timestamps",
|
||||
action="store_true",
|
||||
help="Disable timestamps in output"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"agent_args",
|
||||
nargs="*",
|
||||
help="Additional arguments to pass to the AI agent"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging level
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Create config
|
||||
config = RalphConfig(
|
||||
agent=AgentType(args.agent),
|
||||
prompt_file=args.prompt,
|
||||
prompt_text=args.prompt_text,
|
||||
max_iterations=args.max_iterations,
|
||||
max_runtime=args.max_runtime,
|
||||
checkpoint_interval=args.checkpoint_interval,
|
||||
retry_delay=args.retry_delay,
|
||||
archive_prompts=not args.no_archive,
|
||||
git_checkpoint=not args.no_git,
|
||||
verbose=args.verbose,
|
||||
dry_run=args.dry_run,
|
||||
max_tokens=args.max_tokens,
|
||||
max_cost=args.max_cost,
|
||||
context_window=args.context_window,
|
||||
context_threshold=args.context_threshold,
|
||||
metrics_interval=args.metrics_interval,
|
||||
enable_metrics=not args.no_metrics,
|
||||
max_prompt_size=args.max_prompt_size,
|
||||
allow_unsafe_paths=args.allow_unsafe_paths,
|
||||
agent_args=args.agent_args,
|
||||
completion_promise=args.completion_promise,
|
||||
# Output formatting options
|
||||
output_format=args.output_format,
|
||||
output_verbosity=args.output_verbosity,
|
||||
show_token_usage=not args.no_token_usage,
|
||||
show_timestamps=not args.no_timestamps,
|
||||
)
|
||||
|
||||
# Run orchestrator
|
||||
orchestrator = RalphOrchestrator(config)
|
||||
return orchestrator.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
345
.venv/lib/python3.11/site-packages/ralph_orchestrator/metrics.py
Normal file
345
.venv/lib/python3.11/site-packages/ralph_orchestrator/metrics.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# ABOUTME: Metrics tracking and cost calculation for Ralph Orchestrator
|
||||
# ABOUTME: Monitors performance, usage, and costs across different AI tools
|
||||
|
||||
"""Metrics and cost tracking for Ralph Orchestrator."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Any
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
class TriggerReason(str, Enum):
|
||||
"""Reasons why an iteration was triggered.
|
||||
|
||||
Used for per-iteration telemetry to understand why the orchestrator
|
||||
started each iteration, enabling analysis of orchestration patterns.
|
||||
"""
|
||||
INITIAL = "initial" # First iteration of a session
|
||||
TASK_INCOMPLETE = "task_incomplete" # Previous iteration didn't complete task
|
||||
PREVIOUS_SUCCESS = "previous_success" # Previous iteration succeeded, continuing
|
||||
RECOVERY = "recovery" # Recovering from a previous failure
|
||||
LOOP_DETECTED = "loop_detected" # Loop detection triggered intervention
|
||||
SAFETY_LIMIT = "safety_limit" # Safety limits triggered
|
||||
USER_STOP = "user_stop" # User requested stop
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metrics:
|
||||
"""Track orchestration metrics."""
|
||||
|
||||
iterations: int = 0
|
||||
successful_iterations: int = 0
|
||||
failed_iterations: int = 0
|
||||
errors: int = 0
|
||||
checkpoints: int = 0
|
||||
rollbacks: int = 0
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
def elapsed_hours(self) -> float:
|
||||
"""Get elapsed time in hours."""
|
||||
return (time.time() - self.start_time) / 3600
|
||||
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate success rate."""
|
||||
total = self.successful_iterations + self.failed_iterations
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.successful_iterations / total
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"iterations": self.iterations,
|
||||
"successful_iterations": self.successful_iterations,
|
||||
"failed_iterations": self.failed_iterations,
|
||||
"errors": self.errors,
|
||||
"checkpoints": self.checkpoints,
|
||||
"rollbacks": self.rollbacks,
|
||||
"elapsed_hours": self.elapsed_hours(),
|
||||
"success_rate": self.success_rate()
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
|
||||
class CostTracker:
|
||||
"""Track costs across different AI tools."""
|
||||
|
||||
# Cost per 1K tokens (approximate)
|
||||
COSTS = {
|
||||
"claude": {
|
||||
"input": 0.003, # $3 per 1M input tokens
|
||||
"output": 0.015 # $15 per 1M output tokens
|
||||
},
|
||||
"gemini": {
|
||||
"input": 0.00025, # $0.25 per 1M input tokens
|
||||
"output": 0.001 # $1 per 1M output tokens
|
||||
},
|
||||
"qchat": {
|
||||
"input": 0.0, # Free/local
|
||||
"output": 0.0
|
||||
},
|
||||
"acp": {
|
||||
"input": 0.0, # ACP doesn't provide billing info
|
||||
"output": 0.0 # Cost depends on underlying agent
|
||||
},
|
||||
"gpt-4": {
|
||||
"input": 0.03, # $30 per 1M input tokens
|
||||
"output": 0.06 # $60 per 1M output tokens
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize cost tracker."""
|
||||
self.total_cost = 0.0
|
||||
self.costs_by_tool: Dict[str, float] = {}
|
||||
self.usage_history: List[Dict] = []
|
||||
|
||||
def add_usage(
|
||||
self,
|
||||
tool: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int
|
||||
) -> float:
|
||||
"""Add usage and calculate cost.
|
||||
|
||||
Args:
|
||||
tool: Name of the AI tool
|
||||
input_tokens: Number of input tokens
|
||||
output_tokens: Number of output tokens
|
||||
|
||||
Returns:
|
||||
Cost for this usage
|
||||
"""
|
||||
if tool not in self.COSTS:
|
||||
tool = "qchat" # Default to free tier
|
||||
|
||||
costs = self.COSTS[tool]
|
||||
input_cost = (input_tokens / 1000) * costs["input"]
|
||||
output_cost = (output_tokens / 1000) * costs["output"]
|
||||
total = input_cost + output_cost
|
||||
|
||||
# Update tracking
|
||||
self.total_cost += total
|
||||
if tool not in self.costs_by_tool:
|
||||
self.costs_by_tool[tool] = 0.0
|
||||
self.costs_by_tool[tool] += total
|
||||
|
||||
# Add to history
|
||||
self.usage_history.append({
|
||||
"timestamp": time.time(),
|
||||
"tool": tool,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cost": total
|
||||
})
|
||||
|
||||
return total
|
||||
|
||||
def get_summary(self) -> Dict:
|
||||
"""Get cost summary."""
|
||||
return {
|
||||
"total_cost": self.total_cost,
|
||||
"costs_by_tool": self.costs_by_tool,
|
||||
"usage_count": len(self.usage_history),
|
||||
"average_cost": self.total_cost / len(self.usage_history) if self.usage_history else 0
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert to JSON string."""
|
||||
return json.dumps(self.get_summary(), indent=2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IterationStats:
|
||||
"""Memory-efficient iteration statistics tracking.
|
||||
|
||||
Tracks per-iteration details (duration, success/failure, errors) while
|
||||
limiting stored iterations to prevent memory leaks in long-running sessions.
|
||||
"""
|
||||
|
||||
total: int = 0
|
||||
successes: int = 0
|
||||
failures: int = 0
|
||||
start_time: datetime | None = None
|
||||
current_iteration: int = 0
|
||||
iterations: List[Dict[str, Any]] = field(default_factory=list)
|
||||
max_iterations_stored: int = 1000 # Memory limit for stored iterations
|
||||
max_preview_length: int = 500 # Max chars for output preview truncation
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize start time if not set."""
|
||||
if self.start_time is None:
|
||||
self.start_time = datetime.now()
|
||||
|
||||
def record_start(self, iteration: int) -> None:
|
||||
"""Record iteration start.
|
||||
|
||||
Args:
|
||||
iteration: Iteration number
|
||||
"""
|
||||
self.current_iteration = iteration
|
||||
self.total = max(self.total, iteration)
|
||||
|
||||
def record_success(self, iteration: int) -> None:
|
||||
"""Record successful iteration.
|
||||
|
||||
Args:
|
||||
iteration: Iteration number
|
||||
"""
|
||||
self.total = iteration
|
||||
self.successes += 1
|
||||
|
||||
def record_failure(self, iteration: int) -> None:
|
||||
"""Record failed iteration.
|
||||
|
||||
Args:
|
||||
iteration: Iteration number
|
||||
"""
|
||||
self.total = iteration
|
||||
self.failures += 1
|
||||
|
||||
def record_iteration(
|
||||
self,
|
||||
iteration: int,
|
||||
duration: float,
|
||||
success: bool,
|
||||
error: str,
|
||||
trigger_reason: str = "",
|
||||
output_preview: str = "",
|
||||
tokens_used: int = 0,
|
||||
cost: float = 0.0,
|
||||
tools_used: List[str] | None = None,
|
||||
) -> None:
|
||||
"""Record iteration with full details.
|
||||
|
||||
Args:
|
||||
iteration: Iteration number
|
||||
duration: Duration in seconds
|
||||
success: Whether iteration was successful
|
||||
error: Error message if any
|
||||
trigger_reason: Why this iteration was triggered (from TriggerReason)
|
||||
output_preview: Preview of iteration output (truncated for privacy)
|
||||
tokens_used: Total tokens consumed in this iteration
|
||||
cost: Cost in dollars for this iteration
|
||||
tools_used: List of tools/MCPs invoked during iteration
|
||||
"""
|
||||
# Update basic statistics
|
||||
self.total = max(self.total, iteration)
|
||||
self.current_iteration = iteration
|
||||
|
||||
if success:
|
||||
self.successes += 1
|
||||
else:
|
||||
self.failures += 1
|
||||
|
||||
# Truncate output preview for privacy (configurable length)
|
||||
if output_preview and len(output_preview) > self.max_preview_length:
|
||||
output_preview = output_preview[:self.max_preview_length] + "..."
|
||||
|
||||
# Store detailed iteration information
|
||||
iteration_data = {
|
||||
"iteration": iteration,
|
||||
"duration": duration,
|
||||
"success": success,
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"trigger_reason": trigger_reason,
|
||||
"output_preview": output_preview,
|
||||
"tokens_used": tokens_used,
|
||||
"cost": cost,
|
||||
"tools_used": tools_used or [],
|
||||
}
|
||||
self.iterations.append(iteration_data)
|
||||
|
||||
# Enforce memory limit by evicting oldest entries
|
||||
if len(self.iterations) > self.max_iterations_stored:
|
||||
excess = len(self.iterations) - self.max_iterations_stored
|
||||
self.iterations = self.iterations[excess:]
|
||||
|
||||
def get_success_rate(self) -> float:
|
||||
"""Calculate success rate as percentage.
|
||||
|
||||
Returns:
|
||||
Success rate (0-100)
|
||||
"""
|
||||
total_attempts = self.successes + self.failures
|
||||
if total_attempts == 0:
|
||||
return 0.0
|
||||
return (self.successes / total_attempts) * 100
|
||||
|
||||
def get_runtime(self) -> str:
|
||||
"""Get human-readable runtime duration.
|
||||
|
||||
Returns:
|
||||
Runtime string (e.g., "2h 30m 15s")
|
||||
"""
|
||||
if self.start_time is None:
|
||||
return "Unknown"
|
||||
|
||||
delta = datetime.now() - self.start_time
|
||||
hours, remainder = divmod(int(delta.total_seconds()), 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
|
||||
if hours > 0:
|
||||
return f"{hours}h {minutes}m {seconds}s"
|
||||
if minutes > 0:
|
||||
return f"{minutes}m {seconds}s"
|
||||
return f"{seconds}s"
|
||||
|
||||
def get_recent_iterations(self, count: int) -> List[Dict[str, Any]]:
|
||||
"""Get most recent iterations.
|
||||
|
||||
Args:
|
||||
count: Maximum number of iterations to return
|
||||
|
||||
Returns:
|
||||
List of recent iteration data dictionaries
|
||||
"""
|
||||
if count >= len(self.iterations):
|
||||
return self.iterations.copy()
|
||||
return self.iterations[-count:]
|
||||
|
||||
def get_average_duration(self) -> float:
|
||||
"""Calculate average iteration duration.
|
||||
|
||||
Returns:
|
||||
Average duration in seconds, or 0.0 if no iterations
|
||||
"""
|
||||
if not self.iterations:
|
||||
return 0.0
|
||||
total_duration = sum(it["duration"] for it in self.iterations)
|
||||
return total_duration / len(self.iterations)
|
||||
|
||||
def get_error_messages(self) -> List[str]:
|
||||
"""Extract error messages from failed iterations.
|
||||
|
||||
Returns:
|
||||
List of non-empty error messages
|
||||
"""
|
||||
return [
|
||||
it["error"]
|
||||
for it in self.iterations
|
||||
if not it["success"] and it["error"]
|
||||
]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization.
|
||||
|
||||
Returns:
|
||||
Stats as dictionary (excludes iteration list for compatibility)
|
||||
"""
|
||||
return {
|
||||
"total": self.total,
|
||||
"current": self.current_iteration,
|
||||
"successes": self.successes,
|
||||
"failures": self.failures,
|
||||
"success_rate": self.get_success_rate(),
|
||||
"runtime": self.get_runtime(),
|
||||
"start_time": self.start_time.isoformat() if self.start_time else None,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,121 @@
|
||||
# ABOUTME: Output formatter module initialization
|
||||
# ABOUTME: Exports base classes, formatter implementations, and legacy console classes
|
||||
|
||||
"""Output formatting module for Claude adapter responses.
|
||||
|
||||
This module provides:
|
||||
|
||||
1. Legacy output utilities (backward compatible):
|
||||
- DiffStats, DiffFormatter, RalphConsole - Rich terminal utilities
|
||||
- RICH_AVAILABLE - Rich library availability flag
|
||||
|
||||
2. New formatter classes for structured output:
|
||||
- PlainTextFormatter: Basic text output without colors
|
||||
- RichTerminalFormatter: Rich terminal output with colors and panels
|
||||
- JsonFormatter: Structured JSON output for programmatic consumption
|
||||
|
||||
Example usage (new formatters):
|
||||
from ralph_orchestrator.output import (
|
||||
RichTerminalFormatter,
|
||||
VerbosityLevel,
|
||||
ToolCallInfo,
|
||||
)
|
||||
|
||||
formatter = RichTerminalFormatter(verbosity=VerbosityLevel.VERBOSE)
|
||||
tool_info = ToolCallInfo(tool_name="Read", tool_id="abc123", input_params={"path": "test.py"})
|
||||
output = formatter.format_tool_call(tool_info, iteration=1)
|
||||
formatter.print(output)
|
||||
|
||||
Example usage (legacy console):
|
||||
from ralph_orchestrator.output import RalphConsole
|
||||
|
||||
console = RalphConsole()
|
||||
console.print_status("Processing...")
|
||||
console.print_success("Done!")
|
||||
"""
|
||||
|
||||
# Import legacy classes from console module (backward compatibility)
|
||||
from .console import (
|
||||
RICH_AVAILABLE,
|
||||
DiffFormatter,
|
||||
DiffStats,
|
||||
RalphConsole,
|
||||
)
|
||||
|
||||
# Import new formatter base classes
|
||||
from .base import (
|
||||
FormatContext,
|
||||
MessageType,
|
||||
OutputFormatter,
|
||||
TokenUsage,
|
||||
ToolCallInfo,
|
||||
VerbosityLevel,
|
||||
)
|
||||
|
||||
# Import content detection
|
||||
from .content_detector import ContentDetector, ContentType
|
||||
|
||||
# Import new formatter implementations
|
||||
from .json_formatter import JsonFormatter
|
||||
from .plain import PlainTextFormatter
|
||||
from .rich_formatter import RichTerminalFormatter
|
||||
|
||||
__all__ = [
|
||||
# Legacy exports (backward compatibility)
|
||||
"RICH_AVAILABLE",
|
||||
"DiffStats",
|
||||
"DiffFormatter",
|
||||
"RalphConsole",
|
||||
# New base classes
|
||||
"OutputFormatter",
|
||||
"VerbosityLevel",
|
||||
"MessageType",
|
||||
"TokenUsage",
|
||||
"ToolCallInfo",
|
||||
"FormatContext",
|
||||
# Content detection
|
||||
"ContentDetector",
|
||||
"ContentType",
|
||||
# New formatters
|
||||
"PlainTextFormatter",
|
||||
"RichTerminalFormatter",
|
||||
"JsonFormatter",
|
||||
# Factory function
|
||||
"create_formatter",
|
||||
]
|
||||
|
||||
|
||||
def create_formatter(
|
||||
format_type: str = "rich",
|
||||
verbosity: VerbosityLevel = VerbosityLevel.NORMAL,
|
||||
**kwargs,
|
||||
) -> OutputFormatter:
|
||||
"""Factory function to create appropriate formatter.
|
||||
|
||||
Args:
|
||||
format_type: Type of formatter ("plain", "rich", "json")
|
||||
verbosity: Verbosity level for output
|
||||
**kwargs: Additional arguments passed to formatter constructor
|
||||
|
||||
Returns:
|
||||
Configured OutputFormatter instance
|
||||
|
||||
Raises:
|
||||
ValueError: If format_type is not recognized
|
||||
"""
|
||||
formatters = {
|
||||
"plain": PlainTextFormatter,
|
||||
"text": PlainTextFormatter,
|
||||
"rich": RichTerminalFormatter,
|
||||
"terminal": RichTerminalFormatter,
|
||||
"json": JsonFormatter,
|
||||
}
|
||||
|
||||
if format_type.lower() not in formatters:
|
||||
raise ValueError(
|
||||
f"Unknown format type: {format_type}. "
|
||||
f"Valid options: {', '.join(formatters.keys())}"
|
||||
)
|
||||
|
||||
formatter_class = formatters[format_type.lower()]
|
||||
return formatter_class(verbosity=verbosity, **kwargs)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,404 @@
|
||||
# ABOUTME: Base classes and interfaces for output formatting
|
||||
# ABOUTME: Defines OutputFormatter ABC with verbosity levels, event types, and token tracking
|
||||
|
||||
"""Base classes for Claude adapter output formatting."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VerbosityLevel(Enum):
|
||||
"""Verbosity levels for output formatting."""
|
||||
|
||||
QUIET = 0 # Only errors and final results
|
||||
NORMAL = 1 # Tool calls, assistant messages (no details)
|
||||
VERBOSE = 2 # Full tool inputs/outputs, detailed messages
|
||||
DEBUG = 3 # Everything including internal state
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of messages that can be formatted."""
|
||||
|
||||
SYSTEM = "system"
|
||||
ASSISTANT = "assistant"
|
||||
USER = "user"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_RESULT = "tool_result"
|
||||
ERROR = "error"
|
||||
INFO = "info"
|
||||
PROGRESS = "progress"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
"""Tracks token usage and costs."""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
model: str = ""
|
||||
|
||||
# Running totals across session
|
||||
session_input_tokens: int = 0
|
||||
session_output_tokens: int = 0
|
||||
session_total_tokens: int = 0
|
||||
session_cost: float = 0.0
|
||||
|
||||
def add(
|
||||
self,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cost: float = 0.0,
|
||||
model: str = "",
|
||||
) -> None:
|
||||
"""Add tokens to current and session totals."""
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
self.total_tokens = input_tokens + output_tokens
|
||||
self.cost = cost
|
||||
if model:
|
||||
self.model = model
|
||||
|
||||
self.session_input_tokens += input_tokens
|
||||
self.session_output_tokens += output_tokens
|
||||
self.session_total_tokens += input_tokens + output_tokens
|
||||
self.session_cost += cost
|
||||
|
||||
def reset_current(self) -> None:
|
||||
"""Reset current iteration tokens (keep session totals)."""
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.cost = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallInfo:
|
||||
"""Information about a tool call."""
|
||||
|
||||
tool_name: str
|
||||
tool_id: str
|
||||
input_params: Dict[str, Any] = field(default_factory=dict)
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
result: Optional[Any] = None
|
||||
is_error: bool = False
|
||||
duration_ms: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatContext:
|
||||
"""Context information for formatting operations."""
|
||||
|
||||
iteration: int = 0
|
||||
verbosity: VerbosityLevel = VerbosityLevel.NORMAL
|
||||
timestamp: Optional[datetime] = None
|
||||
token_usage: Optional[TokenUsage] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.timestamp is None:
|
||||
self.timestamp = datetime.now()
|
||||
if self.token_usage is None:
|
||||
self.token_usage = TokenUsage()
|
||||
|
||||
|
||||
class OutputFormatter(ABC):
|
||||
"""Abstract base class for output formatters.
|
||||
|
||||
Formatters handle rendering of Claude adapter events to different output
|
||||
formats (plain text, rich terminal, JSON, etc.). They support verbosity
|
||||
levels and consistent token usage tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, verbosity: VerbosityLevel = VerbosityLevel.NORMAL) -> None:
|
||||
"""Initialize formatter with verbosity level.
|
||||
|
||||
Args:
|
||||
verbosity: Output verbosity level
|
||||
"""
|
||||
self._verbosity = verbosity
|
||||
self._token_usage = TokenUsage()
|
||||
self._start_time = datetime.now()
|
||||
self._callbacks: List[Callable[[MessageType, Any, FormatContext], None]] = []
|
||||
|
||||
@property
|
||||
def verbosity(self) -> VerbosityLevel:
|
||||
"""Get current verbosity level."""
|
||||
return self._verbosity
|
||||
|
||||
@verbosity.setter
|
||||
def verbosity(self, level: VerbosityLevel) -> None:
|
||||
"""Set verbosity level."""
|
||||
self._verbosity = level
|
||||
|
||||
@property
|
||||
def token_usage(self) -> TokenUsage:
|
||||
"""Get current token usage."""
|
||||
return self._token_usage
|
||||
|
||||
def should_display(self, message_type: MessageType) -> bool:
|
||||
"""Check if message type should be displayed at current verbosity.
|
||||
|
||||
Args:
|
||||
message_type: Type of message to check
|
||||
|
||||
Returns:
|
||||
True if message should be displayed
|
||||
"""
|
||||
# Always show errors
|
||||
if message_type == MessageType.ERROR:
|
||||
return True
|
||||
|
||||
if self._verbosity == VerbosityLevel.QUIET:
|
||||
return False
|
||||
|
||||
if self._verbosity == VerbosityLevel.NORMAL:
|
||||
return message_type in (
|
||||
MessageType.ASSISTANT,
|
||||
MessageType.TOOL_CALL,
|
||||
MessageType.PROGRESS,
|
||||
MessageType.INFO,
|
||||
)
|
||||
|
||||
# VERBOSE and DEBUG show everything
|
||||
return True
|
||||
|
||||
def register_callback(
|
||||
self, callback: Callable[[MessageType, Any, FormatContext], None]
|
||||
) -> None:
|
||||
"""Register a callback for format events.
|
||||
|
||||
Args:
|
||||
callback: Function to call with (message_type, content, context)
|
||||
"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def _notify_callbacks(
|
||||
self, message_type: MessageType, content: Any, context: FormatContext
|
||||
) -> None:
|
||||
"""Notify all registered callbacks."""
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
callback(message_type, content, context)
|
||||
except Exception as e:
|
||||
# Log but don't let callback errors break formatting
|
||||
callback_name = getattr(callback, "__name__", repr(callback))
|
||||
_logger.debug(
|
||||
"Callback %s failed for %s: %s: %s",
|
||||
callback_name,
|
||||
message_type,
|
||||
type(e).__name__,
|
||||
e,
|
||||
)
|
||||
|
||||
def _create_context(
|
||||
self, iteration: int = 0, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> FormatContext:
|
||||
"""Create a format context with current state.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
FormatContext instance
|
||||
"""
|
||||
return FormatContext(
|
||||
iteration=iteration,
|
||||
verbosity=self._verbosity,
|
||||
timestamp=datetime.now(),
|
||||
token_usage=self._token_usage,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
def update_tokens(
|
||||
self,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cost: float = 0.0,
|
||||
model: str = "",
|
||||
) -> None:
|
||||
"""Update token usage tracking.
|
||||
|
||||
Args:
|
||||
input_tokens: Number of input tokens used
|
||||
output_tokens: Number of output tokens used
|
||||
cost: Cost in USD
|
||||
model: Model name
|
||||
"""
|
||||
self._token_usage.add(input_tokens, output_tokens, cost, model)
|
||||
|
||||
@abstractmethod
|
||||
def format_tool_call(
|
||||
self,
|
||||
tool_info: ToolCallInfo,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a tool call for display.
|
||||
|
||||
Args:
|
||||
tool_info: Tool call information
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_tool_result(
|
||||
self,
|
||||
tool_info: ToolCallInfo,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a tool result for display.
|
||||
|
||||
Args:
|
||||
tool_info: Tool call info with result
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_assistant_message(
|
||||
self,
|
||||
message: str,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format an assistant message for display.
|
||||
|
||||
Args:
|
||||
message: Assistant message text
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_system_message(
|
||||
self,
|
||||
message: str,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a system message for display.
|
||||
|
||||
Args:
|
||||
message: System message text
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_error(
|
||||
self,
|
||||
error: str,
|
||||
exception: Optional[Exception] = None,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format an error for display.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
exception: Optional exception object
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_progress(
|
||||
self,
|
||||
message: str,
|
||||
current: int = 0,
|
||||
total: int = 0,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format progress information for display.
|
||||
|
||||
Args:
|
||||
message: Progress message
|
||||
current: Current progress value
|
||||
total: Total progress value
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_token_usage(self, show_session: bool = True) -> str:
|
||||
"""Format token usage summary for display.
|
||||
|
||||
Args:
|
||||
show_session: Include session totals
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_section_header(self, title: str, iteration: int = 0) -> str:
|
||||
"""Format a section header for display.
|
||||
|
||||
Args:
|
||||
title: Section title
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_section_footer(self) -> str:
|
||||
"""Format a section footer for display.
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
pass
|
||||
|
||||
def summarize_content(self, content: str, max_length: int = 500) -> str:
|
||||
"""Summarize long content for display.
|
||||
|
||||
Args:
|
||||
content: Content to summarize
|
||||
max_length: Maximum length before truncation
|
||||
|
||||
Returns:
|
||||
Summarized content
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
|
||||
# Truncate with indicator
|
||||
half = (max_length - 20) // 2
|
||||
return f"{content[:half]}\n... [{len(content)} chars truncated] ...\n{content[-half:]}"
|
||||
|
||||
def get_elapsed_time(self) -> float:
|
||||
"""Get elapsed time since formatter creation.
|
||||
|
||||
Returns:
|
||||
Elapsed time in seconds
|
||||
"""
|
||||
return (datetime.now() - self._start_time).total_seconds()
|
||||
@@ -0,0 +1,915 @@
|
||||
# ABOUTME: Colored terminal output utilities using Rich
|
||||
# ABOUTME: Provides DiffFormatter, DiffStats, and RalphConsole for enhanced CLI output
|
||||
|
||||
"""Colored terminal output utilities using Rich."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import Rich components with fallback
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.markup import escape
|
||||
from rich.panel import Panel
|
||||
from rich.syntax import Syntax
|
||||
from rich.table import Table
|
||||
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
Console = None # type: ignore
|
||||
Markdown = None # type: ignore
|
||||
Panel = None # type: ignore
|
||||
Syntax = None # type: ignore
|
||||
Table = None # type: ignore
|
||||
|
||||
def escape(x: str) -> str:
|
||||
"""Fallback escape function."""
|
||||
return str(x).replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffStats:
|
||||
"""Statistics for diff content."""
|
||||
|
||||
additions: int = 0
|
||||
deletions: int = 0
|
||||
files: int = 0
|
||||
files_changed: dict[str, tuple[int, int]] = field(
|
||||
default_factory=dict
|
||||
) # filename -> (additions, deletions)
|
||||
|
||||
|
||||
class DiffFormatter:
|
||||
"""Formatter for enhanced diff visualization."""
|
||||
|
||||
# Diff display constants
|
||||
MAX_CONTEXT_LINES = 3 # Maximum context lines to show before/after changes
|
||||
LARGE_DIFF_THRESHOLD = 100 # Lines count for "large diff" detection
|
||||
SEPARATOR_WIDTH = 60 # Width of visual separators
|
||||
LINE_NUM_WIDTH = 6 # Width for line number display
|
||||
|
||||
# Binary file patterns
|
||||
BINARY_EXTENSIONS = {
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".pdf",
|
||||
".zip",
|
||||
".tar",
|
||||
".gz",
|
||||
".so",
|
||||
".pyc",
|
||||
".exe",
|
||||
".dll",
|
||||
}
|
||||
|
||||
def __init__(self, console: "Console") -> None:
|
||||
"""
|
||||
Initialize diff formatter.
|
||||
|
||||
Args:
|
||||
console: Rich console for output
|
||||
"""
|
||||
self.console = console
|
||||
|
||||
def format_and_print(self, text: str) -> None:
|
||||
"""
|
||||
Print diff with enhanced visualization and file path highlighting.
|
||||
|
||||
Features:
|
||||
- Color-coded diff lines (additions, deletions, context)
|
||||
- File path highlighting with improved contrast
|
||||
- Diff statistics summary (+X/-Y lines) with per-file breakdown
|
||||
- Visual separation between file changes with subtle styling
|
||||
- Enhanced hunk headers with line range info and context highlighting
|
||||
- Smart context line limiting for large diffs
|
||||
- Improved spacing for better readability
|
||||
- Binary file detection and special handling
|
||||
- Empty diff detection with clear messaging
|
||||
|
||||
Args:
|
||||
text: Diff text to render
|
||||
"""
|
||||
if not RICH_AVAILABLE:
|
||||
print(text)
|
||||
return
|
||||
|
||||
lines = text.split("\n")
|
||||
|
||||
# Calculate diff statistics
|
||||
stats = self._calculate_stats(lines)
|
||||
|
||||
# Handle empty diffs
|
||||
if stats.additions == 0 and stats.deletions == 0 and stats.files == 0:
|
||||
self.console.print("[dim italic]No changes detected[/dim italic]")
|
||||
return
|
||||
|
||||
# Print summary if we have changes
|
||||
self._print_summary(stats)
|
||||
|
||||
current_file = None
|
||||
current_file_name = None
|
||||
context_line_count = 0
|
||||
in_change_section = False
|
||||
|
||||
for line in lines:
|
||||
# File headers - highlight with bold cyan
|
||||
if line.startswith("diff --git"):
|
||||
# Add visual separator between files (except first)
|
||||
if current_file is not None:
|
||||
# Print per-file stats before separator
|
||||
if current_file_name is not None:
|
||||
self._print_file_stats(current_file_name, stats)
|
||||
self.console.print()
|
||||
self.console.print(f"[dim]{'─' * self.SEPARATOR_WIDTH}[/dim]")
|
||||
self.console.print()
|
||||
|
||||
current_file = line
|
||||
# Extract filename for stats tracking
|
||||
current_file_name = self._extract_filename(line)
|
||||
# Check for binary files
|
||||
if self._is_binary_file(line):
|
||||
self.console.print(
|
||||
f"[bold magenta]{line} [dim](binary)[/dim][/bold magenta]"
|
||||
)
|
||||
else:
|
||||
self.console.print(f"[bold cyan]{line}[/bold cyan]")
|
||||
context_line_count = 0
|
||||
elif line.startswith("Binary files"):
|
||||
# Binary file indicator
|
||||
self.console.print(f"[yellow]📦 {line}[/yellow]")
|
||||
continue
|
||||
elif line.startswith("---") or line.startswith("+++"):
|
||||
# File paths - extract and highlight
|
||||
if line.startswith("---"):
|
||||
self.console.print(f"[bold red]{line}[/bold red]")
|
||||
else: # +++
|
||||
self.console.print(f"[bold green]{line}[/bold green]")
|
||||
# Hunk headers - enhanced with context
|
||||
elif line.startswith("@@"):
|
||||
# Add subtle spacing before hunk for better visual separation
|
||||
if context_line_count > 0:
|
||||
self.console.print()
|
||||
|
||||
# Extract line ranges for better readability
|
||||
hunk_info = self._format_hunk_header(line)
|
||||
self.console.print(f"[bold magenta]{hunk_info}[/bold magenta]")
|
||||
context_line_count = 0
|
||||
in_change_section = False
|
||||
# Added lines - enhanced with bold for better contrast
|
||||
elif line.startswith("+"):
|
||||
self.console.print(f"[bold green]{line}[/bold green]")
|
||||
in_change_section = True
|
||||
context_line_count = 0
|
||||
# Removed lines - enhanced with bold for better contrast
|
||||
elif line.startswith("-"):
|
||||
self.console.print(f"[bold red]{line}[/bold red]")
|
||||
in_change_section = True
|
||||
context_line_count = 0
|
||||
# Context lines
|
||||
else:
|
||||
# Only show limited context lines for large diffs
|
||||
if stats.additions + stats.deletions > self.LARGE_DIFF_THRESHOLD:
|
||||
# Show context around changes only
|
||||
if in_change_section:
|
||||
if context_line_count < self.MAX_CONTEXT_LINES:
|
||||
self.console.print(f"[dim]{line}[/dim]")
|
||||
context_line_count += 1
|
||||
elif context_line_count == self.MAX_CONTEXT_LINES:
|
||||
self.console.print(
|
||||
"[dim italic] ⋮ (context lines omitted for readability)[/dim italic]"
|
||||
)
|
||||
context_line_count += 1
|
||||
else:
|
||||
# Leading context - always show up to limit
|
||||
if context_line_count < self.MAX_CONTEXT_LINES:
|
||||
self.console.print(f"[dim]{line}[/dim]")
|
||||
context_line_count += 1
|
||||
else:
|
||||
# Small diff - show all context
|
||||
self.console.print(f"[dim]{line}[/dim]")
|
||||
|
||||
# Print final file stats
|
||||
if current_file_name:
|
||||
self._print_file_stats(current_file_name, stats)
|
||||
|
||||
# Add spacing after diff for better separation from next content
|
||||
self.console.print()
|
||||
|
||||
def _calculate_stats(self, lines: list[str]) -> DiffStats:
|
||||
"""
|
||||
Calculate statistics from diff lines including per-file breakdown.
|
||||
|
||||
Args:
|
||||
lines: List of diff lines
|
||||
|
||||
Returns:
|
||||
DiffStats with additions, deletions, files count, and per-file breakdown
|
||||
"""
|
||||
stats = DiffStats()
|
||||
current_file = None
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("diff --git"):
|
||||
stats.files += 1
|
||||
current_file = self._extract_filename(line)
|
||||
if current_file and current_file not in stats.files_changed:
|
||||
stats.files_changed[current_file] = (0, 0)
|
||||
elif line.startswith("+") and not line.startswith("+++"):
|
||||
stats.additions += 1
|
||||
if current_file and current_file in stats.files_changed:
|
||||
adds, dels = stats.files_changed[current_file]
|
||||
stats.files_changed[current_file] = (adds + 1, dels)
|
||||
elif line.startswith("-") and not line.startswith("---"):
|
||||
stats.deletions += 1
|
||||
if current_file and current_file in stats.files_changed:
|
||||
adds, dels = stats.files_changed[current_file]
|
||||
stats.files_changed[current_file] = (adds, dels + 1)
|
||||
|
||||
return stats
|
||||
|
||||
def _print_summary(self, stats: DiffStats) -> None:
|
||||
"""
|
||||
Print diff statistics summary.
|
||||
|
||||
Args:
|
||||
stats: Diff statistics
|
||||
"""
|
||||
if stats.additions == 0 and stats.deletions == 0:
|
||||
return
|
||||
|
||||
summary = "[bold cyan]📊 Changes:[/bold cyan] "
|
||||
if stats.additions > 0:
|
||||
summary += f"[green]+{stats.additions}[/green]"
|
||||
if stats.additions > 0 and stats.deletions > 0:
|
||||
summary += " "
|
||||
if stats.deletions > 0:
|
||||
summary += f"[red]-{stats.deletions}[/red]"
|
||||
if stats.files > 1:
|
||||
summary += f" [dim]({stats.files} files)[/dim]"
|
||||
self.console.print(summary)
|
||||
self.console.print()
|
||||
|
||||
def _is_binary_file(self, diff_header: str) -> bool:
|
||||
"""
|
||||
Check if diff is for a binary file based on extension.
|
||||
|
||||
Args:
|
||||
diff_header: Diff header line (e.g., "diff --git a/file.png b/file.png")
|
||||
|
||||
Returns:
|
||||
True if file appears to be binary
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
# Extract file path from diff header
|
||||
parts = diff_header.split()
|
||||
if len(parts) >= 3:
|
||||
file_path = parts[2] # e.g., "a/file.png"
|
||||
ext = Path(file_path).suffix.lower()
|
||||
return ext in self.BINARY_EXTENSIONS
|
||||
return False
|
||||
|
||||
def _extract_filename(self, diff_header: str) -> Optional[str]:
|
||||
"""
|
||||
Extract filename from diff header line.
|
||||
|
||||
Args:
|
||||
diff_header: Diff header line (e.g., "diff --git a/file.py b/file.py")
|
||||
|
||||
Returns:
|
||||
Filename or None if not found
|
||||
"""
|
||||
parts = diff_header.split()
|
||||
if len(parts) >= 3:
|
||||
# Extract from "a/file.py" or "b/file.py"
|
||||
file_path = parts[2]
|
||||
if file_path.startswith("a/") or file_path.startswith("b/"):
|
||||
return file_path[2:]
|
||||
return file_path
|
||||
return None
|
||||
|
||||
def _print_file_stats(self, filename: str, stats: DiffStats) -> None:
|
||||
"""
|
||||
Print per-file statistics with visual bar.
|
||||
|
||||
Args:
|
||||
filename: Name of the file
|
||||
stats: DiffStats containing per-file breakdown
|
||||
"""
|
||||
if filename and filename in stats.files_changed:
|
||||
adds, dels = stats.files_changed[filename]
|
||||
if adds > 0 or dels > 0:
|
||||
# Calculate visual bar proportions (max 30 chars)
|
||||
total_changes = adds + dels
|
||||
bar_width = min(30, total_changes)
|
||||
|
||||
if total_changes > 0:
|
||||
add_width = int((adds / total_changes) * bar_width)
|
||||
del_width = bar_width - add_width
|
||||
|
||||
# Create visual bar
|
||||
bar = ""
|
||||
if add_width > 0:
|
||||
bar += f"[bold green]{'▓' * add_width}[/bold green]"
|
||||
if del_width > 0:
|
||||
bar += f"[bold red]{'▓' * del_width}[/bold red]"
|
||||
|
||||
# Print stats with bar
|
||||
summary = f" {bar} "
|
||||
if adds > 0:
|
||||
summary += f"[bold green]+{adds}[/bold green]"
|
||||
if adds > 0 and dels > 0:
|
||||
summary += " "
|
||||
if dels > 0:
|
||||
summary += f"[bold red]-{dels}[/bold red]"
|
||||
self.console.print(summary)
|
||||
|
||||
def _format_hunk_header(self, hunk: str) -> str:
|
||||
"""
|
||||
Format hunk header with enhanced readability and context highlighting.
|
||||
|
||||
Transforms: @@ -140,7 +140,7 @@ class RalphConsole:
|
||||
Into: @@ Lines 140-147 → 140-147 @@ class RalphConsole:
|
||||
With context (function/class name) highlighted in cyan.
|
||||
|
||||
Args:
|
||||
hunk: Original hunk header line
|
||||
|
||||
Returns:
|
||||
Formatted hunk header with improved readability
|
||||
"""
|
||||
# Extract line ranges using regex
|
||||
pattern = r"@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@(.*)$"
|
||||
match = re.search(pattern, hunk)
|
||||
|
||||
if not match:
|
||||
return hunk
|
||||
|
||||
old_start = int(match.group(1))
|
||||
old_count = int(match.group(2)) if match.group(2) else 1
|
||||
new_start = int(match.group(3))
|
||||
new_count = int(match.group(4)) if match.group(4) else 1
|
||||
context = match.group(5).strip()
|
||||
|
||||
# Calculate end lines
|
||||
old_end = old_start + old_count - 1
|
||||
new_end = new_start + new_count - 1
|
||||
|
||||
# Format with readable line ranges
|
||||
header = f"@@ Lines {old_start}-{old_end} → {new_start}-{new_end} @@"
|
||||
|
||||
# Highlight context (function/class name) if present
|
||||
if context:
|
||||
# Highlight the context in cyan for better visibility
|
||||
header += f" [cyan]{context}[/cyan]"
|
||||
|
||||
return header
|
||||
|
||||
|
||||
class RalphConsole:
|
||||
"""Rich console wrapper for Ralph output."""
|
||||
|
||||
# Display constants
|
||||
CLEAR_LINE_WIDTH = 80 # Characters to clear when clearing a line
|
||||
PROGRESS_BAR_WIDTH = 30 # Width of progress bar in characters
|
||||
COUNTDOWN_COLOR_CHANGE_THRESHOLD_HIGH = 5 # Seconds remaining for yellow
|
||||
COUNTDOWN_COLOR_CHANGE_THRESHOLD_LOW = 2 # Seconds remaining for red
|
||||
MARKDOWN_INDICATOR_THRESHOLD = 2 # Minimum markdown patterns to consider as markdown
|
||||
DIFF_SCAN_LINE_LIMIT = 5 # Number of lines to scan for diff indicators
|
||||
DIFF_HUNK_SCAN_CHARS = 100 # Characters to scan for diff hunk markers
|
||||
|
||||
# Regex patterns for content detection and formatting
|
||||
CODE_BLOCK_PATTERN = r"```(\w+)?\n(.*?)\n```"
|
||||
FILE_REF_PATTERN = r"(\S+\.[a-zA-Z0-9]+):(\d+)"
|
||||
INLINE_CODE_PATTERN = r"`([^`\n]+)`"
|
||||
HUNK_HEADER_PATTERN = r"@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@(.*)$"
|
||||
TABLE_SEPARATOR_PATTERN = r"^\s*\|[\s\-:|]+\|\s*$"
|
||||
MARKDOWN_HEADING_PATTERN = r"^#{1,6}\s+.+"
|
||||
MARKDOWN_UNORDERED_LIST_PATTERN = r"^[\*\-]\s+.+"
|
||||
MARKDOWN_ORDERED_LIST_PATTERN = r"^\d+\.\s+.+"
|
||||
MARKDOWN_BOLD_PATTERN = r"\*\*.+?\*\*"
|
||||
MARKDOWN_ITALIC_PATTERN = r"\*.+?\*"
|
||||
MARKDOWN_BLOCKQUOTE_PATTERN = r"^>\s+.+"
|
||||
MARKDOWN_TASK_LIST_PATTERN = r"^[\*\-]\s+\[([ xX])\]\s+.+"
|
||||
MARKDOWN_HORIZONTAL_RULE_PATTERN = r"^(\-{3,}|\*{3,}|_{3,})\s*$"
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize Rich console."""
|
||||
if RICH_AVAILABLE:
|
||||
self.console = Console()
|
||||
self.diff_formatter = DiffFormatter(self.console)
|
||||
else:
|
||||
self.console = None
|
||||
self.diff_formatter = None
|
||||
|
||||
def print_status(self, message: str, style: str = "cyan") -> None:
|
||||
"""Print status message."""
|
||||
if self.console:
|
||||
# Use markup escaping to prevent Rich from parsing brackets in the icon
|
||||
self.console.print(f"[{style}][[*]] {message}[/{style}]")
|
||||
else:
|
||||
print(f"[*] {message}")
|
||||
|
||||
def print_success(self, message: str) -> None:
|
||||
"""Print success message."""
|
||||
if self.console:
|
||||
self.console.print(f"[green]✓[/green] {message}")
|
||||
else:
|
||||
print(f"✓ {message}")
|
||||
|
||||
def print_error(self, message: str, severity: str = "error") -> None:
|
||||
"""
|
||||
Print error message with severity-based formatting.
|
||||
|
||||
Args:
|
||||
message: Error message to print
|
||||
severity: Error severity level ("critical", "error", "warning")
|
||||
"""
|
||||
severity_styles = {
|
||||
"critical": ("[red bold]⛔[/red bold]", "red bold"),
|
||||
"error": ("[red]✗[/red]", "red"),
|
||||
"warning": ("[yellow]⚠[/yellow]", "yellow"),
|
||||
}
|
||||
|
||||
icon, style = severity_styles.get(severity, severity_styles["error"])
|
||||
if self.console:
|
||||
self.console.print(f"{icon} [{style}]{message}[/{style}]")
|
||||
else:
|
||||
print(f"✗ {message}")
|
||||
|
||||
def print_warning(self, message: str) -> None:
|
||||
"""Print warning message."""
|
||||
if self.console:
|
||||
self.console.print(f"[yellow]⚠[/yellow] {message}")
|
||||
else:
|
||||
print(f"⚠ {message}")
|
||||
|
||||
def print_info(self, message: str) -> None:
|
||||
"""Print info message."""
|
||||
if self.console:
|
||||
self.console.print(f"[blue]ℹ[/blue] {message}")
|
||||
else:
|
||||
print(f"ℹ {message}")
|
||||
|
||||
def print_header(self, title: str) -> None:
|
||||
"""Print section header."""
|
||||
if self.console and Panel:
|
||||
self.console.print(
|
||||
Panel(title, style="green bold", border_style="green"),
|
||||
justify="left",
|
||||
)
|
||||
else:
|
||||
print(f"\n=== {title} ===\n")
|
||||
|
||||
def print_iteration_header(self, iteration: int) -> None:
|
||||
"""Print iteration header."""
|
||||
if self.console:
|
||||
self.console.print(
|
||||
f"\n[cyan bold]=== RALPH ITERATION {iteration} ===[/cyan bold]\n"
|
||||
)
|
||||
else:
|
||||
print(f"\n=== RALPH ITERATION {iteration} ===\n")
|
||||
|
||||
def print_stats(
|
||||
self,
|
||||
iteration: int,
|
||||
success_count: int,
|
||||
error_count: int,
|
||||
start_time: str,
|
||||
prompt_file: str,
|
||||
recent_lines: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Print statistics table.
|
||||
|
||||
Args:
|
||||
iteration: Current iteration number
|
||||
success_count: Number of successful iterations
|
||||
error_count: Number of failed iterations
|
||||
start_time: Start time string
|
||||
prompt_file: Prompt file name
|
||||
recent_lines: Recent log entries
|
||||
"""
|
||||
if not self.console or not Table:
|
||||
# Plain text fallback
|
||||
print("\nRALPH STATISTICS")
|
||||
print(f" Iteration: {iteration}")
|
||||
print(f" Successful: {success_count}")
|
||||
print(f" Failed: {error_count}")
|
||||
print(f" Started: {start_time}")
|
||||
print(f" Prompt: {prompt_file}")
|
||||
return
|
||||
|
||||
# Create stats table with better formatting
|
||||
table = Table(
|
||||
title="🤖 RALPH STATISTICS",
|
||||
show_header=True,
|
||||
header_style="bold yellow",
|
||||
border_style="cyan",
|
||||
)
|
||||
table.add_column("Metric", style="cyan bold", no_wrap=True, width=20)
|
||||
table.add_column("Value", style="white", width=40)
|
||||
|
||||
# Calculate success rate
|
||||
total = success_count + error_count
|
||||
success_rate = (success_count / total * 100) if total > 0 else 0
|
||||
|
||||
table.add_row("🔄 Current Iteration", str(iteration))
|
||||
table.add_row("✅ Successful", f"[green bold]{success_count}[/green bold]")
|
||||
table.add_row("❌ Failed", f"[red bold]{error_count}[/red bold]")
|
||||
|
||||
# Determine success rate color based on percentage
|
||||
if success_rate > 80:
|
||||
rate_color = "green"
|
||||
elif success_rate > 50:
|
||||
rate_color = "yellow"
|
||||
else:
|
||||
rate_color = "red"
|
||||
table.add_row("📊 Success Rate", f"[{rate_color}]{success_rate:.1f}%[/]")
|
||||
|
||||
table.add_row("🕐 Started", start_time or "Unknown")
|
||||
table.add_row("📝 Prompt", prompt_file)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
# Show recent activity with better formatting
|
||||
if recent_lines:
|
||||
self.console.print("\n[yellow bold]📋 RECENT ACTIVITY[/yellow bold]")
|
||||
for line in recent_lines:
|
||||
# Clean up log lines for display and escape Rich markup
|
||||
clean_line = escape(line.strip())
|
||||
if "[SUCCESS]" in clean_line:
|
||||
self.console.print(f" [green]▸[/green] {clean_line}")
|
||||
elif "[ERROR]" in clean_line:
|
||||
self.console.print(f" [red]▸[/red] {clean_line}")
|
||||
elif "[WARNING]" in clean_line:
|
||||
self.console.print(f" [yellow]▸[/yellow] {clean_line}")
|
||||
else:
|
||||
self.console.print(f" [blue]▸[/blue] {clean_line}")
|
||||
self.console.print()
|
||||
|
||||
def print_countdown(self, remaining: int, total: int) -> None:
|
||||
"""
|
||||
Print countdown timer with progress bar.
|
||||
|
||||
Args:
|
||||
remaining: Seconds remaining
|
||||
total: Total delay seconds
|
||||
"""
|
||||
# Guard against division by zero
|
||||
if total <= 0:
|
||||
return
|
||||
|
||||
# Calculate progress
|
||||
progress = (total - remaining) / total
|
||||
filled = int(self.PROGRESS_BAR_WIDTH * progress)
|
||||
bar = "█" * filled + "░" * (self.PROGRESS_BAR_WIDTH - filled)
|
||||
|
||||
# Color based on time remaining (using constants)
|
||||
if remaining > self.COUNTDOWN_COLOR_CHANGE_THRESHOLD_HIGH:
|
||||
color = "green"
|
||||
elif remaining > self.COUNTDOWN_COLOR_CHANGE_THRESHOLD_LOW:
|
||||
color = "yellow"
|
||||
else:
|
||||
color = "red"
|
||||
|
||||
if self.console:
|
||||
self.console.print(
|
||||
f"\r[{color}]⏳ [{bar}] {remaining}s / {total}s remaining[/{color}]",
|
||||
end="",
|
||||
)
|
||||
else:
|
||||
print(f"\r⏳ [{bar}] {remaining}s / {total}s remaining", end="")
|
||||
|
||||
def clear_line(self) -> None:
|
||||
"""Clear current line."""
|
||||
if self.console:
|
||||
self.console.print("\r" + " " * self.CLEAR_LINE_WIDTH + "\r", end="")
|
||||
else:
|
||||
print("\r" + " " * self.CLEAR_LINE_WIDTH + "\r", end="")
|
||||
|
||||
def print_separator(self) -> None:
|
||||
"""Print visual separator."""
|
||||
if self.console:
|
||||
self.console.print("\n[cyan]---[/cyan]\n")
|
||||
else:
|
||||
print("\n---\n")
|
||||
|
||||
def clear_screen(self) -> None:
|
||||
"""Clear screen."""
|
||||
if self.console:
|
||||
self.console.clear()
|
||||
else:
|
||||
print("\033[2J\033[H", end="")
|
||||
|
||||
def print_message(self, text: str) -> None:
|
||||
"""
|
||||
Print message with intelligent formatting and improved visual hierarchy.
|
||||
|
||||
Detects and formats:
|
||||
- Code blocks (```language) with syntax highlighting
|
||||
- Diffs (lines starting with +, -, @@) with enhanced visualization
|
||||
- Markdown tables with proper rendering
|
||||
- Markdown headings, lists, emphasis with spacing
|
||||
- Inline code (`code`) with highlighting
|
||||
- Plain text with file path detection
|
||||
|
||||
Args:
|
||||
text: Message text to print
|
||||
"""
|
||||
if not self.console:
|
||||
print(text)
|
||||
return
|
||||
|
||||
# Check if text contains code blocks
|
||||
if "```" in text:
|
||||
# Split text by code blocks and process each part
|
||||
parts = re.split(self.CODE_BLOCK_PATTERN, text, flags=re.DOTALL)
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if i % 3 == 0: # Regular text between code blocks
|
||||
if part.strip():
|
||||
self._print_formatted_text(part)
|
||||
# Add subtle spacing after text before code block
|
||||
if i + 1 < len(parts):
|
||||
self.console.print()
|
||||
elif i % 3 == 1: # Language identifier
|
||||
language = part or "text"
|
||||
code = parts[i + 1] if i + 1 < len(parts) else ""
|
||||
if code.strip() and Syntax:
|
||||
# Use syntax highlighting for code blocks with enhanced features
|
||||
syntax = Syntax(
|
||||
code,
|
||||
language,
|
||||
theme="monokai",
|
||||
line_numbers=True,
|
||||
word_wrap=True,
|
||||
indent_guides=True,
|
||||
padding=(1, 2),
|
||||
)
|
||||
self.console.print(syntax)
|
||||
# Add spacing after code block if more content follows
|
||||
if i + 2 < len(parts) and parts[i + 2].strip():
|
||||
self.console.print()
|
||||
elif self._is_diff_content(text):
|
||||
# Format as diff with enhanced visualization
|
||||
if self.diff_formatter:
|
||||
self.diff_formatter.format_and_print(text)
|
||||
else:
|
||||
print(text)
|
||||
elif self._is_markdown_table(text):
|
||||
# Render markdown tables nicely
|
||||
self._print_markdown_table(text)
|
||||
# Add spacing after table
|
||||
self.console.print()
|
||||
elif self._is_markdown_content(text):
|
||||
# Render rich markdown with headings, lists, emphasis
|
||||
self._print_markdown(text)
|
||||
else:
|
||||
# Regular text - check for inline code and format accordingly
|
||||
self._print_formatted_text(text)
|
||||
|
||||
def _is_diff_content(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text appears to be diff content.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text looks like diff output
|
||||
"""
|
||||
diff_indicators = [
|
||||
text.startswith("diff --git"),
|
||||
text.startswith("--- "),
|
||||
text.startswith("+++ "),
|
||||
"@@" in text[: self.DIFF_HUNK_SCAN_CHARS], # Diff hunk markers
|
||||
any(
|
||||
line.startswith(("+", "-", "@@"))
|
||||
for line in text.split("\n")[: self.DIFF_SCAN_LINE_LIMIT]
|
||||
),
|
||||
]
|
||||
return any(diff_indicators)
|
||||
|
||||
def _is_markdown_table(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text appears to be a markdown table.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text looks like a markdown table
|
||||
"""
|
||||
lines = text.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
# Check for table separator line (e.g., |---|---|)
|
||||
for line in lines[:3]:
|
||||
if re.match(self.TABLE_SEPARATOR_PATTERN, line):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _print_markdown_table(self, text: str) -> None:
|
||||
"""
|
||||
Print markdown table with Rich formatting.
|
||||
|
||||
Args:
|
||||
text: Markdown table text
|
||||
"""
|
||||
if Markdown:
|
||||
# Use Rich's Markdown renderer for tables
|
||||
md = Markdown(text)
|
||||
self.console.print(md)
|
||||
else:
|
||||
print(text)
|
||||
|
||||
def _is_markdown_content(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text appears to contain rich markdown (headings, lists, etc.).
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text looks like markdown with formatting
|
||||
"""
|
||||
markdown_indicators = [
|
||||
re.search(self.MARKDOWN_HEADING_PATTERN, text, re.MULTILINE), # Headings
|
||||
re.search(
|
||||
self.MARKDOWN_UNORDERED_LIST_PATTERN, text, re.MULTILINE
|
||||
), # Unordered lists
|
||||
re.search(
|
||||
self.MARKDOWN_ORDERED_LIST_PATTERN, text, re.MULTILINE
|
||||
), # Ordered lists
|
||||
re.search(self.MARKDOWN_BOLD_PATTERN, text), # Bold
|
||||
re.search(self.MARKDOWN_ITALIC_PATTERN, text), # Italic
|
||||
re.search(
|
||||
self.MARKDOWN_BLOCKQUOTE_PATTERN, text, re.MULTILINE
|
||||
), # Blockquotes
|
||||
re.search(
|
||||
self.MARKDOWN_TASK_LIST_PATTERN, text, re.MULTILINE
|
||||
), # Task lists
|
||||
re.search(
|
||||
self.MARKDOWN_HORIZONTAL_RULE_PATTERN, text, re.MULTILINE
|
||||
), # Horizontal rules
|
||||
]
|
||||
# Return true if at least MARKDOWN_INDICATOR_THRESHOLD markdown indicators present
|
||||
threshold = self.MARKDOWN_INDICATOR_THRESHOLD
|
||||
return sum(bool(indicator) for indicator in markdown_indicators) >= threshold
|
||||
|
||||
def _preprocess_markdown(self, text: str) -> str:
|
||||
"""
|
||||
Preprocess markdown text for better rendering.
|
||||
|
||||
Handles:
|
||||
- Task lists with checkboxes (- [ ] and - [x])
|
||||
- Horizontal rules with visual enhancement
|
||||
- Code blocks with language hints
|
||||
|
||||
Args:
|
||||
text: Raw markdown text
|
||||
|
||||
Returns:
|
||||
Preprocessed markdown text
|
||||
"""
|
||||
lines = text.split("\n")
|
||||
processed_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Enhanced task lists with visual indicators
|
||||
if re.match(self.MARKDOWN_TASK_LIST_PATTERN, line):
|
||||
# Replace [ ] with ☐ and [x]/[X] with ☑
|
||||
line = re.sub(r"\[\s\]", "☐", line)
|
||||
line = re.sub(r"\[[xX]\]", "☑", line)
|
||||
|
||||
# Enhanced horizontal rules - make them more visible
|
||||
if re.match(self.MARKDOWN_HORIZONTAL_RULE_PATTERN, line):
|
||||
line = f"\n{'─' * 60}\n"
|
||||
|
||||
processed_lines.append(line)
|
||||
|
||||
return "\n".join(processed_lines)
|
||||
|
||||
def _print_markdown(self, text: str) -> None:
|
||||
"""
|
||||
Print markdown content with Rich formatting and improved spacing.
|
||||
|
||||
Args:
|
||||
text: Markdown text to render
|
||||
"""
|
||||
if not Markdown:
|
||||
print(text)
|
||||
return
|
||||
|
||||
# Add subtle spacing before markdown for visual separation
|
||||
has_heading = re.search(self.MARKDOWN_HEADING_PATTERN, text, re.MULTILINE)
|
||||
if has_heading:
|
||||
self.console.print()
|
||||
|
||||
# Preprocess markdown for enhanced features
|
||||
processed_text = self._preprocess_markdown(text)
|
||||
|
||||
md = Markdown(processed_text)
|
||||
self.console.print(md)
|
||||
|
||||
# Add spacing after markdown blocks for better separation from next content
|
||||
self.console.print()
|
||||
|
||||
def _print_formatted_text(self, text: str) -> None:
|
||||
"""
|
||||
Print text with basic formatting, inline code, file path highlighting, and error detection.
|
||||
|
||||
Args:
|
||||
text: Text to print
|
||||
"""
|
||||
if not self.console:
|
||||
print(text)
|
||||
return
|
||||
|
||||
# Check for error/exception patterns and apply special formatting
|
||||
if self._is_error_traceback(text):
|
||||
self._print_error_traceback(text)
|
||||
return
|
||||
|
||||
# First, highlight file paths with line numbers (e.g., "file.py:123")
|
||||
text = re.sub(
|
||||
self.FILE_REF_PATTERN,
|
||||
lambda m: (
|
||||
f"[bold yellow]{m.group(1)}[/bold yellow]:"
|
||||
f"[bold blue]{m.group(2)}[/bold blue]"
|
||||
),
|
||||
text,
|
||||
)
|
||||
|
||||
# Check for inline code (single backticks)
|
||||
if "`" in text and "```" not in text:
|
||||
# Replace inline code with Rich markup - improved visibility
|
||||
formatted_text = re.sub(
|
||||
self.INLINE_CODE_PATTERN,
|
||||
lambda m: f"[cyan on grey23]{m.group(1)}[/cyan on grey23]",
|
||||
text,
|
||||
)
|
||||
self.console.print(formatted_text, highlight=True)
|
||||
else:
|
||||
# Enable markup for file paths and highlighting for URLs
|
||||
self.console.print(text, markup=True, highlight=True)
|
||||
|
||||
def _is_error_traceback(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text appears to be an error traceback.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text looks like an error traceback
|
||||
"""
|
||||
error_indicators = [
|
||||
"Traceback (most recent call last):" in text,
|
||||
re.search(r'^\s*File ".*", line \d+', text, re.MULTILINE),
|
||||
re.search(
|
||||
r"^(Error|Exception|ValueError|TypeError|RuntimeError):",
|
||||
text,
|
||||
re.MULTILINE,
|
||||
),
|
||||
]
|
||||
return any(error_indicators)
|
||||
|
||||
def _print_error_traceback(self, text: str) -> None:
|
||||
"""
|
||||
Print error traceback with enhanced formatting.
|
||||
|
||||
Args:
|
||||
text: Error traceback text
|
||||
"""
|
||||
if not Syntax:
|
||||
print(text)
|
||||
return
|
||||
|
||||
# Use Python syntax highlighting for tracebacks
|
||||
try:
|
||||
syntax = Syntax(
|
||||
text,
|
||||
"python",
|
||||
theme="monokai",
|
||||
line_numbers=False,
|
||||
word_wrap=True,
|
||||
background_color="grey11",
|
||||
)
|
||||
self.console.print("\n[red bold]⚠ Error Traceback:[/red bold]")
|
||||
self.console.print(syntax)
|
||||
self.console.print()
|
||||
except Exception as e:
|
||||
# Fallback to simple red text if syntax highlighting fails
|
||||
_logger.warning("Syntax highlighting failed for traceback: %s: %s", type(e).__name__, e)
|
||||
self.console.print(f"[red]{text}[/red]")
|
||||
@@ -0,0 +1,247 @@
|
||||
# ABOUTME: Content type detection for smart output formatting
|
||||
# ABOUTME: Detects diffs, code blocks, markdown, tables, and tracebacks
|
||||
|
||||
"""Content type detection for intelligent output formatting."""
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Types of content that can be detected."""
|
||||
|
||||
PLAIN_TEXT = "plain_text"
|
||||
DIFF = "diff"
|
||||
CODE_BLOCK = "code_block"
|
||||
MARKDOWN = "markdown"
|
||||
MARKDOWN_TABLE = "markdown_table"
|
||||
ERROR_TRACEBACK = "error_traceback"
|
||||
|
||||
|
||||
class ContentDetector:
|
||||
"""Detects content types for smart formatting.
|
||||
|
||||
Analyzes text content to determine the most appropriate rendering method.
|
||||
Detection priority: code_block > diff > traceback > table > markdown > plain
|
||||
"""
|
||||
|
||||
# Detection constants
|
||||
DIFF_HUNK_SCAN_CHARS = 100
|
||||
DIFF_SCAN_LINE_LIMIT = 5
|
||||
MARKDOWN_INDICATOR_THRESHOLD = 2
|
||||
|
||||
# Regex patterns
|
||||
CODE_BLOCK_PATTERN = re.compile(r"```(\w+)?\n.*?\n```", re.DOTALL)
|
||||
MARKDOWN_HEADING_PATTERN = re.compile(r"^#{1,6}\s+.+", re.MULTILINE)
|
||||
MARKDOWN_UNORDERED_LIST_PATTERN = re.compile(r"^[\*\-]\s+.+", re.MULTILINE)
|
||||
MARKDOWN_ORDERED_LIST_PATTERN = re.compile(r"^\d+\.\s+.+", re.MULTILINE)
|
||||
MARKDOWN_BOLD_PATTERN = re.compile(r"\*\*.+?\*\*")
|
||||
MARKDOWN_ITALIC_PATTERN = re.compile(r"(?<!\*)\*(?!\*)[^*\n]+\*(?!\*)")
|
||||
MARKDOWN_BLOCKQUOTE_PATTERN = re.compile(r"^>\s+.+", re.MULTILINE)
|
||||
MARKDOWN_TASK_LIST_PATTERN = re.compile(r"^[\*\-]\s+\[([ xX])\]\s+.+", re.MULTILINE)
|
||||
MARKDOWN_HORIZONTAL_RULE_PATTERN = re.compile(r"^(\-{3,}|\*{3,}|_{3,})\s*$", re.MULTILINE)
|
||||
TABLE_SEPARATOR_PATTERN = re.compile(r"^\s*\|[\s\-:|]+\|\s*$", re.MULTILINE)
|
||||
TRACEBACK_FILE_LINE_PATTERN = re.compile(r'^\s*File ".*", line \d+', re.MULTILINE)
|
||||
TRACEBACK_ERROR_PATTERN = re.compile(
|
||||
r"^(Error|Exception|ValueError|TypeError|RuntimeError|KeyError|AttributeError|"
|
||||
r"IndexError|ImportError|FileNotFoundError|NameError|ZeroDivisionError):",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
def detect(self, text: str) -> ContentType:
|
||||
"""Detect the primary content type of text.
|
||||
|
||||
Detection priority ensures more specific types are matched first:
|
||||
1. Code blocks (```...```) - highest priority
|
||||
2. Diffs (git diff format)
|
||||
3. Error tracebacks (Python exceptions)
|
||||
4. Markdown tables (|...|)
|
||||
5. Rich markdown (headings, lists, etc.)
|
||||
6. Plain text (fallback)
|
||||
|
||||
Args:
|
||||
text: Text content to analyze
|
||||
|
||||
Returns:
|
||||
The detected ContentType
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return ContentType.PLAIN_TEXT
|
||||
|
||||
# Check in priority order
|
||||
if self.is_code_block(text):
|
||||
return ContentType.CODE_BLOCK
|
||||
|
||||
if self.is_diff(text):
|
||||
return ContentType.DIFF
|
||||
|
||||
if self.is_error_traceback(text):
|
||||
return ContentType.ERROR_TRACEBACK
|
||||
|
||||
if self.is_markdown_table(text):
|
||||
return ContentType.MARKDOWN_TABLE
|
||||
|
||||
if self.is_markdown(text):
|
||||
return ContentType.MARKDOWN
|
||||
|
||||
return ContentType.PLAIN_TEXT
|
||||
|
||||
def is_diff(self, text: str) -> bool:
|
||||
"""Check if text is diff content.
|
||||
|
||||
Detects git diff format including:
|
||||
- diff --git headers
|
||||
- --- and +++ file markers (as diff markers, not markdown hr)
|
||||
- @@ hunk markers
|
||||
|
||||
Note: We avoid matching lines that merely start with + or - as those
|
||||
could be markdown list items. Diff detection requires more specific
|
||||
markers like @@ hunks or diff --git headers.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text appears to be diff content
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# Check for definitive diff markers
|
||||
diff_indicators = [
|
||||
text.startswith("diff --git"),
|
||||
"@@" in text[: self.DIFF_HUNK_SCAN_CHARS],
|
||||
]
|
||||
|
||||
if any(diff_indicators):
|
||||
return True
|
||||
|
||||
# Check for --- a/ and +++ b/ patterns (file markers in unified diff)
|
||||
# These are more specific than just --- or +++ which could be markdown hr
|
||||
lines = text.split("\n")[: self.DIFF_SCAN_LINE_LIMIT]
|
||||
has_file_markers = (
|
||||
any(line.startswith("--- a/") or line.startswith("--- /") for line in lines)
|
||||
and any(line.startswith("+++ b/") or line.startswith("+++ /") for line in lines)
|
||||
)
|
||||
if has_file_markers:
|
||||
return True
|
||||
|
||||
# Check for @@ hunk pattern specifically
|
||||
return any(line.startswith("@@") for line in lines)
|
||||
|
||||
def is_code_block(self, text: str) -> bool:
|
||||
"""Check if text contains fenced code blocks.
|
||||
|
||||
Detects markdown-style code blocks with triple backticks.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text contains code blocks
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
return "```" in text and self.CODE_BLOCK_PATTERN.search(text) is not None
|
||||
|
||||
def is_markdown(self, text: str) -> bool:
|
||||
"""Check if text contains rich markdown formatting.
|
||||
|
||||
Requires at least MARKDOWN_INDICATOR_THRESHOLD indicators to avoid
|
||||
false positives on text that happens to contain a single markdown element.
|
||||
|
||||
Detected elements:
|
||||
- Headings (# Title)
|
||||
- Lists (- item, 1. item)
|
||||
- Emphasis (**bold**, *italic*)
|
||||
- Blockquotes (> quote)
|
||||
- Task lists (- [ ] task)
|
||||
- Horizontal rules (---)
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text appears to be markdown content
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
markdown_indicators = [
|
||||
self.MARKDOWN_HEADING_PATTERN.search(text),
|
||||
self.MARKDOWN_UNORDERED_LIST_PATTERN.search(text),
|
||||
self.MARKDOWN_ORDERED_LIST_PATTERN.search(text),
|
||||
self.MARKDOWN_BOLD_PATTERN.search(text),
|
||||
self.MARKDOWN_ITALIC_PATTERN.search(text),
|
||||
self.MARKDOWN_BLOCKQUOTE_PATTERN.search(text),
|
||||
self.MARKDOWN_TASK_LIST_PATTERN.search(text),
|
||||
self.MARKDOWN_HORIZONTAL_RULE_PATTERN.search(text),
|
||||
]
|
||||
|
||||
return sum(bool(indicator) for indicator in markdown_indicators) >= self.MARKDOWN_INDICATOR_THRESHOLD
|
||||
|
||||
def is_markdown_table(self, text: str) -> bool:
|
||||
"""Check if text is a markdown table.
|
||||
|
||||
Detects tables with pipe separators and header dividers.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text appears to be a markdown table
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
lines = text.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return False
|
||||
|
||||
# Check for table separator line (|---|---|) in first few lines
|
||||
for line in lines[:3]:
|
||||
if self.TABLE_SEPARATOR_PATTERN.match(line):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_error_traceback(self, text: str) -> bool:
|
||||
"""Check if text is an error traceback.
|
||||
|
||||
Detects Python exception tracebacks.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text appears to be an error traceback
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
error_indicators = [
|
||||
"Traceback (most recent call last):" in text,
|
||||
self.TRACEBACK_FILE_LINE_PATTERN.search(text),
|
||||
self.TRACEBACK_ERROR_PATTERN.search(text),
|
||||
]
|
||||
return any(error_indicators)
|
||||
|
||||
def extract_code_blocks(self, text: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Extract code blocks from text.
|
||||
|
||||
Args:
|
||||
text: Text containing code blocks
|
||||
|
||||
Returns:
|
||||
List of (language, code) tuples. Language may be None.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
blocks = []
|
||||
pattern = re.compile(r"```(\w+)?\n(.*?)\n```", re.DOTALL)
|
||||
for match in pattern.finditer(text):
|
||||
language = match.group(1)
|
||||
code = match.group(2)
|
||||
blocks.append((language, code))
|
||||
return blocks
|
||||
@@ -0,0 +1,416 @@
|
||||
# ABOUTME: JSON output formatter for programmatic consumption
|
||||
# ABOUTME: Produces structured JSON output for parsing by other tools
|
||||
|
||||
"""JSON output formatter for Claude adapter."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import (
|
||||
MessageType,
|
||||
OutputFormatter,
|
||||
ToolCallInfo,
|
||||
VerbosityLevel,
|
||||
)
|
||||
|
||||
|
||||
class JsonFormatter(OutputFormatter):
|
||||
"""JSON formatter for programmatic output consumption.
|
||||
|
||||
Produces structured JSON output suitable for parsing by other tools,
|
||||
logging systems, or downstream processing pipelines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbosity: VerbosityLevel = VerbosityLevel.NORMAL,
|
||||
pretty: bool = True,
|
||||
include_timestamps: bool = True,
|
||||
) -> None:
|
||||
"""Initialize JSON formatter.
|
||||
|
||||
Args:
|
||||
verbosity: Output verbosity level
|
||||
pretty: Pretty-print JSON with indentation
|
||||
include_timestamps: Include timestamps in output
|
||||
"""
|
||||
super().__init__(verbosity)
|
||||
self._pretty = pretty
|
||||
self._include_timestamps = include_timestamps
|
||||
self._events: List[Dict[str, Any]] = []
|
||||
|
||||
def _to_json(self, obj: Dict[str, Any]) -> str:
|
||||
"""Convert object to JSON string.
|
||||
|
||||
Args:
|
||||
obj: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
JSON string
|
||||
"""
|
||||
if self._pretty:
|
||||
return json.dumps(obj, indent=2, default=str, ensure_ascii=False)
|
||||
return json.dumps(obj, default=str, ensure_ascii=False)
|
||||
|
||||
def _create_event(
|
||||
self,
|
||||
event_type: str,
|
||||
data: Dict[str, Any],
|
||||
iteration: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a structured event object.
|
||||
|
||||
Args:
|
||||
event_type: Type of event
|
||||
data: Event data
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Event dictionary
|
||||
"""
|
||||
event = {
|
||||
"type": event_type,
|
||||
"iteration": iteration,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
if self._include_timestamps:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
return event
|
||||
|
||||
def _record_event(self, event: Dict[str, Any]) -> None:
|
||||
"""Record event for later retrieval.
|
||||
|
||||
Args:
|
||||
event: Event dictionary to record
|
||||
"""
|
||||
self._events.append(event)
|
||||
|
||||
def get_events(self) -> List[Dict[str, Any]]:
|
||||
"""Get all recorded events.
|
||||
|
||||
Returns:
|
||||
List of event dictionaries
|
||||
"""
|
||||
return self._events.copy()
|
||||
|
||||
def clear_events(self) -> None:
|
||||
"""Clear recorded events."""
|
||||
self._events.clear()
|
||||
|
||||
def format_tool_call(
|
||||
self,
|
||||
tool_info: ToolCallInfo,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a tool call as JSON.
|
||||
|
||||
Args:
|
||||
tool_info: Tool call information
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.TOOL_CALL):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.TOOL_CALL, tool_info, context)
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"tool_name": tool_info.tool_name,
|
||||
"tool_id": tool_info.tool_id,
|
||||
}
|
||||
|
||||
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
|
||||
data["input_params"] = tool_info.input_params
|
||||
|
||||
if tool_info.start_time:
|
||||
data["start_time"] = tool_info.start_time.isoformat()
|
||||
|
||||
event = self._create_event("tool_call", data, iteration)
|
||||
self._record_event(event)
|
||||
return self._to_json(event)
|
||||
|
||||
def format_tool_result(
|
||||
self,
|
||||
tool_info: ToolCallInfo,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a tool result as JSON.
|
||||
|
||||
Args:
|
||||
tool_info: Tool call info with result
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.TOOL_RESULT):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.TOOL_RESULT, tool_info, context)
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"tool_name": tool_info.tool_name,
|
||||
"tool_id": tool_info.tool_id,
|
||||
"is_error": tool_info.is_error,
|
||||
}
|
||||
|
||||
if tool_info.duration_ms is not None:
|
||||
data["duration_ms"] = tool_info.duration_ms
|
||||
|
||||
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
|
||||
result = tool_info.result
|
||||
if isinstance(result, str) and len(result) > 1000:
|
||||
data["result"] = self.summarize_content(result, 1000)
|
||||
data["result_truncated"] = True
|
||||
data["result_full_length"] = len(result)
|
||||
else:
|
||||
data["result"] = result
|
||||
|
||||
if tool_info.end_time:
|
||||
data["end_time"] = tool_info.end_time.isoformat()
|
||||
|
||||
event = self._create_event("tool_result", data, iteration)
|
||||
self._record_event(event)
|
||||
return self._to_json(event)
|
||||
|
||||
def format_assistant_message(
|
||||
self,
|
||||
message: str,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format an assistant message as JSON.
|
||||
|
||||
Args:
|
||||
message: Assistant message text
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.ASSISTANT):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.ASSISTANT, message, context)
|
||||
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
if self._verbosity == VerbosityLevel.NORMAL and len(message) > 1000:
|
||||
data["message"] = self.summarize_content(message, 1000)
|
||||
data["message_truncated"] = True
|
||||
data["message_full_length"] = len(message)
|
||||
else:
|
||||
data["message"] = message
|
||||
data["message_truncated"] = False
|
||||
|
||||
data["message_length"] = len(message)
|
||||
|
||||
event = self._create_event("assistant_message", data, iteration)
|
||||
self._record_event(event)
|
||||
return self._to_json(event)
|
||||
|
||||
def format_system_message(
|
||||
self,
|
||||
message: str,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a system message as JSON.
|
||||
|
||||
Args:
|
||||
message: System message text
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.SYSTEM):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.SYSTEM, message, context)
|
||||
|
||||
data = {
|
||||
"message": message,
|
||||
}
|
||||
|
||||
event = self._create_event("system_message", data, iteration)
|
||||
self._record_event(event)
|
||||
return self._to_json(event)
|
||||
|
||||
def format_error(
|
||||
self,
|
||||
error: str,
|
||||
exception: Optional[Exception] = None,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format an error as JSON.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
exception: Optional exception object
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.ERROR, error, context)
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"error": error,
|
||||
}
|
||||
|
||||
if exception:
|
||||
data["exception_type"] = type(exception).__name__
|
||||
data["exception_str"] = str(exception)
|
||||
|
||||
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
|
||||
import traceback
|
||||
|
||||
data["traceback"] = traceback.format_exception(
|
||||
type(exception), exception, exception.__traceback__
|
||||
)
|
||||
|
||||
event = self._create_event("error", data, iteration)
|
||||
self._record_event(event)
|
||||
return self._to_json(event)
|
||||
|
||||
def format_progress(
|
||||
self,
|
||||
message: str,
|
||||
current: int = 0,
|
||||
total: int = 0,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format progress information as JSON.
|
||||
|
||||
Args:
|
||||
message: Progress message
|
||||
current: Current progress value
|
||||
total: Total progress value
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.PROGRESS):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.PROGRESS, message, context)
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"message": message,
|
||||
"current": current,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
if total > 0:
|
||||
data["percentage"] = round((current / total) * 100, 1)
|
||||
|
||||
event = self._create_event("progress", data, iteration)
|
||||
self._record_event(event)
|
||||
return self._to_json(event)
|
||||
|
||||
def format_token_usage(self, show_session: bool = True) -> str:
|
||||
"""Format token usage summary as JSON.
|
||||
|
||||
Args:
|
||||
show_session: Include session totals
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
usage = self._token_usage
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"current": {
|
||||
"input_tokens": usage.input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"cost": usage.cost,
|
||||
},
|
||||
}
|
||||
|
||||
if show_session:
|
||||
data["session"] = {
|
||||
"input_tokens": usage.session_input_tokens,
|
||||
"output_tokens": usage.session_output_tokens,
|
||||
"total_tokens": usage.session_total_tokens,
|
||||
"cost": usage.session_cost,
|
||||
}
|
||||
|
||||
if usage.model:
|
||||
data["model"] = usage.model
|
||||
|
||||
event = self._create_event("token_usage", data, 0)
|
||||
return self._to_json(event)
|
||||
|
||||
def format_section_header(self, title: str, iteration: int = 0) -> str:
|
||||
"""Format a section header as JSON.
|
||||
|
||||
Args:
|
||||
title: Section title
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
data = {
|
||||
"title": title,
|
||||
"elapsed_seconds": self.get_elapsed_time(),
|
||||
}
|
||||
|
||||
event = self._create_event("section_start", data, iteration)
|
||||
self._record_event(event)
|
||||
return self._to_json(event)
|
||||
|
||||
def format_section_footer(self) -> str:
|
||||
"""Format a section footer as JSON.
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
data = {
|
||||
"elapsed_seconds": self.get_elapsed_time(),
|
||||
}
|
||||
|
||||
event = self._create_event("section_end", data, 0)
|
||||
self._record_event(event)
|
||||
return self._to_json(event)
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of all recorded events.
|
||||
|
||||
Returns:
|
||||
Summary dictionary with counts and totals
|
||||
"""
|
||||
event_counts: Dict[str, int] = {}
|
||||
for event in self._events:
|
||||
event_type = event.get("type", "unknown")
|
||||
event_counts[event_type] = event_counts.get(event_type, 0) + 1
|
||||
|
||||
return {
|
||||
"total_events": len(self._events),
|
||||
"event_counts": event_counts,
|
||||
"token_usage": {
|
||||
"total_tokens": self._token_usage.session_total_tokens,
|
||||
"total_cost": self._token_usage.session_cost,
|
||||
},
|
||||
"elapsed_seconds": self.get_elapsed_time(),
|
||||
}
|
||||
|
||||
def export_events(self) -> str:
|
||||
"""Export all recorded events as a JSON array.
|
||||
|
||||
Returns:
|
||||
JSON string with all events
|
||||
"""
|
||||
return self._to_json({"events": self._events, "summary": self.get_summary()})
|
||||
@@ -0,0 +1,298 @@
|
||||
# ABOUTME: Plain text output formatter for non-terminal environments
|
||||
# ABOUTME: Provides basic text formatting without colors or special characters
|
||||
|
||||
"""Plain text output formatter for Claude adapter."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
MessageType,
|
||||
OutputFormatter,
|
||||
ToolCallInfo,
|
||||
VerbosityLevel,
|
||||
)
|
||||
|
||||
|
||||
class PlainTextFormatter(OutputFormatter):
|
||||
"""Plain text formatter for environments without rich terminal support.
|
||||
|
||||
Produces readable output without ANSI codes, colors, or special characters.
|
||||
Suitable for logging to files or basic terminal output.
|
||||
"""
|
||||
|
||||
# Formatting constants
|
||||
SEPARATOR_WIDTH = 60
|
||||
HEADER_CHAR = "="
|
||||
SUBHEADER_CHAR = "-"
|
||||
SECTION_CHAR = "#"
|
||||
|
||||
def __init__(self, verbosity: VerbosityLevel = VerbosityLevel.NORMAL) -> None:
|
||||
"""Initialize plain text formatter.
|
||||
|
||||
Args:
|
||||
verbosity: Output verbosity level
|
||||
"""
|
||||
super().__init__(verbosity)
|
||||
|
||||
def _timestamp(self) -> str:
|
||||
"""Get formatted timestamp string."""
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def _separator(self, char: str = "-", width: int = None) -> str:
|
||||
"""Create a separator line."""
|
||||
return char * (width or self.SEPARATOR_WIDTH)
|
||||
|
||||
def format_tool_call(
|
||||
self,
|
||||
tool_info: ToolCallInfo,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a tool call for plain text display.
|
||||
|
||||
Args:
|
||||
tool_info: Tool call information
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.TOOL_CALL):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.TOOL_CALL, tool_info, context)
|
||||
|
||||
lines = [
|
||||
self._separator(),
|
||||
f"[{self._timestamp()}] TOOL CALL: {tool_info.tool_name}",
|
||||
f" ID: {tool_info.tool_id[:12]}...",
|
||||
]
|
||||
|
||||
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
|
||||
if tool_info.input_params:
|
||||
lines.append(" Input Parameters:")
|
||||
for key, value in tool_info.input_params.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:97] + "..."
|
||||
lines.append(f" {key}: {value_str}")
|
||||
|
||||
lines.append(self._separator())
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_tool_result(
|
||||
self,
|
||||
tool_info: ToolCallInfo,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a tool result for plain text display.
|
||||
|
||||
Args:
|
||||
tool_info: Tool call info with result
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.TOOL_RESULT):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.TOOL_RESULT, tool_info, context)
|
||||
|
||||
status = "ERROR" if tool_info.is_error else "Success"
|
||||
duration = f" ({tool_info.duration_ms}ms)" if tool_info.duration_ms else ""
|
||||
|
||||
lines = [
|
||||
f"[TOOL RESULT] {tool_info.tool_name}{duration}",
|
||||
f" ID: {tool_info.tool_id[:12]}...",
|
||||
f" Status: {status}",
|
||||
]
|
||||
|
||||
if self._verbosity.value >= VerbosityLevel.VERBOSE.value and tool_info.result:
|
||||
result_str = str(tool_info.result)
|
||||
if len(result_str) > 500:
|
||||
result_str = self.summarize_content(result_str, 500)
|
||||
lines.append(" Output:")
|
||||
for line in result_str.split("\n"):
|
||||
lines.append(f" {line}")
|
||||
|
||||
lines.append(self._separator())
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_assistant_message(
|
||||
self,
|
||||
message: str,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format an assistant message for plain text display.
|
||||
|
||||
Args:
|
||||
message: Assistant message text
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.ASSISTANT):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.ASSISTANT, message, context)
|
||||
|
||||
if self._verbosity == VerbosityLevel.QUIET:
|
||||
return ""
|
||||
|
||||
# Summarize if too long and not verbose
|
||||
if self._verbosity == VerbosityLevel.NORMAL and len(message) > 1000:
|
||||
message = self.summarize_content(message, 1000)
|
||||
|
||||
return f"[{self._timestamp()}] ASSISTANT:\n{message}\n"
|
||||
|
||||
def format_system_message(
|
||||
self,
|
||||
message: str,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a system message for plain text display.
|
||||
|
||||
Args:
|
||||
message: System message text
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.SYSTEM):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.SYSTEM, message, context)
|
||||
|
||||
return f"[{self._timestamp()}] SYSTEM: {message}\n"
|
||||
|
||||
def format_error(
|
||||
self,
|
||||
error: str,
|
||||
exception: Optional[Exception] = None,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format an error for plain text display.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
exception: Optional exception object
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.ERROR, error, context)
|
||||
|
||||
lines = [
|
||||
self._separator(self.HEADER_CHAR),
|
||||
f"[{self._timestamp()}] ERROR (Iteration {iteration})",
|
||||
f" Message: {error}",
|
||||
]
|
||||
|
||||
if exception and self._verbosity.value >= VerbosityLevel.VERBOSE.value:
|
||||
lines.append(f" Type: {type(exception).__name__}")
|
||||
import traceback
|
||||
|
||||
tb = "".join(traceback.format_exception(type(exception), exception, exception.__traceback__))
|
||||
lines.append(" Traceback:")
|
||||
for line in tb.split("\n"):
|
||||
lines.append(f" {line}")
|
||||
|
||||
lines.append(self._separator(self.HEADER_CHAR))
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_progress(
|
||||
self,
|
||||
message: str,
|
||||
current: int = 0,
|
||||
total: int = 0,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format progress information for plain text display.
|
||||
|
||||
Args:
|
||||
message: Progress message
|
||||
current: Current progress value
|
||||
total: Total progress value
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.PROGRESS):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.PROGRESS, message, context)
|
||||
|
||||
if total > 0:
|
||||
pct = (current / total) * 100
|
||||
bar_width = 30
|
||||
filled = int(bar_width * current / total)
|
||||
bar = "#" * filled + "-" * (bar_width - filled)
|
||||
return f"[{bar}] {pct:.1f}% - {message}"
|
||||
return f"[...] {message}"
|
||||
|
||||
def format_token_usage(self, show_session: bool = True) -> str:
|
||||
"""Format token usage summary for plain text display.
|
||||
|
||||
Args:
|
||||
show_session: Include session totals
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
usage = self._token_usage
|
||||
lines = [
|
||||
self._separator(self.SUBHEADER_CHAR),
|
||||
"TOKEN USAGE:",
|
||||
f" Current: {usage.total_tokens:,} tokens (${usage.cost:.4f})",
|
||||
f" Input: {usage.input_tokens:,} | Output: {usage.output_tokens:,}",
|
||||
]
|
||||
|
||||
if show_session:
|
||||
lines.extend([
|
||||
f" Session: {usage.session_total_tokens:,} tokens (${usage.session_cost:.4f})",
|
||||
f" Input: {usage.session_input_tokens:,} | Output: {usage.session_output_tokens:,}",
|
||||
])
|
||||
|
||||
if usage.model:
|
||||
lines.append(f" Model: {usage.model}")
|
||||
|
||||
lines.append(self._separator(self.SUBHEADER_CHAR))
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_section_header(self, title: str, iteration: int = 0) -> str:
|
||||
"""Format a section header for plain text display.
|
||||
|
||||
Args:
|
||||
title: Section title
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
lines = [
|
||||
"",
|
||||
self._separator(self.HEADER_CHAR),
|
||||
f"{title} (Iteration {iteration})" if iteration else title,
|
||||
self._separator(self.HEADER_CHAR),
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_section_footer(self) -> str:
|
||||
"""Format a section footer for plain text display.
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
elapsed = self.get_elapsed_time()
|
||||
return f"\n{self._separator(self.SUBHEADER_CHAR)}\nElapsed: {elapsed:.1f}s\n"
|
||||
@@ -0,0 +1,737 @@
|
||||
# ABOUTME: Rich terminal formatter with colors, panels, and progress indicators
|
||||
# ABOUTME: Provides visually enhanced output using the Rich library with smart content detection
|
||||
|
||||
"""Rich terminal output formatter for Claude adapter.
|
||||
|
||||
This formatter provides intelligent content detection and rendering:
|
||||
- Diffs are rendered with color-coded additions/deletions
|
||||
- Code blocks get syntax highlighting
|
||||
- Markdown is rendered with proper formatting
|
||||
- Error tracebacks are highlighted for readability
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import Optional
|
||||
|
||||
from .base import (
|
||||
MessageType,
|
||||
OutputFormatter,
|
||||
ToolCallInfo,
|
||||
VerbosityLevel,
|
||||
)
|
||||
from .content_detector import ContentDetector, ContentType
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import Rich components with fallback
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
|
||||
from rich.syntax import Syntax
|
||||
from rich.markup import escape
|
||||
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
Console = None # type: ignore
|
||||
Panel = None # type: ignore
|
||||
|
||||
|
||||
class RichTerminalFormatter(OutputFormatter):
|
||||
"""Rich terminal formatter with colors, panels, and progress indicators.
|
||||
|
||||
Provides visually enhanced output using the Rich library for terminal
|
||||
display. Falls back to plain text if Rich is not available.
|
||||
"""
|
||||
|
||||
# Color scheme
|
||||
COLORS = {
|
||||
"tool_name": "bold cyan",
|
||||
"tool_id": "dim",
|
||||
"success": "bold green",
|
||||
"error": "bold red",
|
||||
"warning": "yellow",
|
||||
"info": "blue",
|
||||
"timestamp": "dim white",
|
||||
"header": "bold magenta",
|
||||
"assistant": "white",
|
||||
"system": "dim cyan",
|
||||
"token_input": "green",
|
||||
"token_output": "yellow",
|
||||
"cost": "bold yellow",
|
||||
}
|
||||
|
||||
# Icons
|
||||
ICONS = {
|
||||
"tool": "",
|
||||
"success": "",
|
||||
"error": "",
|
||||
"warning": "",
|
||||
"info": "",
|
||||
"assistant": "",
|
||||
"system": "",
|
||||
"token": "",
|
||||
"clock": "",
|
||||
"progress": "",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbosity: VerbosityLevel = VerbosityLevel.NORMAL,
|
||||
console: Optional["Console"] = None,
|
||||
smart_detection: bool = True,
|
||||
) -> None:
|
||||
"""Initialize rich terminal formatter.
|
||||
|
||||
Args:
|
||||
verbosity: Output verbosity level
|
||||
console: Optional Rich console instance (creates new if None)
|
||||
smart_detection: Enable smart content detection (diff, code, markdown)
|
||||
"""
|
||||
super().__init__(verbosity)
|
||||
self._rich_available = RICH_AVAILABLE
|
||||
self._smart_detection = smart_detection
|
||||
self._content_detector = ContentDetector() if smart_detection else None
|
||||
|
||||
if RICH_AVAILABLE:
|
||||
self._console = console or Console()
|
||||
# Import DiffFormatter for diff rendering
|
||||
from .console import DiffFormatter
|
||||
self._diff_formatter = DiffFormatter(self._console)
|
||||
else:
|
||||
self._console = None
|
||||
self._diff_formatter = None
|
||||
|
||||
@property
|
||||
def console(self) -> Optional["Console"]:
|
||||
"""Get the Rich console instance."""
|
||||
return self._console
|
||||
|
||||
def _timestamp(self) -> str:
|
||||
"""Get formatted timestamp string with Rich markup."""
|
||||
ts = datetime.now().strftime("%H:%M:%S")
|
||||
if self._rich_available:
|
||||
return f"[{self.COLORS['timestamp']}]{ts}[/]"
|
||||
return ts
|
||||
|
||||
def _full_timestamp(self) -> str:
|
||||
"""Get full timestamp with date."""
|
||||
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self._rich_available:
|
||||
return f"[{self.COLORS['timestamp']}]{ts}[/]"
|
||||
return ts
|
||||
|
||||
def format_tool_call(
|
||||
self,
|
||||
tool_info: ToolCallInfo,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a tool call for rich terminal display.
|
||||
|
||||
Args:
|
||||
tool_info: Tool call information
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.TOOL_CALL):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.TOOL_CALL, tool_info, context)
|
||||
|
||||
if not self._rich_available:
|
||||
return self._format_tool_call_plain(tool_info)
|
||||
|
||||
# Build rich formatted output
|
||||
icon = self.ICONS["tool"]
|
||||
name_color = self.COLORS["tool_name"]
|
||||
id_color = self.COLORS["tool_id"]
|
||||
|
||||
lines = [
|
||||
f"{icon} [{name_color}]TOOL CALL: {tool_info.tool_name}[/]",
|
||||
f" [{id_color}]ID: {tool_info.tool_id[:12]}...[/]",
|
||||
]
|
||||
|
||||
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
|
||||
if tool_info.input_params:
|
||||
lines.append(f" [{self.COLORS['info']}]Input Parameters:[/]")
|
||||
for key, value in tool_info.input_params.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:97] + "..."
|
||||
# Escape Rich markup in values
|
||||
if self._rich_available:
|
||||
value_str = escape(value_str)
|
||||
lines.append(f" - {key}: {value_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_tool_call_plain(self, tool_info: ToolCallInfo) -> str:
|
||||
"""Plain fallback for tool call formatting."""
|
||||
lines = [
|
||||
f"TOOL CALL: {tool_info.tool_name}",
|
||||
f" ID: {tool_info.tool_id[:12]}...",
|
||||
]
|
||||
if tool_info.input_params:
|
||||
for key, value in tool_info.input_params.items():
|
||||
value_str = str(value)[:100]
|
||||
lines.append(f" {key}: {value_str}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_tool_result(
|
||||
self,
|
||||
tool_info: ToolCallInfo,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a tool result for rich terminal display.
|
||||
|
||||
Args:
|
||||
tool_info: Tool call info with result
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.TOOL_RESULT):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.TOOL_RESULT, tool_info, context)
|
||||
|
||||
if not self._rich_available:
|
||||
return self._format_tool_result_plain(tool_info)
|
||||
|
||||
# Determine status styling
|
||||
if tool_info.is_error:
|
||||
status_icon = self.ICONS["error"]
|
||||
status_color = self.COLORS["error"]
|
||||
status_text = "ERROR"
|
||||
else:
|
||||
status_icon = self.ICONS["success"]
|
||||
status_color = self.COLORS["success"]
|
||||
status_text = "Success"
|
||||
|
||||
duration = f" ({tool_info.duration_ms}ms)" if tool_info.duration_ms else ""
|
||||
|
||||
lines = [
|
||||
f"{status_icon} [{status_color}]TOOL RESULT: {tool_info.tool_name}{duration}[/]",
|
||||
f" [{self.COLORS['tool_id']}]ID: {tool_info.tool_id[:12]}...[/]",
|
||||
f" Status: [{status_color}]{status_text}[/]",
|
||||
]
|
||||
|
||||
if self._verbosity.value >= VerbosityLevel.VERBOSE.value and tool_info.result:
|
||||
result_str = str(tool_info.result)
|
||||
if len(result_str) > 500:
|
||||
result_str = self.summarize_content(result_str, 500)
|
||||
# Escape Rich markup in result
|
||||
if self._rich_available:
|
||||
result_str = escape(result_str)
|
||||
lines.append(f" [{self.COLORS['info']}]Output:[/]")
|
||||
for line in result_str.split("\n")[:20]: # Limit lines
|
||||
lines.append(f" {line}")
|
||||
if result_str.count("\n") > 20:
|
||||
lines.append(f" [{self.COLORS['timestamp']}]... ({result_str.count(chr(10)) - 20} more lines)[/]")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_tool_result_plain(self, tool_info: ToolCallInfo) -> str:
|
||||
"""Plain fallback for tool result formatting."""
|
||||
status = "ERROR" if tool_info.is_error else "Success"
|
||||
lines = [
|
||||
f"TOOL RESULT: {tool_info.tool_name}",
|
||||
f" Status: {status}",
|
||||
]
|
||||
if tool_info.result:
|
||||
lines.append(f" Output: {str(tool_info.result)[:200]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_assistant_message(
|
||||
self,
|
||||
message: str,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format an assistant message for rich terminal display.
|
||||
|
||||
With smart_detection enabled, detects and renders:
|
||||
- Diffs with color-coded additions/deletions
|
||||
- Code blocks with syntax highlighting
|
||||
- Markdown with proper formatting
|
||||
- Error tracebacks with special highlighting
|
||||
|
||||
Args:
|
||||
message: Assistant message text
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.ASSISTANT):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.ASSISTANT, message, context)
|
||||
|
||||
if self._verbosity == VerbosityLevel.QUIET:
|
||||
return ""
|
||||
|
||||
# Summarize if needed (only for normal verbosity)
|
||||
display_message = message
|
||||
if self._verbosity == VerbosityLevel.NORMAL and len(message) > 1000:
|
||||
display_message = self.summarize_content(message, 1000)
|
||||
|
||||
if not self._rich_available:
|
||||
return f"ASSISTANT: {display_message}"
|
||||
|
||||
# Use smart detection if enabled
|
||||
if self._smart_detection and self._content_detector:
|
||||
content_type = self._content_detector.detect(display_message)
|
||||
return self._render_smart_content(display_message, content_type)
|
||||
|
||||
# Fallback to simple formatting
|
||||
icon = self.ICONS["assistant"]
|
||||
return f"{icon} [{self.COLORS['assistant']}]{escape(display_message)}[/]"
|
||||
|
||||
def _render_smart_content(self, text: str, content_type: ContentType) -> str:
|
||||
"""Render content based on detected type.
|
||||
|
||||
Args:
|
||||
text: Text content to render
|
||||
content_type: Detected content type
|
||||
|
||||
Returns:
|
||||
Formatted string (may include Rich markup)
|
||||
"""
|
||||
if content_type == ContentType.DIFF:
|
||||
return self._render_diff(text)
|
||||
elif content_type == ContentType.CODE_BLOCK:
|
||||
return self._render_code_blocks(text)
|
||||
elif content_type == ContentType.MARKDOWN:
|
||||
return self._render_markdown(text)
|
||||
elif content_type == ContentType.MARKDOWN_TABLE:
|
||||
return self._render_markdown(text) # Tables use markdown renderer
|
||||
elif content_type == ContentType.ERROR_TRACEBACK:
|
||||
return self._render_traceback(text)
|
||||
else:
|
||||
# Plain text - escape and format
|
||||
icon = self.ICONS["assistant"]
|
||||
return f"{icon} [{self.COLORS['assistant']}]{escape(text)}[/]"
|
||||
|
||||
def _render_diff(self, text: str) -> str:
|
||||
"""Render diff content with colors.
|
||||
|
||||
Uses the DiffFormatter for enhanced diff visualization with
|
||||
color-coded additions/deletions and file statistics.
|
||||
|
||||
Args:
|
||||
text: Diff text to render
|
||||
|
||||
Returns:
|
||||
Empty string (diff is printed directly to console)
|
||||
"""
|
||||
if self._diff_formatter and self._console:
|
||||
# DiffFormatter prints directly, so capture would require buffer
|
||||
# For now, print directly and return marker
|
||||
self._diff_formatter.format_and_print(text)
|
||||
return "" # Already printed
|
||||
return f"[dim]{text}[/dim]"
|
||||
|
||||
def _render_code_blocks(self, text: str) -> str:
|
||||
"""Render text with code blocks using syntax highlighting.
|
||||
|
||||
Extracts code blocks, renders them with Rich Syntax, and
|
||||
formats the surrounding text.
|
||||
|
||||
Args:
|
||||
text: Text containing code blocks
|
||||
|
||||
Returns:
|
||||
Formatted string with code block markers for print_smart()
|
||||
"""
|
||||
if not self._console or not self._content_detector:
|
||||
return f"[dim]{text}[/dim]"
|
||||
|
||||
# Split by code blocks and render each part
|
||||
parts = []
|
||||
pattern = r"```(\w+)?\n(.*?)\n```"
|
||||
last_end = 0
|
||||
|
||||
for match in re.finditer(pattern, text, re.DOTALL):
|
||||
# Text before code block
|
||||
before = text[last_end:match.start()].strip()
|
||||
if before:
|
||||
parts.append(("text", before))
|
||||
|
||||
# Code block
|
||||
language = match.group(1) or "text"
|
||||
code = match.group(2)
|
||||
parts.append(("code", language, code))
|
||||
last_end = match.end()
|
||||
|
||||
# Text after last code block
|
||||
after = text[last_end:].strip()
|
||||
if after:
|
||||
parts.append(("text", after))
|
||||
|
||||
# Render to string buffer using console
|
||||
buffer = StringIO()
|
||||
temp_console = Console(file=buffer, force_terminal=True, width=100)
|
||||
|
||||
for part in parts:
|
||||
if part[0] == "text":
|
||||
temp_console.print(part[1], markup=True, highlight=True)
|
||||
else: # code
|
||||
_, language, code = part
|
||||
syntax = Syntax(
|
||||
code,
|
||||
language,
|
||||
theme="monokai",
|
||||
line_numbers=True,
|
||||
word_wrap=True,
|
||||
)
|
||||
temp_console.print(syntax)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
def _render_markdown(self, text: str) -> str:
|
||||
"""Render markdown content with Rich formatting.
|
||||
|
||||
Uses Rich's Markdown renderer for headings, lists, emphasis, etc.
|
||||
|
||||
Args:
|
||||
text: Markdown text to render
|
||||
|
||||
Returns:
|
||||
Formatted markdown string
|
||||
"""
|
||||
if not self._console:
|
||||
return text
|
||||
|
||||
try:
|
||||
from rich.markdown import Markdown
|
||||
|
||||
# Preprocess for task lists
|
||||
processed = self._preprocess_markdown(text)
|
||||
|
||||
buffer = StringIO()
|
||||
temp_console = Console(file=buffer, force_terminal=True, width=100)
|
||||
temp_console.print(Markdown(processed))
|
||||
return buffer.getvalue()
|
||||
except ImportError:
|
||||
return text
|
||||
|
||||
def _preprocess_markdown(self, text: str) -> str:
|
||||
"""Preprocess markdown for enhanced rendering.
|
||||
|
||||
Converts task list checkboxes to visual indicators.
|
||||
|
||||
Args:
|
||||
text: Raw markdown
|
||||
|
||||
Returns:
|
||||
Preprocessed markdown
|
||||
"""
|
||||
# Convert task lists: [ ] -> ☐, [x] -> ☑
|
||||
text = re.sub(r"\[\s\]", "☐", text)
|
||||
text = re.sub(r"\[[xX]\]", "☑", text)
|
||||
return text
|
||||
|
||||
def _render_traceback(self, text: str) -> str:
|
||||
"""Render error traceback with syntax highlighting.
|
||||
|
||||
Uses Python syntax highlighting for better readability.
|
||||
|
||||
Args:
|
||||
text: Traceback text to render
|
||||
|
||||
Returns:
|
||||
Formatted traceback string
|
||||
"""
|
||||
if not self._console:
|
||||
return f"[red]{text}[/red]"
|
||||
|
||||
try:
|
||||
buffer = StringIO()
|
||||
temp_console = Console(file=buffer, force_terminal=True, width=100)
|
||||
temp_console.print("[red bold]⚠ Error Traceback:[/red bold]")
|
||||
syntax = Syntax(
|
||||
text,
|
||||
"python",
|
||||
theme="monokai",
|
||||
line_numbers=False,
|
||||
word_wrap=True,
|
||||
background_color="grey11",
|
||||
)
|
||||
temp_console.print(syntax)
|
||||
return buffer.getvalue()
|
||||
except Exception as e:
|
||||
_logger.warning("Rich traceback rendering failed: %s: %s", type(e).__name__, e)
|
||||
return f"[red]{escape(text)}[/red]"
|
||||
|
||||
def print_smart(self, message: str, iteration: int = 0) -> None:
|
||||
"""Print message with smart content detection directly to console.
|
||||
|
||||
This is the preferred method for displaying assistant messages
|
||||
as it handles all content types appropriately and prints directly.
|
||||
|
||||
Args:
|
||||
message: Message text to print
|
||||
iteration: Current iteration number
|
||||
"""
|
||||
if not self.should_display(MessageType.ASSISTANT):
|
||||
return
|
||||
|
||||
if self._verbosity == VerbosityLevel.QUIET:
|
||||
return
|
||||
|
||||
if not self._console:
|
||||
print(f"ASSISTANT: {message}")
|
||||
return
|
||||
|
||||
# Use smart detection
|
||||
if self._smart_detection and self._content_detector:
|
||||
content_type = self._content_detector.detect(message)
|
||||
|
||||
if content_type == ContentType.DIFF:
|
||||
# DiffFormatter prints directly
|
||||
if self._diff_formatter:
|
||||
self._diff_formatter.format_and_print(message)
|
||||
return
|
||||
|
||||
# For other content types, use format and print
|
||||
formatted = self.format_assistant_message(message, iteration)
|
||||
if formatted:
|
||||
self._console.print(formatted, markup=True)
|
||||
|
||||
def format_system_message(
|
||||
self,
|
||||
message: str,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format a system message for rich terminal display.
|
||||
|
||||
Args:
|
||||
message: System message text
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.SYSTEM):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.SYSTEM, message, context)
|
||||
|
||||
if not self._rich_available:
|
||||
return f"SYSTEM: {message}"
|
||||
|
||||
icon = self.ICONS["system"]
|
||||
return f"{icon} [{self.COLORS['system']}]SYSTEM: {message}[/]"
|
||||
|
||||
def format_error(
|
||||
self,
|
||||
error: str,
|
||||
exception: Optional[Exception] = None,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format an error for rich terminal display.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
exception: Optional exception object
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.ERROR, error, context)
|
||||
|
||||
if not self._rich_available:
|
||||
return f"ERROR: {error}"
|
||||
|
||||
icon = self.ICONS["error"]
|
||||
color = self.COLORS["error"]
|
||||
|
||||
lines = [
|
||||
f"\n{icon} [{color}]ERROR (Iteration {iteration})[/]",
|
||||
f" [{color}]{error}[/]",
|
||||
]
|
||||
|
||||
if exception and self._verbosity.value >= VerbosityLevel.VERBOSE.value:
|
||||
lines.append(f" [{self.COLORS['warning']}]Type: {type(exception).__name__}[/]")
|
||||
import traceback
|
||||
|
||||
tb = "".join(traceback.format_exception(type(exception), exception, exception.__traceback__))
|
||||
lines.append(f" [{self.COLORS['timestamp']}]Traceback:[/]")
|
||||
for line in tb.split("\n")[:15]: # Limit traceback lines
|
||||
if line.strip():
|
||||
lines.append(f" {escape(line)}" if self._rich_available else f" {line}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_progress(
|
||||
self,
|
||||
message: str,
|
||||
current: int = 0,
|
||||
total: int = 0,
|
||||
iteration: int = 0,
|
||||
) -> str:
|
||||
"""Format progress information for rich terminal display.
|
||||
|
||||
Args:
|
||||
message: Progress message
|
||||
current: Current progress value
|
||||
total: Total progress value
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self.should_display(MessageType.PROGRESS):
|
||||
return ""
|
||||
|
||||
context = self._create_context(iteration)
|
||||
self._notify_callbacks(MessageType.PROGRESS, message, context)
|
||||
|
||||
if not self._rich_available:
|
||||
if total > 0:
|
||||
pct = (current / total) * 100
|
||||
return f"[{pct:.0f}%] {message}"
|
||||
return f"[...] {message}"
|
||||
|
||||
icon = self.ICONS["progress"]
|
||||
if total > 0:
|
||||
pct = (current / total) * 100
|
||||
bar_width = 20
|
||||
filled = int(bar_width * current / total)
|
||||
bar = "" * filled + "" * (bar_width - filled)
|
||||
return f"{icon} [{self.COLORS['info']}][{bar}] {pct:.0f}%[/] {message}"
|
||||
return f"{icon} [{self.COLORS['info']}][...][/] {message}"
|
||||
|
||||
def format_token_usage(self, show_session: bool = True) -> str:
|
||||
"""Format token usage summary for rich terminal display.
|
||||
|
||||
Args:
|
||||
show_session: Include session totals
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
usage = self._token_usage
|
||||
|
||||
if not self._rich_available:
|
||||
lines = [
|
||||
f"TOKEN USAGE: {usage.total_tokens:,} (${usage.cost:.4f})",
|
||||
]
|
||||
if show_session:
|
||||
lines.append(f" Session: {usage.session_total_tokens:,} (${usage.session_cost:.4f})")
|
||||
return "\n".join(lines)
|
||||
|
||||
icon = self.ICONS["token"]
|
||||
input_color = self.COLORS["token_input"]
|
||||
output_color = self.COLORS["token_output"]
|
||||
cost_color = self.COLORS["cost"]
|
||||
|
||||
lines = [
|
||||
f"\n{icon} [{self.COLORS['header']}]TOKEN USAGE[/]",
|
||||
f" Current: [{input_color}]{usage.input_tokens:,} in[/] | [{output_color}]{usage.output_tokens:,} out[/] | [{cost_color}]${usage.cost:.4f}[/]",
|
||||
]
|
||||
|
||||
if show_session:
|
||||
lines.append(
|
||||
f" Session: [{input_color}]{usage.session_input_tokens:,} in[/] | [{output_color}]{usage.session_output_tokens:,} out[/] | [{cost_color}]${usage.session_cost:.4f}[/]"
|
||||
)
|
||||
|
||||
if usage.model:
|
||||
lines.append(f" [{self.COLORS['timestamp']}]Model: {usage.model}[/]")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_section_header(self, title: str, iteration: int = 0) -> str:
|
||||
"""Format a section header for rich terminal display.
|
||||
|
||||
Args:
|
||||
title: Section title
|
||||
iteration: Current iteration number
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
if not self._rich_available:
|
||||
sep = "=" * 60
|
||||
header_title = f"{title} (Iteration {iteration})" if iteration else title
|
||||
return f"\n{sep}\n{header_title}\n{sep}"
|
||||
|
||||
header_title = f"{title} (Iteration {iteration})" if iteration else title
|
||||
sep = "" * 50
|
||||
return f"\n[{self.COLORS['header']}]{sep}\n{header_title}\n{sep}[/]"
|
||||
|
||||
def format_section_footer(self) -> str:
|
||||
"""Format a section footer for rich terminal display.
|
||||
|
||||
Returns:
|
||||
Formatted string representation
|
||||
"""
|
||||
elapsed = self.get_elapsed_time()
|
||||
|
||||
if not self._rich_available:
|
||||
return f"\n{'=' * 50}\nElapsed: {elapsed:.1f}s\n"
|
||||
|
||||
icon = self.ICONS["clock"]
|
||||
return f"\n[{self.COLORS['timestamp']}]{icon} Elapsed: {elapsed:.1f}s[/]\n"
|
||||
|
||||
def print(self, text: str) -> None:
|
||||
"""Print formatted text to console.
|
||||
|
||||
Args:
|
||||
text: Rich-formatted text to print
|
||||
"""
|
||||
if self._console:
|
||||
self._console.print(text, markup=True)
|
||||
else:
|
||||
# Strip markup for plain output
|
||||
import re
|
||||
|
||||
plain = re.sub(r"\[/?[^\]]+\]", "", text)
|
||||
print(plain)
|
||||
|
||||
def print_panel(self, content: str, title: str = "", border_style: str = "blue") -> None:
|
||||
"""Print content in a Rich panel.
|
||||
|
||||
Args:
|
||||
content: Content to display
|
||||
title: Panel title
|
||||
border_style: Panel border color
|
||||
"""
|
||||
if self._console and self._rich_available and Panel:
|
||||
panel = Panel(content, title=title, border_style=border_style)
|
||||
self._console.print(panel)
|
||||
else:
|
||||
if title:
|
||||
print(f"\n=== {title} ===")
|
||||
print(content)
|
||||
print()
|
||||
|
||||
def create_progress_bar(self) -> Optional["Progress"]:
|
||||
"""Create a Rich progress bar instance.
|
||||
|
||||
Returns:
|
||||
Progress instance or None if Rich not available
|
||||
"""
|
||||
if not self._rich_available or not Progress:
|
||||
return None
|
||||
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
console=self._console,
|
||||
)
|
||||
156
.venv/lib/python3.11/site-packages/ralph_orchestrator/safety.py
Normal file
156
.venv/lib/python3.11/site-packages/ralph_orchestrator/safety.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# ABOUTME: Safety guardrails and circuit breakers for Ralph Orchestrator
|
||||
# ABOUTME: Prevents runaway loops and excessive costs
|
||||
|
||||
"""Safety mechanisms for Ralph Orchestrator."""
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('ralph-orchestrator.safety')
|
||||
|
||||
|
||||
@dataclass
|
||||
class SafetyCheckResult:
|
||||
"""Result of a safety check."""
|
||||
passed: bool
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class SafetyGuard:
|
||||
"""Safety guardrails for orchestration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_iterations: int = 100,
|
||||
max_runtime: int = 14400, # 4 hours
|
||||
max_cost: float = 10.0,
|
||||
consecutive_failure_limit: int = 5
|
||||
):
|
||||
"""Initialize safety guard.
|
||||
|
||||
Args:
|
||||
max_iterations: Maximum allowed iterations
|
||||
max_runtime: Maximum runtime in seconds
|
||||
max_cost: Maximum allowed cost in dollars
|
||||
consecutive_failure_limit: Max consecutive failures before stopping
|
||||
"""
|
||||
self.max_iterations = max_iterations
|
||||
self.max_runtime = max_runtime
|
||||
self.max_cost = max_cost
|
||||
self.consecutive_failure_limit = consecutive_failure_limit
|
||||
self.consecutive_failures = 0
|
||||
# Loop detection state
|
||||
self.recent_outputs: deque = deque(maxlen=5)
|
||||
self.loop_threshold: float = 0.9
|
||||
|
||||
def check(
|
||||
self,
|
||||
iterations: int,
|
||||
elapsed_time: float,
|
||||
total_cost: float
|
||||
) -> SafetyCheckResult:
|
||||
"""Check all safety conditions.
|
||||
|
||||
Args:
|
||||
iterations: Current iteration count
|
||||
elapsed_time: Elapsed time in seconds
|
||||
total_cost: Total cost so far
|
||||
|
||||
Returns:
|
||||
SafetyCheckResult indicating if it's safe to continue
|
||||
"""
|
||||
# Check iteration limit
|
||||
if iterations >= self.max_iterations:
|
||||
return SafetyCheckResult(
|
||||
passed=False,
|
||||
reason=f"Reached maximum iterations ({self.max_iterations})"
|
||||
)
|
||||
|
||||
# Check runtime limit
|
||||
if elapsed_time >= self.max_runtime:
|
||||
hours = elapsed_time / 3600
|
||||
return SafetyCheckResult(
|
||||
passed=False,
|
||||
reason=f"Reached maximum runtime ({hours:.1f} hours)"
|
||||
)
|
||||
|
||||
# Check cost limit
|
||||
if total_cost >= self.max_cost:
|
||||
return SafetyCheckResult(
|
||||
passed=False,
|
||||
reason=f"Reached maximum cost (${total_cost:.2f})"
|
||||
)
|
||||
|
||||
# Check consecutive failures
|
||||
if self.consecutive_failures >= self.consecutive_failure_limit:
|
||||
return SafetyCheckResult(
|
||||
passed=False,
|
||||
reason=f"Too many consecutive failures ({self.consecutive_failures})"
|
||||
)
|
||||
|
||||
# Additional safety checks for high iteration counts
|
||||
if iterations > 50:
|
||||
# Warn but don't stop
|
||||
logger.warning(f"High iteration count: {iterations}")
|
||||
|
||||
if iterations > 75:
|
||||
# More aggressive checks
|
||||
if elapsed_time / iterations > 300: # More than 5 min per iteration avg
|
||||
return SafetyCheckResult(
|
||||
passed=False,
|
||||
reason="Iterations taking too long on average"
|
||||
)
|
||||
|
||||
return SafetyCheckResult(passed=True)
|
||||
|
||||
def record_success(self):
|
||||
"""Record a successful iteration."""
|
||||
self.consecutive_failures = 0
|
||||
|
||||
def record_failure(self):
|
||||
"""Record a failed iteration."""
|
||||
self.consecutive_failures += 1
|
||||
logger.warning(f"Consecutive failures: {self.consecutive_failures}")
|
||||
|
||||
def reset(self):
|
||||
"""Reset safety counters."""
|
||||
self.consecutive_failures = 0
|
||||
self.recent_outputs.clear()
|
||||
|
||||
def detect_loop(self, current_output: str) -> bool:
|
||||
"""Detect if agent is looping based on output similarity.
|
||||
|
||||
Uses rapidfuzz for fast fuzzy string matching. If the current output
|
||||
is more than 90% similar to any recent output, a loop is detected.
|
||||
|
||||
Args:
|
||||
current_output: The current agent output to check.
|
||||
|
||||
Returns:
|
||||
True if loop detected (similar output found), False otherwise.
|
||||
"""
|
||||
if not current_output:
|
||||
return False
|
||||
|
||||
try:
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
for prev_output in self.recent_outputs:
|
||||
ratio = fuzz.ratio(current_output, prev_output) / 100.0
|
||||
if ratio >= self.loop_threshold:
|
||||
logger.warning(
|
||||
f"Loop detected: {ratio:.1%} similarity to previous output"
|
||||
)
|
||||
return True
|
||||
|
||||
self.recent_outputs.append(current_output)
|
||||
return False
|
||||
except ImportError:
|
||||
# rapidfuzz not installed, skip loop detection
|
||||
logger.debug("rapidfuzz not installed, skipping loop detection")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in loop detection: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,488 @@
|
||||
# ABOUTME: Security utilities for Ralph Orchestrator
|
||||
# ABOUTME: Provides input validation, path sanitization, and sensitive data protection
|
||||
|
||||
"""
|
||||
Security utilities for Ralph Orchestrator.
|
||||
|
||||
This module provides security hardening functions including input validation,
|
||||
path sanitization, and sensitive data protection.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger("ralph-orchestrator.security")
|
||||
|
||||
|
||||
class SecurityValidator:
|
||||
"""Security validation utilities for Ralph Orchestrator."""
|
||||
|
||||
# Patterns for dangerous path components
|
||||
DANGEROUS_PATH_PATTERNS = [
|
||||
r"\.\.\/.*", # Directory traversal (Unix)
|
||||
r"\.\.\\.*", # Windows directory traversal
|
||||
r"^\.\.[\/\\]", # Starts with parent directory
|
||||
r"[\/\\]\.\.[\/\\]", # Contains parent directory
|
||||
r"[<>:\"|?*]", # Invalid filename characters (Windows)
|
||||
r"[\x00-\x1f]", # Control characters
|
||||
r"[\/\\]\.\.[\/\\]\.\.[\/\\]", # Double traversal
|
||||
]
|
||||
|
||||
# Sensitive data patterns that should be masked (16+ patterns)
|
||||
SENSITIVE_PATTERNS = [
|
||||
# API Keys
|
||||
(r"(sk-[a-zA-Z0-9]{10,})", r"sk-***********"), # OpenAI API keys
|
||||
(r"(xai-[a-zA-Z0-9]{10,})", r"xai-***********"), # xAI API keys
|
||||
(r"(AIza[a-zA-Z0-9_-]{35})", r"AIza***********"), # Google API keys
|
||||
# Bearer tokens
|
||||
(r"(Bearer [a-zA-Z0-9\-_\.]{20,})", r"Bearer ***********"),
|
||||
# Passwords in various formats
|
||||
(
|
||||
r'(["\']?password["\']?\s*[:=]\s*["\']?)([^"\'\s]{3,})(["\']?)',
|
||||
r"\1*********\3",
|
||||
),
|
||||
(r"(password\s*=\s*)([^\"'\s]{3,})", r"\1*********"),
|
||||
# Tokens in various formats
|
||||
(
|
||||
r'(token["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9\-_\.]{10,})(["\']?)',
|
||||
r"\1*********\3",
|
||||
),
|
||||
(r"(token\s*=\s*)([a-zA-Z0-9\-_\.]{10,})", r"\1*********"),
|
||||
# Secrets
|
||||
(
|
||||
r'(secret["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9\-_\.]{10,})(["\']?)',
|
||||
r"\1*********\3",
|
||||
),
|
||||
(r"(secret\s*=\s*)([a-zA-Z0-9\-_\.]{10,})", r"\1*********"),
|
||||
# Generic keys
|
||||
(
|
||||
r'(key["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9\-_\.]{10,})(["\']?)',
|
||||
r"\1*********\3",
|
||||
),
|
||||
# API keys in various formats
|
||||
(
|
||||
r'(api[_-]?key["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9\-_\.]{10,})(["\']?)',
|
||||
r"\1*********\3",
|
||||
),
|
||||
(r"(api[_-]?key\s*=\s*)([a-zA-Z0-9\-_\.]{10,})", r"\1*********"),
|
||||
# Sensitive file paths
|
||||
(
|
||||
r"(/[a-zA-Z0-9_\-\./]*\.ssh/[a-zA-Z0-9_\-\./]*)",
|
||||
r"[REDACTED_SSH_PATH]",
|
||||
), # SSH paths
|
||||
(
|
||||
r"(/[a-zA-Z0-9_\-\./]*\.ssh/id_[a-zA-Z0-9]*)",
|
||||
r"[REDACTED_SSH_KEY]",
|
||||
), # SSH private keys
|
||||
(
|
||||
r"(/[a-zA-Z0-9_\-\./]*\.config/[a-zA-Z0-9_\-\./]*)",
|
||||
r"[REDACTED_CONFIG_PATH]",
|
||||
), # Config files
|
||||
(
|
||||
r"(/[a-zA-Z0-9_\-\./]*\.aws/[a-zA-Z0-9_\-\./]*)",
|
||||
r"[REDACTED_AWS_PATH]",
|
||||
), # AWS credentials
|
||||
(
|
||||
r"(/[a-zA-Z0-9_\-\./]*(passwd|shadow|group|hosts))",
|
||||
r"[REDACTED_SYSTEM_FILE]",
|
||||
), # System files
|
||||
(
|
||||
r"(C:\\\\[a-zA-Z0-9_\-\./]*\\\\System32\\\\[a-zA-Z0-9_\-\./]*)",
|
||||
r"[REDACTED_SYSTEM_PATH]",
|
||||
), # Windows system files
|
||||
(
|
||||
r"(/[a-zA-Z0-9_\-\./]*(id_rsa|id_dsa|id_ecdsa|id_ed25519))",
|
||||
r"[REDACTED_PRIVATE_KEY]",
|
||||
), # Private key files
|
||||
]
|
||||
|
||||
# Dangerous absolute path prefixes
|
||||
DANGEROUS_ABS_PATHS = [
|
||||
"/etc",
|
||||
"/usr/bin",
|
||||
"/bin",
|
||||
"/sbin",
|
||||
"/root",
|
||||
"/var",
|
||||
"/opt",
|
||||
"/sys",
|
||||
"/proc",
|
||||
"/dev",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def sanitize_path(cls, path: str, base_dir: Optional[Path] = None) -> Path:
|
||||
"""
|
||||
Sanitize a file path to prevent directory traversal attacks.
|
||||
|
||||
Args:
|
||||
path: Input path to sanitize
|
||||
base_dir: Base directory to resolve relative paths against
|
||||
|
||||
Returns:
|
||||
Sanitized absolute Path
|
||||
|
||||
Raises:
|
||||
ValueError: If path contains dangerous patterns
|
||||
"""
|
||||
if base_dir is None:
|
||||
base_dir = Path.cwd()
|
||||
|
||||
# Convert to Path object
|
||||
try:
|
||||
input_path = Path(path)
|
||||
except (ValueError, OSError) as e:
|
||||
raise ValueError(f"Invalid path: {path}") from e
|
||||
|
||||
# Check for dangerous patterns
|
||||
path_str = str(input_path)
|
||||
for pattern in cls.DANGEROUS_PATH_PATTERNS:
|
||||
if re.search(pattern, path_str, re.IGNORECASE):
|
||||
raise ValueError(f"Path contains dangerous pattern: {path}")
|
||||
|
||||
# Check for dangerous absolute paths
|
||||
if input_path.is_absolute():
|
||||
for dangerous in cls.DANGEROUS_ABS_PATHS:
|
||||
if path_str.startswith(dangerous):
|
||||
raise ValueError(
|
||||
f"Path resolves to dangerous system location: {path_str}"
|
||||
)
|
||||
|
||||
# Resolve the path
|
||||
if input_path.is_absolute():
|
||||
resolved_path = input_path.resolve()
|
||||
else:
|
||||
resolved_path = (base_dir / input_path).resolve()
|
||||
|
||||
# Ensure resolved path is within base directory or a safe location
|
||||
try:
|
||||
resolved_path.relative_to(base_dir.resolve())
|
||||
except ValueError:
|
||||
# Check if this is an absolute path that might be dangerous
|
||||
if input_path.is_absolute():
|
||||
# Check dangerous absolute paths
|
||||
dangerous_paths = cls.DANGEROUS_ABS_PATHS + ["/home"]
|
||||
for dangerous in dangerous_paths:
|
||||
try:
|
||||
resolved_path.relative_to(dangerous)
|
||||
raise ValueError(
|
||||
f"Path resolves to dangerous system location: {resolved_path}"
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
# Relative path that goes outside base directory
|
||||
raise ValueError(
|
||||
f"Path traversal detected: {path} -> {resolved_path}"
|
||||
) from None
|
||||
|
||||
return resolved_path
|
||||
|
||||
@classmethod
|
||||
def validate_config_value(cls, key: str, value: Any) -> Any:
|
||||
"""
|
||||
Validate and sanitize configuration values.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
value: Configuration value
|
||||
|
||||
Returns:
|
||||
Sanitized value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is invalid or dangerous
|
||||
"""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
# Type-specific validation
|
||||
if key in ["delay", "stats_interval", "max_iterations", "iteration_timeout"]:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid integer value for {key}: {value}") from e
|
||||
|
||||
# Validate ranges
|
||||
if value < 0:
|
||||
raise ValueError(f"{key} must be non-negative, got: {value}")
|
||||
if key == "delay" and value > 86400: # 24 hours
|
||||
raise ValueError(f"{key} too large (>24 hours): {value}")
|
||||
if key == "max_iterations" and value > 10000:
|
||||
raise ValueError(f"{key} too large (>10000): {value}")
|
||||
if key == "stats_interval" and value > 3600: # 1 hour
|
||||
raise ValueError(f"{key} too large (>1 hour): {value}")
|
||||
if key == "iteration_timeout" and value > 7200: # 2 hours
|
||||
raise ValueError(f"{key} too large (>2 hours): {value}")
|
||||
|
||||
elif key in ["log_file", "pid_file", "prompt_file", "system_prompt_file"]:
|
||||
if isinstance(value, str):
|
||||
# Sanitize file paths for non-prompt files
|
||||
if key not in ["prompt_file", "system_prompt_file"]:
|
||||
cls.sanitize_path(value)
|
||||
|
||||
elif key in [
|
||||
"verbose",
|
||||
"dry_run",
|
||||
"clear_screen",
|
||||
"show_countdown",
|
||||
"inject_best_practices",
|
||||
]:
|
||||
# Boolean validation
|
||||
if isinstance(value, str):
|
||||
value = cls._parse_bool_safe(value)
|
||||
elif not isinstance(value, bool):
|
||||
raise ValueError(f"Invalid boolean value for {key}: {value}")
|
||||
|
||||
elif key == "focus":
|
||||
if isinstance(value, str):
|
||||
# Sanitize focus text - remove potential command injection
|
||||
value = re.sub(r"[;&|`$()]", "", value)
|
||||
if len(value) > 200:
|
||||
value = value[:200]
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _parse_bool_safe(cls, value: str) -> bool:
|
||||
"""
|
||||
Safely parse boolean values from strings.
|
||||
|
||||
Args:
|
||||
value: String value to parse
|
||||
|
||||
Returns:
|
||||
Boolean value
|
||||
"""
|
||||
if not value or not value.strip():
|
||||
return False
|
||||
|
||||
value_lower = value.lower().strip()
|
||||
|
||||
# Remove any dangerous characters
|
||||
value_clean = re.sub(r"[;&|`$()]", "", value_lower)
|
||||
|
||||
true_values = ("true", "1", "yes", "on")
|
||||
false_values = ("false", "0", "no", "off")
|
||||
|
||||
if value_clean in true_values:
|
||||
return True
|
||||
elif value_clean in false_values:
|
||||
return False
|
||||
else:
|
||||
# Default to False for ambiguous values
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def mask_sensitive_data(cls, text: str) -> str:
|
||||
"""
|
||||
Mask sensitive data in text for logging.
|
||||
|
||||
Args:
|
||||
text: Text to mask sensitive data in
|
||||
|
||||
Returns:
|
||||
Text with sensitive data masked
|
||||
"""
|
||||
masked_text = text
|
||||
for pattern, replacement in cls.SENSITIVE_PATTERNS:
|
||||
masked_text = re.sub(pattern, replacement, masked_text, flags=re.IGNORECASE)
|
||||
return masked_text
|
||||
|
||||
@classmethod
|
||||
def validate_filename(cls, filename: str) -> str:
|
||||
"""
|
||||
Validate a filename for security.
|
||||
|
||||
Args:
|
||||
filename: Filename to validate
|
||||
|
||||
Returns:
|
||||
Sanitized filename
|
||||
|
||||
Raises:
|
||||
ValueError: If filename is invalid or dangerous
|
||||
"""
|
||||
if not filename or not filename.strip():
|
||||
raise ValueError("Filename cannot be empty")
|
||||
|
||||
# Check for path traversal attempts in filename
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
raise ValueError(f"Filename contains path traversal: {filename}")
|
||||
|
||||
# Remove dangerous characters
|
||||
sanitized = re.sub(r'[<>:"|?*\x00-\x1f]', "", filename.strip())
|
||||
|
||||
if not sanitized:
|
||||
raise ValueError("Filename contains only invalid characters")
|
||||
|
||||
# Prevent reserved names (Windows)
|
||||
reserved_names = {
|
||||
"CON",
|
||||
"PRN",
|
||||
"AUX",
|
||||
"NUL",
|
||||
"COM1",
|
||||
"COM2",
|
||||
"COM3",
|
||||
"COM4",
|
||||
"COM5",
|
||||
"COM6",
|
||||
"COM7",
|
||||
"COM8",
|
||||
"COM9",
|
||||
"LPT1",
|
||||
"LPT2",
|
||||
"LPT3",
|
||||
"LPT4",
|
||||
"LPT5",
|
||||
"LPT6",
|
||||
"LPT7",
|
||||
"LPT8",
|
||||
"LPT9",
|
||||
}
|
||||
|
||||
name_without_ext = sanitized.split(".")[0].upper()
|
||||
if name_without_ext in reserved_names:
|
||||
raise ValueError(f"Filename uses reserved name: {filename}")
|
||||
|
||||
# Check for control characters
|
||||
if any(ord(char) < 32 for char in filename):
|
||||
raise ValueError(f"Filename contains control characters: {filename}")
|
||||
|
||||
# Limit length
|
||||
if len(sanitized) > 255:
|
||||
sanitized = sanitized[:255]
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def create_secure_logger(
|
||||
cls, name: str, log_file: Optional[str] = None
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Create a logger with security features enabled.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
log_file: Optional log file path
|
||||
|
||||
Returns:
|
||||
Secure logger instance
|
||||
"""
|
||||
secure_logger = logging.getLogger(name)
|
||||
|
||||
# Create custom formatter that masks sensitive data
|
||||
class SecureFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
formatted = super().format(record)
|
||||
return cls.mask_sensitive_data(formatted)
|
||||
|
||||
# Set up secure formatter
|
||||
if log_file:
|
||||
handler = logging.FileHandler(log_file)
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
|
||||
handler.setFormatter(
|
||||
SecureFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
|
||||
secure_logger.addHandler(handler)
|
||||
secure_logger.setLevel(logging.INFO)
|
||||
|
||||
return secure_logger
|
||||
|
||||
|
||||
class PathTraversalProtection:
|
||||
"""Protection against path traversal attacks."""
|
||||
|
||||
@staticmethod
|
||||
def safe_file_read(file_path: str, base_dir: Optional[Path] = None) -> str:
|
||||
"""
|
||||
Safely read a file with path traversal protection.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to read
|
||||
base_dir: Base directory for relative paths
|
||||
|
||||
Returns:
|
||||
File content
|
||||
|
||||
Raises:
|
||||
ValueError: If path is dangerous
|
||||
FileNotFoundError: If file doesn't exist
|
||||
PermissionError: If file cannot be read
|
||||
"""
|
||||
safe_path = SecurityValidator.sanitize_path(file_path, base_dir)
|
||||
|
||||
if not safe_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {safe_path}")
|
||||
|
||||
if not safe_path.is_file():
|
||||
raise ValueError(f"Path is not a file: {safe_path}")
|
||||
|
||||
try:
|
||||
return safe_path.read_text(encoding="utf-8")
|
||||
except PermissionError as e:
|
||||
raise PermissionError(f"Cannot read file: {safe_path}") from e
|
||||
|
||||
@staticmethod
|
||||
def safe_file_write(
|
||||
file_path: str, content: str, base_dir: Optional[Path] = None
|
||||
) -> None:
|
||||
"""
|
||||
Safely write to a file with path traversal protection.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to write
|
||||
content: Content to write
|
||||
base_dir: Base directory for relative paths
|
||||
|
||||
Raises:
|
||||
ValueError: If path is dangerous
|
||||
PermissionError: If file cannot be written
|
||||
"""
|
||||
safe_path = SecurityValidator.sanitize_path(file_path, base_dir)
|
||||
|
||||
# Create parent directories if needed
|
||||
safe_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
safe_path.write_text(content, encoding="utf-8")
|
||||
except PermissionError as e:
|
||||
raise PermissionError(f"Cannot write file: {safe_path}") from e
|
||||
|
||||
|
||||
# Security decorator for functions that handle file paths
|
||||
def secure_file_operation(base_dir: Optional[Path] = None):
|
||||
"""
|
||||
Decorator to secure file operations against path traversal.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for relative paths
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Find path arguments and sanitize them
|
||||
new_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, str) and ("/" in arg or "\\" in arg):
|
||||
arg = str(SecurityValidator.sanitize_path(arg, base_dir))
|
||||
new_args.append(arg)
|
||||
|
||||
new_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, str) and ("/" in value or "\\" in value):
|
||||
value = str(SecurityValidator.sanitize_path(value, base_dir))
|
||||
new_kwargs[key] = value
|
||||
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,951 @@
|
||||
# ABOUTME: Enhanced verbose logging utilities for Ralph Orchestrator
|
||||
# ABOUTME: Provides session metrics, emergency shutdown, re-entrancy protection, Rich output
|
||||
|
||||
"""Enhanced verbose logging utilities for Ralph."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, TextIO, cast
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.syntax import Syntax
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
# Unused imports commented out but kept for future reference:
|
||||
# from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
Console = None # type: ignore
|
||||
Markdown = None # type: ignore
|
||||
Syntax = None # type: ignore
|
||||
|
||||
# Import DiffFormatter for enhanced diff output
|
||||
try:
|
||||
from ralph_orchestrator.output import DiffFormatter
|
||||
except ImportError:
|
||||
DiffFormatter = None # type: ignore
|
||||
|
||||
|
||||
class TextIOProxy:
|
||||
"""TextIO proxy that captures Rich console output to a file."""
|
||||
|
||||
def __init__(self, file_path: Path) -> None:
|
||||
"""
|
||||
Initialize TextIO proxy.
|
||||
|
||||
Args:
|
||||
file_path: Path to output file
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self._file: Optional[TextIO] = None
|
||||
self._closed = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _ensure_open(self) -> Optional[TextIO]:
|
||||
"""Ensure file is open, opening lazily if needed."""
|
||||
if self._closed:
|
||||
return None
|
||||
if self._file is None:
|
||||
try:
|
||||
self._file = open(self.file_path, "a", encoding="utf-8")
|
||||
except (OSError, IOError):
|
||||
self._closed = True
|
||||
return None
|
||||
return self._file
|
||||
|
||||
def write(self, text: str) -> int:
|
||||
"""
|
||||
Write text to file.
|
||||
|
||||
Args:
|
||||
text: Text to write
|
||||
|
||||
Returns:
|
||||
Number of characters written
|
||||
"""
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
return 0
|
||||
try:
|
||||
f = self._ensure_open()
|
||||
if f is None:
|
||||
return 0
|
||||
return f.write(text)
|
||||
except (ValueError, OSError, AttributeError):
|
||||
return 0
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush file buffer."""
|
||||
with self._lock:
|
||||
if not self._closed and self._file:
|
||||
try:
|
||||
self._file.flush()
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close file."""
|
||||
with self._lock:
|
||||
if not self._closed and self._file:
|
||||
try:
|
||||
self._file.close()
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
finally:
|
||||
self._closed = True
|
||||
self._file = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup on deletion."""
|
||||
self.close()
|
||||
|
||||
|
||||
class VerboseLogger:
|
||||
"""
|
||||
Enhanced verbose logger that captures detailed output to log files.
|
||||
|
||||
Features:
|
||||
- Session metrics tracking in JSON format
|
||||
- Emergency shutdown capability
|
||||
- Re-entrancy protection (prevent logging loops)
|
||||
- Console output with Rich library integration
|
||||
- Thread-safe operations
|
||||
|
||||
This logger captures all verbose output including:
|
||||
- Claude SDK messages with full content
|
||||
- Tool calls and results
|
||||
- Console output with formatting preserved
|
||||
- System events and status updates
|
||||
- Error details and tracebacks
|
||||
"""
|
||||
|
||||
_metrics: Dict[str, Any]
|
||||
|
||||
def __init__(self, log_dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Initialize verbose logger with thread safety.
|
||||
|
||||
Args:
|
||||
log_dir: Directory to store verbose log files (defaults to .agent in cwd)
|
||||
"""
|
||||
if log_dir is None:
|
||||
# Find repository root by looking for .git directory
|
||||
current_dir = Path.cwd()
|
||||
repo_root = current_dir
|
||||
|
||||
# Walk up to find .git directory or stop at filesystem root
|
||||
while repo_root.parent != repo_root:
|
||||
if (repo_root / ".git").exists():
|
||||
break
|
||||
repo_root = repo_root.parent
|
||||
|
||||
# Create .agent directory in repository root
|
||||
log_dir = str(repo_root / ".agent")
|
||||
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create timestamped log file for current session
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.verbose_log_file = self.log_dir / f"ralph_verbose_{timestamp}.log"
|
||||
self.raw_output_file = self.log_dir / f"ralph_raw_{timestamp}.log"
|
||||
self.metrics_file = self.log_dir / f"ralph_metrics_{timestamp}.json"
|
||||
|
||||
# Thread safety: Use both asyncio and threading locks
|
||||
self._lock = asyncio.Lock()
|
||||
self._thread_lock = threading.RLock() # Re-entrant lock for thread safety
|
||||
|
||||
# Initialize Rich console or fallback
|
||||
self._text_io_proxy = TextIOProxy(self.verbose_log_file)
|
||||
if RICH_AVAILABLE:
|
||||
self._console = Console(file=cast(TextIO, self._text_io_proxy), width=120)
|
||||
self._live_console = Console() # For live terminal output
|
||||
else:
|
||||
self._console = None
|
||||
self._live_console = None
|
||||
|
||||
# Initialize DiffFormatter for enhanced diff output
|
||||
if RICH_AVAILABLE and DiffFormatter and self._console:
|
||||
self._diff_formatter: Optional[DiffFormatter] = DiffFormatter(self._console)
|
||||
else:
|
||||
self._diff_formatter = None
|
||||
|
||||
self._raw_file_handle: Optional[TextIO] = None
|
||||
|
||||
# Emergency shutdown state
|
||||
self._emergency_shutdown = False
|
||||
self._emergency_event = threading.Event()
|
||||
|
||||
# Re-entrancy protection: Track if we're already logging
|
||||
self._logging_depth = 0
|
||||
self._logging_thread_ids: set = set()
|
||||
self._max_logging_depth = 3 # Prevent deep nesting
|
||||
|
||||
# Session metrics tracking
|
||||
self._metrics = {
|
||||
"session_start": datetime.now().isoformat(),
|
||||
"session_end": None,
|
||||
"messages": [],
|
||||
"tool_calls": [],
|
||||
"errors": [],
|
||||
"iterations": [],
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
}
|
||||
|
||||
def _can_log_safely(self) -> bool:
|
||||
"""
|
||||
Check if logging is safe to perform (re-entrancy and thread safety check).
|
||||
|
||||
Returns:
|
||||
True if logging is safe, False otherwise
|
||||
"""
|
||||
# Check emergency shutdown first
|
||||
if self._emergency_event.is_set():
|
||||
return False
|
||||
|
||||
# Get current thread ID
|
||||
current_thread_id = threading.current_thread().ident
|
||||
|
||||
# Use thread lock for safe access to shared state
|
||||
with self._thread_lock:
|
||||
# Check if this thread is already in the middle of logging
|
||||
if current_thread_id in self._logging_thread_ids:
|
||||
# Check nesting depth
|
||||
if self._logging_depth >= self._max_logging_depth:
|
||||
return False # Too deeply nested
|
||||
return True # Allow some nesting
|
||||
|
||||
# Check if any other thread is logging (to prevent excessive blocking)
|
||||
if len(self._logging_thread_ids) > 0:
|
||||
# Another thread is logging - we can still log but need to be careful
|
||||
pass
|
||||
|
||||
# Check async lock state (non-blocking)
|
||||
if self._lock.locked():
|
||||
return False # Async lock is held, skip logging
|
||||
|
||||
return True
|
||||
|
||||
def _enter_logging_context(self) -> bool:
|
||||
"""
|
||||
Enter a logging context safely.
|
||||
|
||||
Returns:
|
||||
True if we successfully entered the context, False otherwise
|
||||
"""
|
||||
current_thread_id = threading.current_thread().ident
|
||||
|
||||
with self._thread_lock:
|
||||
if self._logging_depth >= self._max_logging_depth:
|
||||
return False
|
||||
|
||||
if current_thread_id in self._logging_thread_ids:
|
||||
self._logging_depth += 1
|
||||
return True # Re-entrancy in same thread is okay with depth tracking
|
||||
|
||||
self._logging_thread_ids.add(current_thread_id)
|
||||
self._logging_depth = 1
|
||||
return True
|
||||
|
||||
def _exit_logging_context(self) -> None:
|
||||
"""Exit a logging context safely."""
|
||||
current_thread_id = threading.current_thread().ident
|
||||
|
||||
with self._thread_lock:
|
||||
self._logging_depth = max(0, self._logging_depth - 1)
|
||||
|
||||
if self._logging_depth == 0:
|
||||
self._logging_thread_ids.discard(current_thread_id)
|
||||
|
||||
def emergency_shutdown(self) -> None:
|
||||
"""Signal emergency shutdown to make logging operations non-blocking."""
|
||||
self._emergency_shutdown = True
|
||||
self._emergency_event.set()
|
||||
|
||||
def is_shutdown(self) -> bool:
|
||||
"""Check if emergency shutdown has been triggered."""
|
||||
return self._emergency_event.is_set()
|
||||
|
||||
def _print_to_file(self, text: str) -> None:
|
||||
"""Print text to the log file (Rich or plain)."""
|
||||
if self._console and RICH_AVAILABLE:
|
||||
self._console.print(text)
|
||||
else:
|
||||
self._text_io_proxy.write(text + "\n")
|
||||
self._text_io_proxy.flush()
|
||||
|
||||
def _print_to_terminal(self, text: str) -> None:
|
||||
"""Print text to the live terminal (Rich or plain)."""
|
||||
if self._live_console and RICH_AVAILABLE:
|
||||
self._live_console.print(text)
|
||||
else:
|
||||
print(text)
|
||||
|
||||
async def log_message(
|
||||
self,
|
||||
message_type: str,
|
||||
content: Any,
|
||||
iteration: int = 0,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Log a detailed message with rich formatting preserved (thread-safe).
|
||||
|
||||
Args:
|
||||
message_type: Type of message (system, assistant, user, tool, etc.)
|
||||
content: Message content (text, dict, object)
|
||||
iteration: Current iteration number
|
||||
metadata: Additional metadata about the message
|
||||
"""
|
||||
# Check if logging is safe (thread safety + re-entrancy)
|
||||
if not self._can_log_safely():
|
||||
return
|
||||
|
||||
# Enter logging context safely
|
||||
if not self._enter_logging_context():
|
||||
return
|
||||
|
||||
try:
|
||||
# Use non-blocking lock acquisition
|
||||
if self._lock.locked():
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||
|
||||
# Create log entry
|
||||
log_entry = {
|
||||
"timestamp": timestamp,
|
||||
"iteration": iteration,
|
||||
"type": message_type,
|
||||
"content": self._serialize_content(content),
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
# Write to verbose log with rich formatting
|
||||
self._print_to_file(f"\n{'='*80}")
|
||||
self._print_to_file(
|
||||
f"[{timestamp}] Iteration {iteration} - {message_type}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
self._print_to_file(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
self._print_to_file(f"{'='*80}\n")
|
||||
|
||||
# Format content based on type
|
||||
if isinstance(content, str):
|
||||
if len(content) > 2000:
|
||||
preview = content[:1000]
|
||||
self._print_to_file(preview)
|
||||
self._print_to_file(
|
||||
f"\n[Content truncated ({len(content)} chars total)]"
|
||||
)
|
||||
else:
|
||||
self._print_to_file(content)
|
||||
elif isinstance(content, dict):
|
||||
json_str = json.dumps(content, indent=2)
|
||||
self._print_to_file(json_str)
|
||||
else:
|
||||
self._print_to_file(str(content))
|
||||
|
||||
# Write to raw log (complete content)
|
||||
await self._write_raw_log(log_entry)
|
||||
|
||||
# Update metrics
|
||||
await self._update_metrics("message", log_entry)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"Logging error in log_message: {e}", file=sys.stderr)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._exit_logging_context()
|
||||
|
||||
def log_message_sync(
|
||||
self,
|
||||
message_type: str,
|
||||
content: Any,
|
||||
iteration: int = 0,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Synchronous wrapper for log_message."""
|
||||
if self._emergency_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop() # Check if loop exists (raises RuntimeError if not)
|
||||
asyncio.create_task(
|
||||
self.log_message(message_type, content, iteration, metadata)
|
||||
)
|
||||
except RuntimeError:
|
||||
# No running loop, run directly
|
||||
asyncio.run(self.log_message(message_type, content, iteration, metadata))
|
||||
|
||||
async def log_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
input_data: Any,
|
||||
result: Any,
|
||||
iteration: int,
|
||||
duration_ms: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Log a detailed tool call with input and output (thread-safe).
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was called
|
||||
input_data: Tool input parameters
|
||||
result: Tool execution result
|
||||
iteration: Current iteration number
|
||||
duration_ms: Tool execution duration in milliseconds
|
||||
"""
|
||||
if not self._can_log_safely():
|
||||
return
|
||||
|
||||
if not self._enter_logging_context():
|
||||
return
|
||||
|
||||
try:
|
||||
if self._lock.locked():
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||
|
||||
tool_entry = {
|
||||
"timestamp": timestamp,
|
||||
"iteration": iteration,
|
||||
"tool_name": tool_name,
|
||||
"input": self._serialize_content(input_data),
|
||||
"result": self._serialize_content(result),
|
||||
"duration_ms": duration_ms,
|
||||
"success": result is not None,
|
||||
}
|
||||
|
||||
# Write formatted tool call to verbose log
|
||||
duration_text = f"{duration_ms}ms" if duration_ms else "unknown"
|
||||
self._print_to_file(f"\n{'-'*60}")
|
||||
self._print_to_file(f"TOOL CALL: {tool_name} ({duration_text})")
|
||||
|
||||
# Format input
|
||||
if input_data:
|
||||
self._print_to_file("\nInput:")
|
||||
if isinstance(input_data, (dict, list)):
|
||||
input_json = json.dumps(input_data, indent=2)
|
||||
if len(input_json) > 1000:
|
||||
input_json = (
|
||||
input_json[:500]
|
||||
+ "\n ... [truncated] ...\n"
|
||||
+ input_json[-400:]
|
||||
)
|
||||
self._print_to_file(input_json)
|
||||
else:
|
||||
self._print_to_file(str(input_data)[:500])
|
||||
|
||||
# Format result
|
||||
if result:
|
||||
self._print_to_file("\nResult:")
|
||||
result_str = self._serialize_content(result)
|
||||
|
||||
# Check if result is diff content and format with DiffFormatter
|
||||
if (
|
||||
isinstance(result_str, str)
|
||||
and self._is_diff_content(result_str)
|
||||
and self._diff_formatter
|
||||
):
|
||||
self._print_to_file(
|
||||
"[Detected diff content - formatting with enhanced visualization]"
|
||||
)
|
||||
self._diff_formatter.format_and_print(result_str)
|
||||
elif isinstance(result_str, str) and len(result_str) > 1500:
|
||||
preview = (
|
||||
result_str[:750]
|
||||
+ "\n ... [truncated] ...\n"
|
||||
+ result_str[-500:]
|
||||
)
|
||||
self._print_to_file(preview)
|
||||
else:
|
||||
self._print_to_file(str(result_str))
|
||||
|
||||
self._print_to_file(f"{'-'*60}\n")
|
||||
|
||||
await self._write_raw_log(tool_entry)
|
||||
await self._update_metrics("tool_call", tool_entry)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"Logging error in log_tool_call: {e}", file=sys.stderr)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._exit_logging_context()
|
||||
|
||||
async def log_error(
|
||||
self, error: Exception, iteration: int, context: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log detailed error information with traceback (thread-safe).
|
||||
|
||||
Args:
|
||||
error: Exception that occurred
|
||||
iteration: Current iteration number
|
||||
context: Additional context about when the error occurred
|
||||
"""
|
||||
if not self._can_log_safely():
|
||||
return
|
||||
|
||||
if not self._enter_logging_context():
|
||||
return
|
||||
|
||||
try:
|
||||
if self._lock.locked():
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||
|
||||
error_entry = {
|
||||
"timestamp": timestamp,
|
||||
"iteration": iteration,
|
||||
"error_type": type(error).__name__,
|
||||
"error_message": str(error),
|
||||
"context": context,
|
||||
"traceback": self._get_traceback(error),
|
||||
}
|
||||
|
||||
self._print_to_file(f"\n{'!'*20} ERROR DETAILS {'!'*20}")
|
||||
self._print_to_file(f"[{timestamp}] Iteration {iteration}")
|
||||
self._print_to_file(f"Error Type: {type(error).__name__}")
|
||||
|
||||
if context:
|
||||
self._print_to_file(f"Context: {context}")
|
||||
|
||||
self._print_to_file(f"Message: {str(error)}")
|
||||
|
||||
traceback_str = self._get_traceback(error)
|
||||
if traceback_str:
|
||||
self._print_to_file("\nTraceback:")
|
||||
self._print_to_file(traceback_str)
|
||||
|
||||
self._print_to_file(f"{'!'*20} END ERROR {'!'*20}\n")
|
||||
|
||||
await self._write_raw_log(error_entry)
|
||||
await self._update_metrics("error", error_entry)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"Logging error in log_error: {e}", file=sys.stderr)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._exit_logging_context()
|
||||
|
||||
async def log_iteration_summary(
|
||||
self,
|
||||
iteration: int,
|
||||
duration: int,
|
||||
success: bool,
|
||||
message_count: int,
|
||||
stats: Dict[str, int],
|
||||
tokens_used: int = 0,
|
||||
cost: float = 0.0,
|
||||
) -> None:
|
||||
"""
|
||||
Log a detailed iteration summary.
|
||||
|
||||
Args:
|
||||
iteration: Iteration number
|
||||
duration: Duration in seconds
|
||||
success: Whether iteration was successful
|
||||
message_count: Number of messages exchanged
|
||||
stats: Message type statistics
|
||||
tokens_used: Number of tokens used
|
||||
cost: Cost of this iteration
|
||||
"""
|
||||
if self._emergency_event.is_set():
|
||||
return
|
||||
|
||||
if self._lock.locked():
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
summary_entry = {
|
||||
"timestamp": timestamp,
|
||||
"iteration": iteration,
|
||||
"duration_seconds": duration,
|
||||
"success": success,
|
||||
"message_count": message_count,
|
||||
"stats": stats,
|
||||
"tokens_used": tokens_used,
|
||||
"cost": cost,
|
||||
}
|
||||
|
||||
status_icon = "SUCCESS" if success else "FAILED"
|
||||
|
||||
self._print_to_file(f"\n{'#'*15} ITERATION SUMMARY {'#'*15}")
|
||||
self._print_to_file(f"{status_icon} - Iteration {iteration} - {duration}s")
|
||||
self._print_to_file(f"Timestamp: {timestamp}")
|
||||
self._print_to_file(f"Messages: {message_count}")
|
||||
self._print_to_file(f"Tokens: {tokens_used}")
|
||||
self._print_to_file(f"Cost: ${cost:.4f}")
|
||||
|
||||
if stats:
|
||||
self._print_to_file("\nMessage Statistics:")
|
||||
for msg_type, count in stats.items():
|
||||
if count > 0:
|
||||
self._print_to_file(f" {msg_type}: {count}")
|
||||
|
||||
self._print_to_file(f"{'#'*42}\n")
|
||||
|
||||
await self._write_raw_log(summary_entry)
|
||||
await self._update_metrics("iteration", summary_entry)
|
||||
|
||||
# Update total metrics
|
||||
self._metrics["total_tokens"] += tokens_used
|
||||
self._metrics["total_cost"] += cost
|
||||
|
||||
def _serialize_content(
|
||||
self, content: Any
|
||||
) -> str | Dict[Any, Any] | List[Any] | int | float:
|
||||
"""
|
||||
Serialize content to JSON-serializable format.
|
||||
|
||||
Args:
|
||||
content: Content to serialize
|
||||
|
||||
Returns:
|
||||
Serialized content
|
||||
"""
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif hasattr(content, "__dict__"):
|
||||
if hasattr(content, "text"):
|
||||
return {"text": content.text, "type": type(content).__name__}
|
||||
elif hasattr(content, "content"):
|
||||
return {"content": content.content, "type": type(content).__name__}
|
||||
else:
|
||||
return {"repr": str(content), "type": type(content).__name__}
|
||||
elif isinstance(content, (dict, list)):
|
||||
return content
|
||||
elif isinstance(content, (int, float, bool)):
|
||||
return content
|
||||
else:
|
||||
return str(content)
|
||||
except Exception:
|
||||
return f"<unserializable: {type(content).__name__}>"
|
||||
|
||||
def _is_diff_content(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text appears to be diff content.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text looks like diff output
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return False
|
||||
|
||||
lines = text.split("\n")
|
||||
# Check first few lines for diff indicators
|
||||
diff_indicators = [
|
||||
any(line.startswith("diff --git") for line in lines[:5]),
|
||||
any(line.startswith("--- ") for line in lines[:5]),
|
||||
any(line.startswith("+++ ") for line in lines[:5]),
|
||||
any(line.startswith("@@") for line in lines[:10]),
|
||||
any(
|
||||
line.startswith(("+", "-"))
|
||||
and not line.startswith(("+++", "---"))
|
||||
for line in lines[:10]
|
||||
),
|
||||
]
|
||||
return any(diff_indicators)
|
||||
|
||||
async def _write_raw_log(self, entry: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Write entry to raw log file.
|
||||
|
||||
Args:
|
||||
entry: Log entry to write
|
||||
"""
|
||||
if self._emergency_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
if self._raw_file_handle is None:
|
||||
try:
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
self._raw_file_handle = await asyncio.to_thread(
|
||||
open, self.raw_output_file, "a", encoding="utf-8"
|
||||
)
|
||||
except (OSError, IOError):
|
||||
return
|
||||
|
||||
json_line = json.dumps(entry, default=str, ensure_ascii=False) + "\n"
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(self._raw_file_handle.write, json_line),
|
||||
timeout=0.1,
|
||||
)
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(self._raw_file_handle.flush), timeout=0.1
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
except (OSError, IOError):
|
||||
if self._raw_file_handle:
|
||||
try:
|
||||
self._raw_file_handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._raw_file_handle = None
|
||||
return
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _update_metrics(self, entry_type: str, entry: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Update metrics tracking.
|
||||
|
||||
Args:
|
||||
entry_type: Type of entry (message, tool_call, error, iteration)
|
||||
entry: The entry data
|
||||
"""
|
||||
try:
|
||||
if entry_type == "message":
|
||||
self._metrics["messages"].append(entry)
|
||||
elif entry_type == "tool_call":
|
||||
self._metrics["tool_calls"].append(entry)
|
||||
elif entry_type == "error":
|
||||
self._metrics["errors"].append(entry)
|
||||
elif entry_type == "iteration":
|
||||
self._metrics["iterations"].append(entry)
|
||||
|
||||
# Periodically save metrics (every 10 messages)
|
||||
if len(self._metrics["messages"]) % 10 == 0:
|
||||
await self._save_metrics()
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _save_metrics(self) -> None:
|
||||
"""Save metrics to file."""
|
||||
if self._emergency_event.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
metrics_data = {
|
||||
**self._metrics,
|
||||
"session_last_update": datetime.now().isoformat(),
|
||||
"total_messages": len(self._metrics["messages"]),
|
||||
"total_tool_calls": len(self._metrics["tool_calls"]),
|
||||
"total_errors": len(self._metrics["errors"]),
|
||||
"total_iterations": len(self._metrics["iterations"]),
|
||||
}
|
||||
|
||||
if self._lock.locked():
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
lambda: self.metrics_file.write_text(
|
||||
json.dumps(metrics_data, indent=2, default=str)
|
||||
)
|
||||
),
|
||||
timeout=0.5,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_traceback(self, error: Exception) -> str:
|
||||
"""
|
||||
Get formatted traceback from exception.
|
||||
|
||||
Args:
|
||||
error: Exception to get traceback from
|
||||
|
||||
Returns:
|
||||
Formatted traceback string
|
||||
"""
|
||||
import traceback
|
||||
|
||||
try:
|
||||
return "".join(
|
||||
traceback.format_exception(type(error), error, error.__traceback__)
|
||||
)
|
||||
except Exception:
|
||||
return f"Could not extract traceback: {str(error)}"
|
||||
|
||||
def get_session_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current session metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing session metrics
|
||||
"""
|
||||
return {
|
||||
"session_start": self._metrics["session_start"],
|
||||
"session_end": self._metrics.get("session_end"),
|
||||
"total_messages": len(self._metrics["messages"]),
|
||||
"total_tool_calls": len(self._metrics["tool_calls"]),
|
||||
"total_errors": len(self._metrics["errors"]),
|
||||
"total_iterations": len(self._metrics["iterations"]),
|
||||
"total_tokens": self._metrics["total_tokens"],
|
||||
"total_cost": self._metrics["total_cost"],
|
||||
"log_files": {
|
||||
"verbose": str(self.verbose_log_file),
|
||||
"raw": str(self.raw_output_file),
|
||||
"metrics": str(self.metrics_file),
|
||||
},
|
||||
}
|
||||
|
||||
def print_to_console(
|
||||
self, message: str, style: Optional[str] = None, panel: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Print a message to the live console with Rich formatting.
|
||||
|
||||
Args:
|
||||
message: Message to print
|
||||
style: Rich style string (e.g., "bold red", "green")
|
||||
panel: Whether to wrap in a Rich panel
|
||||
"""
|
||||
if self._emergency_event.is_set():
|
||||
return
|
||||
|
||||
if self._live_console and RICH_AVAILABLE:
|
||||
if panel:
|
||||
self._live_console.print(Panel(message))
|
||||
elif style:
|
||||
self._live_console.print(f"[{style}]{message}[/{style}]")
|
||||
else:
|
||||
self._live_console.print(message)
|
||||
else:
|
||||
print(message)
|
||||
|
||||
def print_table(
|
||||
self, title: str, columns: List[str], rows: List[List[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Print a formatted table to the console.
|
||||
|
||||
Args:
|
||||
title: Table title
|
||||
columns: Column headers
|
||||
rows: Table data rows
|
||||
"""
|
||||
if self._emergency_event.is_set():
|
||||
return
|
||||
|
||||
if self._live_console and RICH_AVAILABLE:
|
||||
table = Table(title=title)
|
||||
for col in columns:
|
||||
table.add_column(col)
|
||||
for row in rows:
|
||||
table.add_row(*row)
|
||||
self._live_console.print(table)
|
||||
else:
|
||||
# Plain text fallback
|
||||
print(f"\n{title}")
|
||||
print("-" * 40)
|
||||
print(" | ".join(columns))
|
||||
print("-" * 40)
|
||||
for row in rows:
|
||||
print(" | ".join(row))
|
||||
print("-" * 40)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close log files and save final metrics."""
|
||||
try:
|
||||
self._emergency_shutdown = True
|
||||
self._emergency_event.set()
|
||||
|
||||
# Update session end time
|
||||
self._metrics["session_end"] = datetime.now().isoformat()
|
||||
|
||||
# Save final metrics with timeout
|
||||
try:
|
||||
await asyncio.wait_for(self._save_metrics(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Close raw file handle
|
||||
if self._raw_file_handle:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(self._raw_file_handle.close), timeout=0.5
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self._raw_file_handle = None
|
||||
|
||||
# Write session summary
|
||||
try:
|
||||
if not self._lock.locked():
|
||||
async with self._lock:
|
||||
session_start = self._metrics["session_start"]
|
||||
total_duration = (
|
||||
datetime.now() - datetime.fromisoformat(session_start)
|
||||
).total_seconds()
|
||||
|
||||
self._print_to_file(f"\n{'='*80}")
|
||||
self._print_to_file("SESSION SUMMARY")
|
||||
self._print_to_file(f"Duration: {total_duration:.1f} seconds")
|
||||
self._print_to_file(
|
||||
f"Messages: {len(self._metrics['messages'])}"
|
||||
)
|
||||
self._print_to_file(
|
||||
f"Tool Calls: {len(self._metrics['tool_calls'])}"
|
||||
)
|
||||
self._print_to_file(f"Errors: {len(self._metrics['errors'])}")
|
||||
self._print_to_file(
|
||||
f"Iterations: {len(self._metrics['iterations'])}"
|
||||
)
|
||||
self._print_to_file(
|
||||
f"Total Tokens: {self._metrics['total_tokens']}"
|
||||
)
|
||||
self._print_to_file(
|
||||
f"Total Cost: ${self._metrics['total_cost']:.4f}"
|
||||
)
|
||||
self._print_to_file(f"Verbose log: {self.verbose_log_file}")
|
||||
self._print_to_file(f"Raw log: {self.raw_output_file}")
|
||||
self._print_to_file(f"Metrics: {self.metrics_file}")
|
||||
self._print_to_file(f"{'='*80}\n")
|
||||
except (RuntimeError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
# Close text IO proxy
|
||||
self._text_io_proxy.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error closing verbose logger: {e}", file=sys.stderr)
|
||||
|
||||
def close_sync(self) -> None:
|
||||
"""Synchronous close method."""
|
||||
try:
|
||||
asyncio.get_running_loop() # Check if loop exists (raises RuntimeError if not)
|
||||
asyncio.create_task(self.close())
|
||||
except RuntimeError:
|
||||
asyncio.run(self.close())
|
||||
@@ -0,0 +1,8 @@
|
||||
# ABOUTME: Web UI module for Ralph Orchestrator monitoring and control
|
||||
# ABOUTME: Provides real-time dashboard for agent execution and system metrics
|
||||
|
||||
"""Web UI module for Ralph Orchestrator monitoring."""
|
||||
|
||||
from .server import WebMonitor
|
||||
|
||||
__all__ = ['WebMonitor']
|
||||
@@ -0,0 +1,74 @@
|
||||
# ABOUTME: Entry point for running the Ralph Orchestrator web monitoring server
|
||||
# ABOUTME: Enables execution with `python -m ralph_orchestrator.web`
|
||||
|
||||
"""Entry point for the Ralph Orchestrator web monitoring server."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from .server import WebMonitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the web monitoring server."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Ralph Orchestrator Web Monitoring Dashboard"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8080,
|
||||
help="Port to run the web server on (default: 8080)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host to bind the server to (default: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-auth",
|
||||
action="store_true",
|
||||
help="Disable authentication (not recommended for production)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set logging level (default: INFO)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Create and run the web monitor
|
||||
monitor = WebMonitor(
|
||||
port=args.port,
|
||||
host=args.host,
|
||||
enable_auth=not args.no_auth
|
||||
)
|
||||
|
||||
logger.info(f"Starting Ralph Orchestrator Web Monitor on {args.host}:{args.port}")
|
||||
if args.no_auth:
|
||||
logger.warning("Authentication is disabled - not recommended for production")
|
||||
else:
|
||||
logger.info("Authentication enabled - default credentials: admin / ralph-admin-2024")
|
||||
|
||||
try:
|
||||
asyncio.run(monitor.run())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Web monitor stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Web monitor error: {e}", exc_info=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,186 @@
|
||||
# ABOUTME: Authentication module for Ralph Orchestrator web monitoring dashboard
|
||||
# ABOUTME: Provides JWT-based authentication with username/password login
|
||||
|
||||
"""Authentication module for the web monitoring server."""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import jwt
|
||||
from passlib.context import CryptContext
|
||||
from fastapi import HTTPException, Security, Depends, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = os.getenv("RALPH_WEB_SECRET_KEY", secrets.token_urlsafe(32))
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("RALPH_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours default
|
||||
|
||||
# Default admin credentials (should be changed in production)
|
||||
DEFAULT_USERNAME = os.getenv("RALPH_WEB_USERNAME", "admin")
|
||||
DEFAULT_PASSWORD_HASH = os.getenv("RALPH_WEB_PASSWORD_HASH", None)
|
||||
|
||||
# If no password hash is provided, generate one for the default password
|
||||
if not DEFAULT_PASSWORD_HASH:
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
default_password = os.getenv("RALPH_WEB_PASSWORD", "admin123")
|
||||
DEFAULT_PASSWORD_HASH = pwd_context.hash(default_password)
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# HTTP Bearer token authentication
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Login request model."""
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Token response model."""
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""Manages authentication for the web server."""
|
||||
|
||||
def __init__(self):
|
||||
self.pwd_context = pwd_context
|
||||
self.secret_key = SECRET_KEY
|
||||
self.algorithm = ALGORITHM
|
||||
self.access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
# Simple in-memory user store (can be extended to use a database)
|
||||
self.users = {
|
||||
DEFAULT_USERNAME: {
|
||||
"username": DEFAULT_USERNAME,
|
||||
"hashed_password": DEFAULT_PASSWORD_HASH,
|
||||
"is_active": True,
|
||||
"is_admin": True
|
||||
}
|
||||
}
|
||||
|
||||
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return self.pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(self, password: str) -> str:
|
||||
"""Hash a password."""
|
||||
return self.pwd_context.hash(password)
|
||||
|
||||
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""Authenticate a user by username and password."""
|
||||
user = self.users.get(username)
|
||||
if not user:
|
||||
return None
|
||||
if not self.verify_password(password, user["hashed_password"]):
|
||||
return None
|
||||
if not user.get("is_active", True):
|
||||
return None
|
||||
return user
|
||||
|
||||
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a JWT access token."""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=self.access_token_expire_minutes)
|
||||
|
||||
to_encode.update({"exp": expire, "iat": datetime.now(timezone.utc)})
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify and decode a JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Check if user still exists and is active
|
||||
user = self.users.get(username)
|
||||
if not user or not user.get("is_active", True):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return {"username": username, "user": user}
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from None
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from None
|
||||
|
||||
def add_user(self, username: str, password: str, is_admin: bool = False) -> bool:
|
||||
"""Add a new user to the system."""
|
||||
if username in self.users:
|
||||
return False
|
||||
|
||||
self.users[username] = {
|
||||
"username": username,
|
||||
"hashed_password": self.get_password_hash(password),
|
||||
"is_active": True,
|
||||
"is_admin": is_admin
|
||||
}
|
||||
return True
|
||||
|
||||
def remove_user(self, username: str) -> bool:
|
||||
"""Remove a user from the system."""
|
||||
if username in self.users and username != DEFAULT_USERNAME:
|
||||
del self.users[username]
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_password(self, username: str, new_password: str) -> bool:
|
||||
"""Update a user's password."""
|
||||
if username not in self.users:
|
||||
return False
|
||||
|
||||
self.users[username]["hashed_password"] = self.get_password_hash(new_password)
|
||||
return True
|
||||
|
||||
|
||||
# Global auth manager instance
|
||||
auth_manager = AuthManager()
|
||||
|
||||
|
||||
async def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)) -> Dict[str, Any]:
|
||||
"""Get the current authenticated user from the token."""
|
||||
token = credentials.credentials
|
||||
user_data = auth_manager.verify_token(token)
|
||||
return user_data
|
||||
|
||||
|
||||
async def require_admin(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
|
||||
"""Require the current user to be an admin."""
|
||||
if not current_user["user"].get("is_admin", False):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required"
|
||||
)
|
||||
return current_user
|
||||
@@ -0,0 +1,467 @@
|
||||
# ABOUTME: Database module for Ralph Orchestrator web monitoring
|
||||
# ABOUTME: Provides SQLite storage for execution history and metrics
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages SQLite database for Ralph Orchestrator execution history."""
|
||||
|
||||
def __init__(self, db_path: Optional[Path] = None):
|
||||
"""Initialize database manager.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file (default: ~/.ralph/history.db)
|
||||
"""
|
||||
if db_path is None:
|
||||
config_dir = Path.home() / ".ralph"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
db_path = config_dir / "history.db"
|
||||
|
||||
self.db_path = db_path
|
||||
self._lock = threading.Lock()
|
||||
self._init_database()
|
||||
logger.info(f"Database initialized at {self.db_path}")
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""Thread-safe context manager for database connections."""
|
||||
conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row # Enable column access by name
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize database schema."""
|
||||
with self._lock:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create orchestrator_runs table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS orchestrator_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
orchestrator_id TEXT NOT NULL,
|
||||
prompt_path TEXT NOT NULL,
|
||||
start_time TIMESTAMP NOT NULL,
|
||||
end_time TIMESTAMP,
|
||||
status TEXT NOT NULL,
|
||||
total_iterations INTEGER DEFAULT 0,
|
||||
max_iterations INTEGER,
|
||||
error_message TEXT,
|
||||
metadata TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Create iteration_history table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS iteration_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id INTEGER NOT NULL,
|
||||
iteration_number INTEGER NOT NULL,
|
||||
start_time TIMESTAMP NOT NULL,
|
||||
end_time TIMESTAMP,
|
||||
status TEXT NOT NULL,
|
||||
current_task TEXT,
|
||||
agent_output TEXT,
|
||||
error_message TEXT,
|
||||
metrics TEXT,
|
||||
FOREIGN KEY (run_id) REFERENCES orchestrator_runs(id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create task_history table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS task_history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id INTEGER NOT NULL,
|
||||
task_description TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
iteration_count INTEGER DEFAULT 0,
|
||||
error_message TEXT,
|
||||
FOREIGN KEY (run_id) REFERENCES orchestrator_runs(id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indices for better query performance
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_orchestrator_id
|
||||
ON orchestrator_runs(orchestrator_id)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_start_time
|
||||
ON orchestrator_runs(start_time DESC)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_iterations_run_id
|
||||
ON iteration_history(run_id)
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_run_id
|
||||
ON task_history(run_id)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def create_run(self, orchestrator_id: str, prompt_path: str,
|
||||
max_iterations: Optional[int] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> int:
|
||||
"""Create a new orchestrator run entry.
|
||||
|
||||
Args:
|
||||
orchestrator_id: Unique ID of the orchestrator
|
||||
prompt_path: Path to the prompt file
|
||||
max_iterations: Maximum iterations for this run
|
||||
metadata: Additional metadata to store
|
||||
|
||||
Returns:
|
||||
ID of the created run
|
||||
"""
|
||||
with self._lock:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO orchestrator_runs
|
||||
(orchestrator_id, prompt_path, start_time, status, max_iterations, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
orchestrator_id,
|
||||
prompt_path,
|
||||
datetime.now().isoformat(),
|
||||
"running",
|
||||
max_iterations,
|
||||
json.dumps(metadata) if metadata else None
|
||||
))
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def update_run_status(self, run_id: int, status: str,
|
||||
error_message: Optional[str] = None,
|
||||
total_iterations: Optional[int] = None):
|
||||
"""Update the status of an orchestrator run.
|
||||
|
||||
Args:
|
||||
run_id: ID of the run to update
|
||||
status: New status (running, completed, failed, paused)
|
||||
error_message: Error message if failed
|
||||
total_iterations: Total iterations completed
|
||||
"""
|
||||
with self._lock:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
updates = ["status = ?"]
|
||||
params = [status]
|
||||
|
||||
if status in ["completed", "failed"]:
|
||||
updates.append("end_time = ?")
|
||||
params.append(datetime.now().isoformat())
|
||||
|
||||
if error_message is not None:
|
||||
updates.append("error_message = ?")
|
||||
params.append(error_message)
|
||||
|
||||
if total_iterations is not None:
|
||||
updates.append("total_iterations = ?")
|
||||
params.append(total_iterations)
|
||||
|
||||
params.append(run_id)
|
||||
cursor.execute(f"""
|
||||
UPDATE orchestrator_runs
|
||||
SET {', '.join(updates)}
|
||||
WHERE id = ?
|
||||
""", params)
|
||||
conn.commit()
|
||||
|
||||
def add_iteration(self, run_id: int, iteration_number: int,
|
||||
current_task: Optional[str] = None,
|
||||
metrics: Optional[Dict[str, Any]] = None) -> int:
|
||||
"""Add a new iteration entry.
|
||||
|
||||
Args:
|
||||
run_id: ID of the parent run
|
||||
iteration_number: Iteration number
|
||||
current_task: Current task being executed
|
||||
metrics: Performance metrics for this iteration
|
||||
|
||||
Returns:
|
||||
ID of the created iteration
|
||||
"""
|
||||
with self._lock:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO iteration_history
|
||||
(run_id, iteration_number, start_time, status, current_task, metrics)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
run_id,
|
||||
iteration_number,
|
||||
datetime.now().isoformat(),
|
||||
"running",
|
||||
current_task,
|
||||
json.dumps(metrics) if metrics else None
|
||||
))
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def update_iteration(self, iteration_id: int, status: str,
|
||||
agent_output: Optional[str] = None,
|
||||
error_message: Optional[str] = None):
|
||||
"""Update an iteration entry.
|
||||
|
||||
Args:
|
||||
iteration_id: ID of the iteration to update
|
||||
status: New status (running, completed, failed)
|
||||
agent_output: Output from the agent
|
||||
error_message: Error message if failed
|
||||
"""
|
||||
with self._lock:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
UPDATE iteration_history
|
||||
SET status = ?, end_time = ?, agent_output = ?, error_message = ?
|
||||
WHERE id = ?
|
||||
""", (
|
||||
status,
|
||||
datetime.now().isoformat() if status != "running" else None,
|
||||
agent_output,
|
||||
error_message,
|
||||
iteration_id
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
def add_task(self, run_id: int, task_description: str) -> int:
|
||||
"""Add a task entry.
|
||||
|
||||
Args:
|
||||
run_id: ID of the parent run
|
||||
task_description: Description of the task
|
||||
|
||||
Returns:
|
||||
ID of the created task
|
||||
"""
|
||||
with self._lock:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO task_history (run_id, task_description, status)
|
||||
VALUES (?, ?, ?)
|
||||
""", (run_id, task_description, "pending"))
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
def update_task_status(self, task_id: int, status: str,
|
||||
error_message: Optional[str] = None):
|
||||
"""Update task status.
|
||||
|
||||
Args:
|
||||
task_id: ID of the task to update
|
||||
status: New status (pending, in_progress, completed, failed)
|
||||
error_message: Error message if failed
|
||||
"""
|
||||
with self._lock:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
if status == "in_progress":
|
||||
cursor.execute("""
|
||||
UPDATE task_history
|
||||
SET status = ?, start_time = ?
|
||||
WHERE id = ?
|
||||
""", (status, now, task_id))
|
||||
elif status in ["completed", "failed"]:
|
||||
cursor.execute("""
|
||||
UPDATE task_history
|
||||
SET status = ?, end_time = ?, error_message = ?
|
||||
WHERE id = ?
|
||||
""", (status, now, error_message, task_id))
|
||||
else:
|
||||
cursor.execute("""
|
||||
UPDATE task_history
|
||||
SET status = ?
|
||||
WHERE id = ?
|
||||
""", (status, task_id))
|
||||
|
||||
conn.commit()
|
||||
|
||||
def get_recent_runs(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""Get recent orchestrator runs.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of runs to return
|
||||
|
||||
Returns:
|
||||
List of run dictionaries
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM orchestrator_runs
|
||||
ORDER BY start_time DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
runs = []
|
||||
for row in cursor.fetchall():
|
||||
run = dict(row)
|
||||
if run.get('metadata'):
|
||||
run['metadata'] = json.loads(run['metadata'])
|
||||
runs.append(run)
|
||||
|
||||
return runs
|
||||
|
||||
def get_run_details(self, run_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get detailed information about a specific run.
|
||||
|
||||
Args:
|
||||
run_id: ID of the run
|
||||
|
||||
Returns:
|
||||
Run details with iterations and tasks
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get run info
|
||||
cursor.execute("SELECT * FROM orchestrator_runs WHERE id = ?", (run_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
run = dict(row)
|
||||
if run.get('metadata'):
|
||||
run['metadata'] = json.loads(run['metadata'])
|
||||
|
||||
# Get iterations
|
||||
cursor.execute("""
|
||||
SELECT * FROM iteration_history
|
||||
WHERE run_id = ?
|
||||
ORDER BY iteration_number
|
||||
""", (run_id,))
|
||||
|
||||
iterations = []
|
||||
for row in cursor.fetchall():
|
||||
iteration = dict(row)
|
||||
if iteration.get('metrics'):
|
||||
iteration['metrics'] = json.loads(iteration['metrics'])
|
||||
iterations.append(iteration)
|
||||
|
||||
run['iterations'] = iterations
|
||||
|
||||
# Get tasks
|
||||
cursor.execute("""
|
||||
SELECT * FROM task_history
|
||||
WHERE run_id = ?
|
||||
ORDER BY id
|
||||
""", (run_id,))
|
||||
|
||||
run['tasks'] = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
return run
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get database statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
stats = {}
|
||||
|
||||
# Total runs
|
||||
cursor.execute("SELECT COUNT(*) FROM orchestrator_runs")
|
||||
stats['total_runs'] = cursor.fetchone()[0]
|
||||
|
||||
# Runs by status
|
||||
cursor.execute("""
|
||||
SELECT status, COUNT(*)
|
||||
FROM orchestrator_runs
|
||||
GROUP BY status
|
||||
""")
|
||||
stats['runs_by_status'] = dict(cursor.fetchall())
|
||||
|
||||
# Total iterations
|
||||
cursor.execute("SELECT COUNT(*) FROM iteration_history")
|
||||
stats['total_iterations'] = cursor.fetchone()[0]
|
||||
|
||||
# Total tasks
|
||||
cursor.execute("SELECT COUNT(*) FROM task_history")
|
||||
stats['total_tasks'] = cursor.fetchone()[0]
|
||||
|
||||
# Average iterations per run
|
||||
cursor.execute("""
|
||||
SELECT AVG(total_iterations)
|
||||
FROM orchestrator_runs
|
||||
WHERE total_iterations > 0
|
||||
""")
|
||||
result = cursor.fetchone()[0]
|
||||
stats['avg_iterations_per_run'] = round(result, 2) if result else 0
|
||||
|
||||
# Success rate
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(CASE WHEN status = 'completed' THEN 1 END) * 100.0 /
|
||||
NULLIF(COUNT(*), 0)
|
||||
FROM orchestrator_runs
|
||||
WHERE status IN ('completed', 'failed')
|
||||
""")
|
||||
result = cursor.fetchone()[0]
|
||||
stats['success_rate'] = round(result, 2) if result else 0
|
||||
|
||||
return stats
|
||||
|
||||
def cleanup_old_records(self, days: int = 30):
|
||||
"""Remove records older than specified days.
|
||||
|
||||
Args:
|
||||
days: Number of days to keep
|
||||
"""
|
||||
with self._lock:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get run IDs to delete (using SQLite datetime functions directly)
|
||||
cursor.execute("""
|
||||
SELECT id FROM orchestrator_runs
|
||||
WHERE datetime(start_time) < datetime('now', '-' || ? || ' days')
|
||||
""", (days,))
|
||||
run_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if run_ids:
|
||||
# Delete iterations
|
||||
cursor.execute("""
|
||||
DELETE FROM iteration_history
|
||||
WHERE run_id IN ({})
|
||||
""".format(','.join('?' * len(run_ids))), run_ids)
|
||||
|
||||
# Delete tasks
|
||||
cursor.execute("""
|
||||
DELETE FROM task_history
|
||||
WHERE run_id IN ({})
|
||||
""".format(','.join('?' * len(run_ids))), run_ids)
|
||||
|
||||
# Delete runs
|
||||
cursor.execute("""
|
||||
DELETE FROM orchestrator_runs
|
||||
WHERE id IN ({})
|
||||
""".format(','.join('?' * len(run_ids))), run_ids)
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"Cleaned up {len(run_ids)} old runs")
|
||||
@@ -0,0 +1,280 @@
|
||||
# ABOUTME: Implements rate limiting for API endpoints to prevent abuse
|
||||
# ABOUTME: Uses token bucket algorithm with configurable limits per endpoint
|
||||
|
||||
"""Rate limiting implementation for the Ralph web server."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Optional, Tuple
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter implementation.
|
||||
|
||||
Uses a token bucket algorithm to limit requests per IP address.
|
||||
Tokens are replenished at a fixed rate up to a maximum capacity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int = 100,
|
||||
refill_rate: float = 10.0,
|
||||
refill_period: float = 1.0,
|
||||
block_duration: float = 60.0
|
||||
):
|
||||
"""Initialize the rate limiter.
|
||||
|
||||
Args:
|
||||
capacity: Maximum number of tokens in the bucket
|
||||
refill_rate: Number of tokens to add per refill period
|
||||
refill_period: Time in seconds between refills
|
||||
block_duration: Time in seconds to block an IP after exhausting tokens
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate
|
||||
self.refill_period = refill_period
|
||||
self.block_duration = block_duration
|
||||
|
||||
# Store buckets per IP address
|
||||
self.buckets: Dict[str, Tuple[float, float, float]] = defaultdict(
|
||||
lambda: (float(capacity), time.time(), 0.0)
|
||||
)
|
||||
self.blocked_ips: Dict[str, float] = {}
|
||||
|
||||
# Lock for thread-safe access
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def check_rate_limit(self, identifier: str) -> Tuple[bool, Optional[int]]:
|
||||
"""Check if a request is allowed under the rate limit.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier for the client (e.g., IP address)
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, retry_after_seconds)
|
||||
"""
|
||||
async with self._lock:
|
||||
current_time = time.time()
|
||||
|
||||
# Check if IP is blocked
|
||||
if identifier in self.blocked_ips:
|
||||
block_end = self.blocked_ips[identifier]
|
||||
if current_time < block_end:
|
||||
retry_after = int(block_end - current_time)
|
||||
return False, retry_after
|
||||
else:
|
||||
# Unblock the IP
|
||||
del self.blocked_ips[identifier]
|
||||
|
||||
# Get or create bucket for this identifier
|
||||
tokens, last_refill, consecutive_violations = self.buckets[identifier]
|
||||
|
||||
# Calculate tokens to add based on time elapsed
|
||||
time_elapsed = current_time - last_refill
|
||||
tokens_to_add = (time_elapsed / self.refill_period) * self.refill_rate
|
||||
tokens = min(self.capacity, tokens + tokens_to_add)
|
||||
|
||||
if tokens >= 1:
|
||||
# Consume a token
|
||||
tokens -= 1
|
||||
consecutive_violations = 0
|
||||
self.buckets[identifier] = (tokens, current_time, consecutive_violations)
|
||||
return True, None
|
||||
else:
|
||||
# No tokens available
|
||||
consecutive_violations += 1
|
||||
|
||||
# Block IP if too many consecutive violations
|
||||
if consecutive_violations >= 5:
|
||||
block_end = current_time + self.block_duration
|
||||
self.blocked_ips[identifier] = block_end
|
||||
del self.buckets[identifier]
|
||||
return False, int(self.block_duration)
|
||||
|
||||
self.buckets[identifier] = (tokens, current_time, consecutive_violations)
|
||||
retry_after = int(self.refill_period / self.refill_rate)
|
||||
return False, retry_after
|
||||
|
||||
async def cleanup_old_buckets(self, max_age: float = 3600.0):
|
||||
"""Remove old inactive buckets to prevent memory growth.
|
||||
|
||||
Args:
|
||||
max_age: Maximum age in seconds for inactive buckets
|
||||
"""
|
||||
async with self._lock:
|
||||
current_time = time.time()
|
||||
to_remove = []
|
||||
|
||||
for identifier, (_tokens, last_refill, _) in self.buckets.items():
|
||||
if current_time - last_refill > max_age:
|
||||
to_remove.append(identifier)
|
||||
|
||||
for identifier in to_remove:
|
||||
del self.buckets[identifier]
|
||||
|
||||
# Clean up expired blocks
|
||||
to_remove = []
|
||||
for identifier, block_end in self.blocked_ips.items():
|
||||
if current_time >= block_end:
|
||||
to_remove.append(identifier)
|
||||
|
||||
for identifier in to_remove:
|
||||
del self.blocked_ips[identifier]
|
||||
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} expired rate limit entries")
|
||||
|
||||
|
||||
class RateLimitConfig:
|
||||
"""Configuration for different rate limit tiers."""
|
||||
|
||||
# Default limits for different endpoint categories
|
||||
LIMITS = {
|
||||
"auth": {"capacity": 10, "refill_rate": 1.0, "refill_period": 60.0}, # 10 requests/minute
|
||||
"api": {"capacity": 100, "refill_rate": 10.0, "refill_period": 1.0}, # 100 requests/10 seconds
|
||||
"websocket": {"capacity": 10, "refill_rate": 1.0, "refill_period": 10.0}, # 10 connections/10 seconds
|
||||
"static": {"capacity": 200, "refill_rate": 20.0, "refill_period": 1.0}, # 200 requests/20 seconds
|
||||
"admin": {"capacity": 50, "refill_rate": 5.0, "refill_period": 1.0}, # 50 requests/5 seconds
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_limiter(cls, category: str) -> RateLimiter:
|
||||
"""Get or create a rate limiter for a specific category.
|
||||
|
||||
Args:
|
||||
category: The category of endpoints to limit
|
||||
|
||||
Returns:
|
||||
A configured RateLimiter instance
|
||||
"""
|
||||
if not hasattr(cls, "_limiters"):
|
||||
cls._limiters = {}
|
||||
|
||||
if category not in cls._limiters:
|
||||
config = cls.LIMITS.get(category, cls.LIMITS["api"])
|
||||
cls._limiters[category] = RateLimiter(**config)
|
||||
|
||||
return cls._limiters[category]
|
||||
|
||||
|
||||
def rate_limit(category: str = "api"):
|
||||
"""Decorator to apply rate limiting to FastAPI endpoints.
|
||||
|
||||
Args:
|
||||
category: The rate limit category to apply
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(request: Request, *args, **kwargs):
|
||||
# Get client IP address
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# Check for X-Forwarded-For header (for proxies)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Get the appropriate rate limiter
|
||||
limiter = RateLimitConfig.get_limiter(category)
|
||||
|
||||
# Check rate limit
|
||||
allowed, retry_after = await limiter.check_rate_limit(client_ip)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(f"Rate limit exceeded for {client_ip} on {category} endpoint")
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"retry_after": retry_after
|
||||
}
|
||||
)
|
||||
if retry_after:
|
||||
response.headers["Retry-After"] = str(retry_after)
|
||||
return response
|
||||
|
||||
# Call the original function
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
async def setup_rate_limit_cleanup():
|
||||
"""Set up periodic cleanup of old rate limit buckets.
|
||||
|
||||
Returns:
|
||||
The cleanup task
|
||||
"""
|
||||
async def cleanup_task():
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
for category in RateLimitConfig.LIMITS:
|
||||
limiter = RateLimitConfig.get_limiter(category)
|
||||
await limiter.cleanup_old_buckets()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in rate limit cleanup: {e}")
|
||||
|
||||
return asyncio.create_task(cleanup_task())
|
||||
|
||||
|
||||
# Middleware for global rate limiting
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""Global rate limiting middleware for all requests.
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
call_next: The next middleware or endpoint
|
||||
|
||||
Returns:
|
||||
The response
|
||||
"""
|
||||
# Determine the category based on the path
|
||||
path = request.url.path
|
||||
|
||||
if path.startswith("/api/auth"):
|
||||
category = "auth"
|
||||
elif path.startswith("/api/admin"):
|
||||
category = "admin"
|
||||
elif path.startswith("/ws"):
|
||||
category = "websocket"
|
||||
elif path.startswith("/static"):
|
||||
category = "static"
|
||||
else:
|
||||
category = "api"
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check rate limit
|
||||
limiter = RateLimitConfig.get_limiter(category)
|
||||
allowed, retry_after = await limiter.check_rate_limit(client_ip)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(f"Rate limit exceeded for {client_ip} on {path}")
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"retry_after": retry_after
|
||||
}
|
||||
)
|
||||
if retry_after:
|
||||
response.headers["Retry-After"] = str(retry_after)
|
||||
return response
|
||||
|
||||
# Continue with the request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
@@ -0,0 +1,712 @@
|
||||
# ABOUTME: FastAPI web server for Ralph Orchestrator monitoring dashboard
|
||||
# ABOUTME: Provides REST API endpoints and WebSocket connections for real-time updates
|
||||
|
||||
"""FastAPI web server for Ralph Orchestrator monitoring."""
|
||||
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, status
|
||||
from fastapi.responses import HTMLResponse, FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
import psutil
|
||||
|
||||
from ..orchestrator import RalphOrchestrator
|
||||
from .auth import (
|
||||
auth_manager, LoginRequest, TokenResponse,
|
||||
get_current_user, require_admin
|
||||
)
|
||||
from .database import DatabaseManager
|
||||
from .rate_limit import rate_limit_middleware, setup_rate_limit_cleanup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptUpdateRequest(BaseModel):
|
||||
"""Request model for updating orchestrator prompt."""
|
||||
content: str
|
||||
|
||||
|
||||
class OrchestratorMonitor:
|
||||
"""Monitors and manages orchestrator instances."""
|
||||
|
||||
def __init__(self):
|
||||
self.active_orchestrators: Dict[str, RalphOrchestrator] = {}
|
||||
self.execution_history: List[Dict[str, Any]] = []
|
||||
self.websocket_clients: List[WebSocket] = []
|
||||
self.metrics_cache: Dict[str, Any] = {}
|
||||
self.system_metrics_task: Optional[asyncio.Task] = None
|
||||
self.database = DatabaseManager()
|
||||
self.active_runs: Dict[str, int] = {} # Maps orchestrator_id to run_id
|
||||
self.active_iterations: Dict[str, int] = {} # Maps orchestrator_id to iteration_id
|
||||
|
||||
async def start_monitoring(self):
|
||||
"""Start background monitoring tasks."""
|
||||
if not self.system_metrics_task:
|
||||
self.system_metrics_task = asyncio.create_task(self._monitor_system_metrics())
|
||||
|
||||
async def stop_monitoring(self):
|
||||
"""Stop background monitoring tasks."""
|
||||
if self.system_metrics_task:
|
||||
self.system_metrics_task.cancel()
|
||||
try:
|
||||
await self.system_metrics_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _monitor_system_metrics(self):
|
||||
"""Monitor system metrics continuously."""
|
||||
while True:
|
||||
try:
|
||||
# Collect system metrics
|
||||
metrics = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_percent": psutil.cpu_percent(interval=1),
|
||||
"memory": {
|
||||
"total": psutil.virtual_memory().total,
|
||||
"available": psutil.virtual_memory().available,
|
||||
"percent": psutil.virtual_memory().percent
|
||||
},
|
||||
"active_processes": len(psutil.pids()),
|
||||
"orchestrators": len(self.active_orchestrators)
|
||||
}
|
||||
|
||||
self.metrics_cache["system"] = metrics
|
||||
|
||||
# Broadcast to WebSocket clients
|
||||
await self._broadcast_to_clients({
|
||||
"type": "system_metrics",
|
||||
"data": metrics
|
||||
})
|
||||
|
||||
await asyncio.sleep(5) # Update every 5 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring system metrics: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _broadcast_to_clients(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected WebSocket clients."""
|
||||
disconnected_clients = []
|
||||
for client in self.websocket_clients:
|
||||
try:
|
||||
await client.send_json(message)
|
||||
except Exception:
|
||||
disconnected_clients.append(client)
|
||||
|
||||
# Remove disconnected clients
|
||||
for client in disconnected_clients:
|
||||
if client in self.websocket_clients:
|
||||
self.websocket_clients.remove(client)
|
||||
|
||||
def _schedule_broadcast(self, message: Dict[str, Any]):
|
||||
"""Schedule a broadcast to clients, handling both sync and async contexts."""
|
||||
try:
|
||||
# Check if there's a running event loop (raises RuntimeError if not)
|
||||
asyncio.get_running_loop()
|
||||
# If we're in an async context, schedule the broadcast
|
||||
asyncio.create_task(self._broadcast_to_clients(message))
|
||||
except RuntimeError:
|
||||
# No event loop running - we're in a sync context (e.g., during testing)
|
||||
# The broadcast will be skipped in this case
|
||||
pass
|
||||
|
||||
async def broadcast_update(self, message: Dict[str, Any]):
|
||||
"""Public method to broadcast updates to WebSocket clients."""
|
||||
await self._broadcast_to_clients(message)
|
||||
|
||||
def register_orchestrator(self, orchestrator_id: str, orchestrator: RalphOrchestrator):
|
||||
"""Register an orchestrator instance."""
|
||||
self.active_orchestrators[orchestrator_id] = orchestrator
|
||||
|
||||
# Create a new run in the database
|
||||
try:
|
||||
run_id = self.database.create_run(
|
||||
orchestrator_id=orchestrator_id,
|
||||
prompt_path=str(orchestrator.prompt_file),
|
||||
max_iterations=orchestrator.max_iterations,
|
||||
metadata={
|
||||
"primary_tool": orchestrator.primary_tool,
|
||||
"max_runtime": orchestrator.max_runtime
|
||||
}
|
||||
)
|
||||
self.active_runs[orchestrator_id] = run_id
|
||||
|
||||
# Extract and store tasks if available
|
||||
if hasattr(orchestrator, 'task_queue'):
|
||||
for task in orchestrator.task_queue:
|
||||
self.database.add_task(run_id, task['description'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating database run for orchestrator {orchestrator_id}: {e}")
|
||||
|
||||
self._schedule_broadcast({
|
||||
"type": "orchestrator_registered",
|
||||
"data": {"id": orchestrator_id, "timestamp": datetime.now().isoformat()}
|
||||
})
|
||||
|
||||
def unregister_orchestrator(self, orchestrator_id: str):
|
||||
"""Unregister an orchestrator instance."""
|
||||
if orchestrator_id in self.active_orchestrators:
|
||||
# Update database run status
|
||||
if orchestrator_id in self.active_runs:
|
||||
try:
|
||||
orchestrator = self.active_orchestrators[orchestrator_id]
|
||||
status = "completed" if not orchestrator.stop_requested else "stopped"
|
||||
total_iterations = orchestrator.metrics.total_iterations if hasattr(orchestrator, 'metrics') else 0
|
||||
self.database.update_run_status(
|
||||
self.active_runs[orchestrator_id],
|
||||
status=status,
|
||||
total_iterations=total_iterations
|
||||
)
|
||||
del self.active_runs[orchestrator_id]
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating database run for orchestrator {orchestrator_id}: {e}")
|
||||
|
||||
# Remove from active orchestrators
|
||||
del self.active_orchestrators[orchestrator_id]
|
||||
|
||||
# Remove active iteration tracking if exists
|
||||
if orchestrator_id in self.active_iterations:
|
||||
del self.active_iterations[orchestrator_id]
|
||||
|
||||
self._schedule_broadcast({
|
||||
"type": "orchestrator_unregistered",
|
||||
"data": {"id": orchestrator_id, "timestamp": datetime.now().isoformat()}
|
||||
})
|
||||
|
||||
def get_orchestrator_status(self, orchestrator_id: str) -> Dict[str, Any]:
|
||||
"""Get status of a specific orchestrator."""
|
||||
if orchestrator_id not in self.active_orchestrators:
|
||||
return None
|
||||
|
||||
orchestrator = self.active_orchestrators[orchestrator_id]
|
||||
|
||||
# Try to use the new get_orchestrator_state method if it exists
|
||||
if hasattr(orchestrator, 'get_orchestrator_state'):
|
||||
state = orchestrator.get_orchestrator_state()
|
||||
state['id'] = orchestrator_id # Override with our ID
|
||||
return state
|
||||
else:
|
||||
# Fallback to old method for compatibility
|
||||
return {
|
||||
"id": orchestrator_id,
|
||||
"status": "running" if not orchestrator.stop_requested else "stopping",
|
||||
"metrics": orchestrator.metrics.to_dict(),
|
||||
"cost": orchestrator.cost_tracker.get_summary() if orchestrator.cost_tracker else None,
|
||||
"config": {
|
||||
"primary_tool": orchestrator.primary_tool,
|
||||
"max_iterations": orchestrator.max_iterations,
|
||||
"max_runtime": orchestrator.max_runtime,
|
||||
"prompt_file": str(orchestrator.prompt_file)
|
||||
}
|
||||
}
|
||||
|
||||
def get_all_orchestrators_status(self) -> List[Dict[str, Any]]:
|
||||
"""Get status of all orchestrators."""
|
||||
return [
|
||||
self.get_orchestrator_status(orch_id)
|
||||
for orch_id in self.active_orchestrators
|
||||
]
|
||||
|
||||
def start_iteration(self, orchestrator_id: str, iteration_number: int,
|
||||
current_task: Optional[str] = None) -> Optional[int]:
|
||||
"""Start tracking a new iteration.
|
||||
|
||||
Args:
|
||||
orchestrator_id: ID of the orchestrator
|
||||
iteration_number: Iteration number
|
||||
current_task: Current task being executed
|
||||
|
||||
Returns:
|
||||
Iteration ID if successful, None otherwise
|
||||
"""
|
||||
if orchestrator_id not in self.active_runs:
|
||||
return None
|
||||
|
||||
try:
|
||||
iteration_id = self.database.add_iteration(
|
||||
run_id=self.active_runs[orchestrator_id],
|
||||
iteration_number=iteration_number,
|
||||
current_task=current_task,
|
||||
metrics=None # Can be enhanced to include metrics
|
||||
)
|
||||
self.active_iterations[orchestrator_id] = iteration_id
|
||||
return iteration_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting iteration for orchestrator {orchestrator_id}: {e}")
|
||||
return None
|
||||
|
||||
def end_iteration(self, orchestrator_id: str, status: str = "completed",
|
||||
agent_output: Optional[str] = None, error_message: Optional[str] = None):
|
||||
"""End tracking for the current iteration.
|
||||
|
||||
Args:
|
||||
orchestrator_id: ID of the orchestrator
|
||||
status: Status of the iteration (completed, failed)
|
||||
agent_output: Output from the agent
|
||||
error_message: Error message if failed
|
||||
"""
|
||||
if orchestrator_id not in self.active_iterations:
|
||||
return
|
||||
|
||||
try:
|
||||
self.database.update_iteration(
|
||||
iteration_id=self.active_iterations[orchestrator_id],
|
||||
status=status,
|
||||
agent_output=agent_output,
|
||||
error_message=error_message
|
||||
)
|
||||
del self.active_iterations[orchestrator_id]
|
||||
except Exception as e:
|
||||
logger.error(f"Error ending iteration for orchestrator {orchestrator_id}: {e}")
|
||||
|
||||
|
||||
class WebMonitor:
|
||||
"""Web monitoring server for Ralph Orchestrator."""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8080, enable_auth: bool = True):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.enable_auth = enable_auth
|
||||
self.monitor = OrchestratorMonitor()
|
||||
self.app = None
|
||||
self._setup_app()
|
||||
|
||||
def _setup_app(self):
|
||||
"""Setup FastAPI application."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
await self.monitor.start_monitoring()
|
||||
# Start rate limit cleanup task
|
||||
cleanup_task = await setup_rate_limit_cleanup()
|
||||
yield
|
||||
# Shutdown
|
||||
cleanup_task.cancel()
|
||||
try:
|
||||
await cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self.monitor.stop_monitoring()
|
||||
|
||||
self.app = FastAPI(
|
||||
title="Ralph Orchestrator Monitor",
|
||||
description="Real-time monitoring for Ralph AI Orchestrator",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add rate limiting middleware
|
||||
self.app.middleware("http")(rate_limit_middleware)
|
||||
|
||||
# Mount static files directory if it exists
|
||||
static_dir = Path(__file__).parent / "static"
|
||||
if static_dir.exists():
|
||||
self.app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
||||
|
||||
# Setup routes
|
||||
self._setup_routes()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Setup API routes."""
|
||||
|
||||
# Authentication endpoints (public)
|
||||
@self.app.post("/api/auth/login", response_model=TokenResponse)
|
||||
async def login(request: LoginRequest):
|
||||
"""Login and receive an access token."""
|
||||
user = auth_manager.authenticate_user(request.username, request.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token = auth_manager.create_access_token(
|
||||
data={"sub": user["username"]}
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
expires_in=auth_manager.access_token_expire_minutes * 60
|
||||
)
|
||||
|
||||
@self.app.get("/api/auth/verify")
|
||||
async def verify_token(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""Verify the current token is valid."""
|
||||
return {
|
||||
"valid": True,
|
||||
"username": current_user["username"],
|
||||
"is_admin": current_user["user"].get("is_admin", False)
|
||||
}
|
||||
|
||||
# Public endpoints - HTML pages
|
||||
@self.app.get("/login.html")
|
||||
async def login_page():
|
||||
"""Serve the login page."""
|
||||
html_file = Path(__file__).parent / "static" / "login.html"
|
||||
if html_file.exists():
|
||||
return FileResponse(html_file, media_type="text/html")
|
||||
else:
|
||||
return HTMLResponse(content="<h1>Login page not found</h1>", status_code=404)
|
||||
|
||||
@self.app.get("/")
|
||||
async def index():
|
||||
"""Serve the main dashboard."""
|
||||
html_file = Path(__file__).parent / "static" / "index.html"
|
||||
if html_file.exists():
|
||||
return FileResponse(html_file, media_type="text/html")
|
||||
else:
|
||||
# Return a basic HTML page if static file doesn't exist yet
|
||||
return HTMLResponse(content="""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Ralph Orchestrator Monitor</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 20px; }
|
||||
h1 { color: #333; }
|
||||
.status { padding: 10px; margin: 10px 0; background: #f0f0f0; border-radius: 5px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Ralph Orchestrator Monitor</h1>
|
||||
<div id="status" class="status">
|
||||
<p>Web monitor is running. Dashboard file not found.</p>
|
||||
<p>API Endpoints:</p>
|
||||
<ul>
|
||||
<li><a href="/api/status">/api/status</a> - System status</li>
|
||||
<li><a href="/api/orchestrators">/api/orchestrators</a> - Active orchestrators</li>
|
||||
<li><a href="/api/metrics">/api/metrics</a> - System metrics</li>
|
||||
<li><a href="/docs">/docs</a> - API documentation</li>
|
||||
</ul>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
|
||||
# Create dependency for auth if enabled
|
||||
auth_dependency = Depends(get_current_user) if self.enable_auth else None
|
||||
|
||||
@self.app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@self.app.get("/api/status", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_status():
|
||||
"""Get overall system status."""
|
||||
return {
|
||||
"status": "online",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"active_orchestrators": len(self.monitor.active_orchestrators),
|
||||
"connected_clients": len(self.monitor.websocket_clients),
|
||||
"system_metrics": self.monitor.metrics_cache.get("system", {})
|
||||
}
|
||||
|
||||
@self.app.get("/api/orchestrators", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_orchestrators():
|
||||
"""Get all active orchestrators."""
|
||||
return {
|
||||
"orchestrators": self.monitor.get_all_orchestrators_status(),
|
||||
"count": len(self.monitor.active_orchestrators)
|
||||
}
|
||||
|
||||
@self.app.get("/api/orchestrators/{orchestrator_id}", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_orchestrator(orchestrator_id: str):
|
||||
"""Get specific orchestrator status."""
|
||||
status = self.monitor.get_orchestrator_status(orchestrator_id)
|
||||
if not status:
|
||||
raise HTTPException(status_code=404, detail="Orchestrator not found")
|
||||
return status
|
||||
|
||||
@self.app.get("/api/orchestrators/{orchestrator_id}/tasks", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_orchestrator_tasks(orchestrator_id: str):
|
||||
"""Get task queue status for an orchestrator."""
|
||||
if orchestrator_id not in self.monitor.active_orchestrators:
|
||||
raise HTTPException(status_code=404, detail="Orchestrator not found")
|
||||
|
||||
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
|
||||
task_status = orchestrator.get_task_status()
|
||||
|
||||
return {
|
||||
"orchestrator_id": orchestrator_id,
|
||||
"tasks": task_status
|
||||
}
|
||||
|
||||
@self.app.post("/api/orchestrators/{orchestrator_id}/pause", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def pause_orchestrator(orchestrator_id: str):
|
||||
"""Pause an orchestrator."""
|
||||
if orchestrator_id not in self.monitor.active_orchestrators:
|
||||
raise HTTPException(status_code=404, detail="Orchestrator not found")
|
||||
|
||||
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
|
||||
orchestrator.stop_requested = True
|
||||
|
||||
return {"status": "paused", "orchestrator_id": orchestrator_id}
|
||||
|
||||
@self.app.post("/api/orchestrators/{orchestrator_id}/resume", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def resume_orchestrator(orchestrator_id: str):
|
||||
"""Resume an orchestrator."""
|
||||
if orchestrator_id not in self.monitor.active_orchestrators:
|
||||
raise HTTPException(status_code=404, detail="Orchestrator not found")
|
||||
|
||||
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
|
||||
orchestrator.stop_requested = False
|
||||
|
||||
return {"status": "resumed", "orchestrator_id": orchestrator_id}
|
||||
|
||||
@self.app.get("/api/orchestrators/{orchestrator_id}/prompt", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_orchestrator_prompt(orchestrator_id: str):
|
||||
"""Get the current prompt for an orchestrator."""
|
||||
if orchestrator_id not in self.monitor.active_orchestrators:
|
||||
raise HTTPException(status_code=404, detail="Orchestrator not found")
|
||||
|
||||
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
|
||||
prompt_file = orchestrator.prompt_file
|
||||
|
||||
if not prompt_file.exists():
|
||||
raise HTTPException(status_code=404, detail="Prompt file not found")
|
||||
|
||||
content = prompt_file.read_text()
|
||||
return {
|
||||
"orchestrator_id": orchestrator_id,
|
||||
"prompt_file": str(prompt_file),
|
||||
"content": content,
|
||||
"last_modified": prompt_file.stat().st_mtime
|
||||
}
|
||||
|
||||
@self.app.post("/api/orchestrators/{orchestrator_id}/prompt", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def update_orchestrator_prompt(orchestrator_id: str, request: PromptUpdateRequest):
|
||||
"""Update the prompt for an orchestrator."""
|
||||
if orchestrator_id not in self.monitor.active_orchestrators:
|
||||
raise HTTPException(status_code=404, detail="Orchestrator not found")
|
||||
|
||||
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
|
||||
prompt_file = orchestrator.prompt_file
|
||||
|
||||
try:
|
||||
# Create backup before updating
|
||||
backup_file = prompt_file.with_suffix(f".{int(time.time())}.backup")
|
||||
if prompt_file.exists():
|
||||
backup_file.write_text(prompt_file.read_text())
|
||||
|
||||
# Write the new content
|
||||
prompt_file.write_text(request.content)
|
||||
|
||||
# Notify the orchestrator of the update
|
||||
if hasattr(orchestrator, '_reload_prompt'):
|
||||
orchestrator._reload_prompt()
|
||||
|
||||
# Broadcast update to WebSocket clients
|
||||
await self.monitor._broadcast_to_clients({
|
||||
"type": "prompt_updated",
|
||||
"data": {
|
||||
"orchestrator_id": orchestrator_id,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"orchestrator_id": orchestrator_id,
|
||||
"backup_file": str(backup_file),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating prompt: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update prompt: {str(e)}") from e
|
||||
|
||||
@self.app.get("/api/metrics", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_metrics():
|
||||
"""Get system metrics."""
|
||||
return self.monitor.metrics_cache
|
||||
|
||||
@self.app.get("/api/history", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_history(limit: int = 50):
|
||||
"""Get execution history from database.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of runs to return (default 50)
|
||||
"""
|
||||
try:
|
||||
# Get recent runs from database
|
||||
history = self.monitor.database.get_recent_runs(limit=limit)
|
||||
return history
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching history from database: {e}")
|
||||
# Fallback to file-based history if database fails
|
||||
metrics_dir = Path(".agent") / "metrics"
|
||||
history = []
|
||||
|
||||
if metrics_dir.exists():
|
||||
for metrics_file in sorted(metrics_dir.glob("metrics_*.json")):
|
||||
try:
|
||||
data = json.loads(metrics_file.read_text())
|
||||
data["filename"] = metrics_file.name
|
||||
history.append(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading metrics file {metrics_file}: {e}")
|
||||
|
||||
return {"history": history[-50:]} # Return last 50 entries
|
||||
|
||||
@self.app.get("/api/history/{run_id}", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_run_details(run_id: int):
|
||||
"""Get detailed information about a specific run.
|
||||
|
||||
Args:
|
||||
run_id: ID of the run to retrieve
|
||||
"""
|
||||
run_details = self.monitor.database.get_run_details(run_id)
|
||||
if not run_details:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
return run_details
|
||||
|
||||
@self.app.get("/api/statistics", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def get_statistics():
|
||||
"""Get database statistics."""
|
||||
return self.monitor.database.get_statistics()
|
||||
|
||||
@self.app.post("/api/database/cleanup", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def cleanup_database(days: int = 30):
|
||||
"""Clean up old records from the database.
|
||||
|
||||
Args:
|
||||
days: Number of days of history to keep (default 30)
|
||||
"""
|
||||
try:
|
||||
self.monitor.database.cleanup_old_records(days=days)
|
||||
return {"status": "success", "message": f"Cleaned up records older than {days} days"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up database: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
# Admin endpoints for user management
|
||||
@self.app.post("/api/admin/users", dependencies=[Depends(require_admin)] if self.enable_auth else [])
|
||||
async def add_user(username: str, password: str, is_admin: bool = False):
|
||||
"""Add a new user (admin only)."""
|
||||
if auth_manager.add_user(username, password, is_admin):
|
||||
return {"status": "success", "message": f"User {username} created"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User already exists"
|
||||
)
|
||||
|
||||
@self.app.delete("/api/admin/users/{username}", dependencies=[Depends(require_admin)] if self.enable_auth else [])
|
||||
async def remove_user(username: str):
|
||||
"""Remove a user (admin only)."""
|
||||
if auth_manager.remove_user(username):
|
||||
return {"status": "success", "message": f"User {username} removed"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot remove user"
|
||||
)
|
||||
|
||||
@self.app.post("/api/auth/change-password", dependencies=[auth_dependency] if self.enable_auth else [])
|
||||
async def change_password(
|
||||
old_password: str,
|
||||
new_password: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user) if self.enable_auth else None
|
||||
):
|
||||
"""Change the current user's password."""
|
||||
if not self.enable_auth:
|
||||
raise HTTPException(status_code=404, detail="Authentication not enabled")
|
||||
|
||||
# Verify old password
|
||||
user = auth_manager.authenticate_user(current_user["username"], old_password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect old password"
|
||||
)
|
||||
|
||||
# Update password
|
||||
if auth_manager.update_password(current_user["username"], new_password):
|
||||
return {"status": "success", "message": "Password updated"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update password"
|
||||
)
|
||||
|
||||
@self.app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = None):
|
||||
"""WebSocket endpoint for real-time updates."""
|
||||
# Verify token if auth is enabled
|
||||
if self.enable_auth and token:
|
||||
try:
|
||||
auth_manager.verify_token(token)
|
||||
except HTTPException:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
elif self.enable_auth:
|
||||
# Auth is enabled but no token provided
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
self.monitor.websocket_clients.append(websocket)
|
||||
|
||||
# Send initial state
|
||||
await websocket.send_json({
|
||||
"type": "initial_state",
|
||||
"data": {
|
||||
"orchestrators": self.monitor.get_all_orchestrators_status(),
|
||||
"system_metrics": self.monitor.metrics_cache.get("system", {})
|
||||
}
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Keep connection alive and handle incoming messages
|
||||
data = await websocket.receive_text()
|
||||
# Handle ping/pong or other commands if needed
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
except WebSocketDisconnect:
|
||||
self.monitor.websocket_clients.remove(websocket)
|
||||
logger.info("WebSocket client disconnected")
|
||||
|
||||
def run(self):
|
||||
"""Run the web server."""
|
||||
logger.info(f"Starting web monitor on {self.host}:{self.port}")
|
||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||
|
||||
async def arun(self):
|
||||
"""Run the web server asynchronously."""
|
||||
logger.info(f"Starting web monitor on {self.host}:{self.port}")
|
||||
config = uvicorn.Config(app=self.app, host=self.host, port=self.port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
def register_orchestrator(self, orchestrator_id: str, orchestrator: RalphOrchestrator):
|
||||
"""Register an orchestrator with the monitor."""
|
||||
self.monitor.register_orchestrator(orchestrator_id, orchestrator)
|
||||
|
||||
def unregister_orchestrator(self, orchestrator_id: str):
|
||||
"""Unregister an orchestrator from the monitor."""
|
||||
self.monitor.unregister_orchestrator(orchestrator_id)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,320 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Ralph Orchestrator - Login</title>
|
||||
<style>
|
||||
:root {
|
||||
--primary-color: #4a5568;
|
||||
--secondary-color: #718096;
|
||||
--success-color: #48bb78;
|
||||
--warning-color: #ed8936;
|
||||
--danger-color: #f56565;
|
||||
--background: #f7fafc;
|
||||
--surface: #ffffff;
|
||||
--text-primary: #2d3748;
|
||||
--text-secondary: #718096;
|
||||
--border-color: #e2e8f0;
|
||||
}
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, var(--primary-color) 0%, var(--secondary-color) 100%);
|
||||
color: var(--text-primary);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.login-container {
|
||||
background: var(--surface);
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.15);
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
padding: 40px;
|
||||
}
|
||||
|
||||
.login-header {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.login-header h1 {
|
||||
color: var(--primary-color);
|
||||
font-size: 28px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.login-header p {
|
||||
color: var(--text-secondary);
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.form-group {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: var(--text-primary);
|
||||
font-weight: 500;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
input[type="text"],
|
||||
input[type="password"] {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
transition: border-color 0.3s, box-shadow 0.3s;
|
||||
background: var(--surface);
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
input[type="text"]:focus,
|
||||
input[type="password"]:focus {
|
||||
outline: none;
|
||||
border-color: var(--primary-color);
|
||||
box-shadow: 0 0 0 3px rgba(74, 85, 104, 0.1);
|
||||
}
|
||||
|
||||
.btn-login {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
background: var(--primary-color);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.3s, transform 0.2s;
|
||||
}
|
||||
|
||||
.btn-login:hover {
|
||||
background: #2d3748;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.btn-login:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.btn-login:disabled {
|
||||
background: var(--secondary-color);
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.error-message {
|
||||
background: rgba(245, 101, 101, 0.1);
|
||||
border: 1px solid var(--danger-color);
|
||||
color: var(--danger-color);
|
||||
padding: 10px;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.error-message.show {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.success-message {
|
||||
background: rgba(72, 187, 120, 0.1);
|
||||
border: 1px solid var(--success-color);
|
||||
color: var(--success-color);
|
||||
padding: 10px;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.success-message.show {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.form-footer {
|
||||
margin-top: 20px;
|
||||
text-align: center;
|
||||
color: var(--text-secondary);
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.loading-spinner {
|
||||
display: inline-block;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border: 2px solid rgba(255, 255, 255, 0.3);
|
||||
border-radius: 50%;
|
||||
border-top-color: white;
|
||||
animation: spin 0.6s linear infinite;
|
||||
margin-left: 8px;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.auth-info {
|
||||
background: rgba(74, 85, 104, 0.05);
|
||||
border-left: 3px solid var(--primary-color);
|
||||
padding: 12px;
|
||||
margin-top: 20px;
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.auth-info strong {
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
.login-container {
|
||||
padding: 30px 20px;
|
||||
}
|
||||
|
||||
.login-header h1 {
|
||||
font-size: 24px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="login-container">
|
||||
<div class="login-header">
|
||||
<h1>🤖 Ralph Orchestrator</h1>
|
||||
<p>Monitoring Dashboard Login</p>
|
||||
</div>
|
||||
|
||||
<div id="errorMessage" class="error-message"></div>
|
||||
<div id="successMessage" class="success-message"></div>
|
||||
|
||||
<form id="loginForm">
|
||||
<div class="form-group">
|
||||
<label for="username">Username</label>
|
||||
<input type="text" id="username" name="username" required autofocus>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" required>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="btn-login" id="loginButton">
|
||||
Login
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div class="auth-info">
|
||||
<strong>Default Credentials:</strong><br>
|
||||
Username: admin<br>
|
||||
Password: Set via RALPH_WEB_PASSWORD env variable<br>
|
||||
(Default: ralph-admin-2024)
|
||||
</div>
|
||||
|
||||
<div class="form-footer">
|
||||
Secure access to Ralph Orchestrator monitoring
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Check if already authenticated
|
||||
const token = localStorage.getItem('ralph_auth_token');
|
||||
if (token) {
|
||||
// Verify token is still valid
|
||||
fetch('/api/auth/verify', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${token}`
|
||||
}
|
||||
}).then(response => {
|
||||
if (response.ok) {
|
||||
// Token is valid, redirect to dashboard
|
||||
window.location.href = '/';
|
||||
} else {
|
||||
// Token is invalid, remove it
|
||||
localStorage.removeItem('ralph_auth_token');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Handle login form submission
|
||||
document.getElementById('loginForm').addEventListener('submit', async (e) => {
|
||||
e.preventDefault();
|
||||
|
||||
const username = document.getElementById('username').value;
|
||||
const password = document.getElementById('password').value;
|
||||
const loginButton = document.getElementById('loginButton');
|
||||
const errorMessage = document.getElementById('errorMessage');
|
||||
const successMessage = document.getElementById('successMessage');
|
||||
|
||||
// Reset messages
|
||||
errorMessage.classList.remove('show');
|
||||
successMessage.classList.remove('show');
|
||||
|
||||
// Disable button and show loading
|
||||
loginButton.disabled = true;
|
||||
loginButton.innerHTML = 'Logging in<span class="loading-spinner"></span>';
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/auth/login', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({ username, password })
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok) {
|
||||
// Store token
|
||||
localStorage.setItem('ralph_auth_token', data.access_token);
|
||||
|
||||
// Show success message
|
||||
successMessage.textContent = 'Login successful! Redirecting...';
|
||||
successMessage.classList.add('show');
|
||||
|
||||
// Redirect to dashboard
|
||||
setTimeout(() => {
|
||||
window.location.href = '/';
|
||||
}, 1000);
|
||||
} else {
|
||||
// Show error
|
||||
errorMessage.textContent = data.detail || 'Login failed. Please check your credentials.';
|
||||
errorMessage.classList.add('show');
|
||||
|
||||
// Re-enable button
|
||||
loginButton.disabled = false;
|
||||
loginButton.textContent = 'Login';
|
||||
}
|
||||
} catch (error) {
|
||||
// Show error
|
||||
errorMessage.textContent = 'Connection error. Please try again.';
|
||||
errorMessage.classList.add('show');
|
||||
|
||||
// Re-enable button
|
||||
loginButton.disabled = false;
|
||||
loginButton.textContent = 'Login';
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user