Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1,25 @@
# ABOUTME: Ralph Orchestrator package for AI agent orchestration
# ABOUTME: Implements the Ralph Wiggum technique with multi-tool support
"""Ralph Orchestrator - Simple AI agent orchestration."""
__version__ = "1.2.3"
from .orchestrator import RalphOrchestrator
from .metrics import Metrics, CostTracker, IterationStats
from .error_formatter import ClaudeErrorFormatter, ErrorMessage
from .verbose_logger import VerboseLogger
from .output import DiffStats, DiffFormatter, RalphConsole
__all__ = [
"RalphOrchestrator",
"Metrics",
"CostTracker",
"IterationStats",
"ClaudeErrorFormatter",
"ErrorMessage",
"VerboseLogger",
"DiffStats",
"DiffFormatter",
"RalphConsole",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
# ABOUTME: Tool adapter interfaces and implementations
# ABOUTME: Provides unified interface for Claude, Q Chat, Gemini, ACP, and other tools
"""Tool adapters for Ralph Orchestrator."""
from .base import ToolAdapter, ToolResponse
from .claude import ClaudeAdapter
from .qchat import QChatAdapter
from .kiro import KiroAdapter
from .gemini import GeminiAdapter
from .acp import ACPAdapter
from .acp_handlers import ACPHandlers, PermissionRequest, PermissionResult, Terminal
__all__ = [
"ToolAdapter",
"ToolResponse",
"ClaudeAdapter",
"QChatAdapter",
"KiroAdapter",
"GeminiAdapter",
"ACPAdapter",
"ACPHandlers",
"PermissionRequest",
"PermissionResult",
"Terminal",
]

View File

@@ -0,0 +1,984 @@
# ABOUTME: ACP Adapter for Agent Client Protocol integration
# ABOUTME: Provides subprocess-based communication with ACP-compliant agents like Gemini CLI
"""ACP (Agent Client Protocol) adapter for Ralph Orchestrator.
This adapter enables Ralph to use any ACP-compliant agent (like Gemini CLI)
as a backend for task execution. It manages the subprocess lifecycle,
handles the initialization handshake, and routes session messages.
"""
import asyncio
import logging
import os
import shutil
import signal
import threading
from typing import Optional
from .base import ToolAdapter, ToolResponse
from .acp_client import ACPClient, ACPClientError
from .acp_models import ACPAdapterConfig, ACPSession, UpdatePayload
from .acp_handlers import ACPHandlers
from ..output.console import RalphConsole
logger = logging.getLogger(__name__)
# ACP Protocol version this adapter supports (integer per spec)
ACP_PROTOCOL_VERSION = 1
class ACPAdapter(ToolAdapter):
"""Adapter for ACP-compliant agents like Gemini CLI.
Manages subprocess lifecycle, initialization handshake, and session
message routing for Agent Client Protocol communication.
Attributes:
agent_command: Command to spawn the agent (default: gemini).
agent_args: Additional arguments for agent command.
timeout: Request timeout in seconds.
permission_mode: How to handle permission requests.
"""
_TOOL_FIELD_ALIASES = {
"toolName": ("toolName", "tool_name", "name", "tool"),
"toolCallId": ("toolCallId", "tool_call_id", "id"),
"arguments": ("arguments", "args", "parameters", "params", "input"),
"status": ("status",),
"result": ("result",),
"error": ("error",),
}
def __init__(
self,
agent_command: str = "gemini",
agent_args: Optional[list[str]] = None,
timeout: int = 300,
permission_mode: str = "auto_approve",
permission_allowlist: Optional[list[str]] = None,
verbose: bool = False,
) -> None:
"""Initialize ACPAdapter.
Args:
agent_command: Command to spawn the agent (default: gemini).
agent_args: Additional command-line arguments.
timeout: Request timeout in seconds (default: 300).
permission_mode: Permission handling mode (default: auto_approve).
permission_allowlist: Patterns for allowlist mode.
verbose: Enable verbose streaming output (default: False).
"""
self.agent_command = agent_command
self.agent_args = agent_args or []
self.timeout = timeout
self.permission_mode = permission_mode
self.permission_allowlist = permission_allowlist or []
self.verbose = verbose
self._current_verbose = verbose # Per-request verbose flag
# Console for verbose output
self._console = RalphConsole()
# State
self._client: Optional[ACPClient] = None
self._session_id: Optional[str] = None
self._initialized = False
self._session: Optional[ACPSession] = None
# Create permission handlers
self._handlers = ACPHandlers(
permission_mode=permission_mode,
permission_allowlist=self.permission_allowlist,
on_permission_log=self._log_permission,
)
# Thread synchronization
self._lock = threading.Lock()
self._shutdown_requested = False
# Signal handlers
self._original_sigint = None
self._original_sigterm = None
# Call parent init - this will call check_availability()
super().__init__("acp")
# Register signal handlers
self._register_signal_handlers()
@classmethod
def from_config(cls, config: ACPAdapterConfig) -> "ACPAdapter":
"""Create ACPAdapter from configuration object.
Args:
config: ACPAdapterConfig with adapter settings.
Returns:
Configured ACPAdapter instance.
"""
return cls(
agent_command=config.agent_command,
agent_args=config.agent_args,
timeout=config.timeout,
permission_mode=config.permission_mode,
permission_allowlist=config.permission_allowlist,
)
def check_availability(self) -> bool:
"""Check if the agent command is available.
Returns:
True if agent command exists in PATH, False otherwise.
"""
return shutil.which(self.agent_command) is not None
def _register_signal_handlers(self) -> None:
"""Register signal handlers for graceful shutdown."""
try:
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
except ValueError as e:
logger.warning("Cannot register signal handlers (not in main thread): %s. Graceful shutdown via Ctrl+C will not work.", e)
def _restore_signal_handlers(self) -> None:
"""Restore original signal handlers."""
try:
if self._original_sigint is not None:
signal.signal(signal.SIGINT, self._original_sigint)
if self._original_sigterm is not None:
signal.signal(signal.SIGTERM, self._original_sigterm)
except (ValueError, TypeError) as e:
logger.warning("Failed to restore signal handlers: %s", e)
def _signal_handler(self, signum: int, frame) -> None:
"""Handle shutdown signals.
Terminates running subprocess synchronously (signal-safe),
then propagates to original handler (orchestrator).
Args:
signum: Signal number.
frame: Current stack frame.
"""
with self._lock:
self._shutdown_requested = True
# Kill subprocess synchronously (signal-safe)
self.kill_subprocess_sync()
# Propagate signal to original handler (orchestrator's handler)
original = self._original_sigint if signum == signal.SIGINT else self._original_sigterm
if original and callable(original):
original(signum, frame)
def kill_subprocess_sync(self) -> None:
"""Synchronously kill the agent subprocess (signal-safe).
This method is safe to call from signal handlers.
Uses non-blocking approach with immediate force kill after 2 seconds.
"""
if self._client and self._client._process:
try:
process = self._client._process
if process.returncode is None:
# Try graceful termination first
process.terminate()
# Non-blocking poll with timeout
import time
start = time.time()
timeout = 2.0
while time.time() - start < timeout:
if process.poll() is not None:
# Process terminated successfully
return
time.sleep(0.01) # Brief sleep to avoid busy-wait
# Timeout reached, force kill
try:
process.kill()
# Brief wait to ensure kill completes
time.sleep(0.1)
process.poll()
except Exception as e:
logger.debug("Exception during subprocess kill: %s", e)
except Exception as e:
logger.debug("Exception during subprocess kill: %s", e)
async def _initialize(self) -> None:
"""Initialize ACP connection with agent.
Performs the ACP initialization handshake:
1. Start ACPClient subprocess
2. Send initialize request with protocol version
3. Receive and validate initialize response
4. Send session/new request
5. Store session_id
Raises:
ACPClientError: If initialization fails.
"""
if self._initialized:
return
# Build effective args, auto-adding ACP flags for known agents
effective_args = list(self.agent_args)
# Gemini CLI requires --experimental-acp flag to enter ACP mode
# Also add --yolo to auto-approve internal tool executions
# And --allowed-tools to enable native Gemini tools
agent_basename = os.path.basename(self.agent_command)
if agent_basename == "gemini":
if "--experimental-acp" not in effective_args:
logger.info("Auto-adding --experimental-acp flag for Gemini CLI")
effective_args.append("--experimental-acp")
if "--yolo" not in effective_args:
logger.info("Auto-adding --yolo flag for Gemini CLI tool execution")
effective_args.append("--yolo")
# Enable native Gemini tools for ACP mode
# Note: Excluding write_file and run_shell_command - they have bugs in ACP mode
# Gemini should fall back to ACP's fs/write_text_file and terminal/create
if "--allowed-tools" not in effective_args:
logger.info("Auto-adding --allowed-tools for Gemini CLI native tools")
effective_args.extend([
"--allowed-tools",
"list_directory",
"read_many_files",
"read_file",
"web_fetch",
"google_web_search",
])
# Create and start client
self._client = ACPClient(
command=self.agent_command,
args=effective_args,
timeout=self.timeout,
)
await self._client.start()
# Register notification handler for session updates
self._client.on_notification(self._handle_notification)
# Register request handler for permission requests
self._client.on_request(self._handle_request)
try:
# Send initialize request (per ACP spec)
init_future = self._client.send_request(
"initialize",
{
"protocolVersion": ACP_PROTOCOL_VERSION,
"clientCapabilities": {
"fs": {
"readTextFile": True,
"writeTextFile": True,
},
"terminal": True,
},
"clientInfo": {
"name": "ralph-orchestrator",
"title": "Ralph Orchestrator",
"version": "1.2.3",
},
},
)
init_response = await asyncio.wait_for(init_future, timeout=self.timeout)
# Validate response
if "protocolVersion" not in init_response:
raise ACPClientError("Invalid initialize response: missing protocolVersion")
# Create new session (cwd and mcpServers are required per ACP spec)
session_future = self._client.send_request(
"session/new",
{
"cwd": os.getcwd(),
"mcpServers": [], # No MCP servers by default
},
)
session_response = await asyncio.wait_for(session_future, timeout=self.timeout)
# Store session ID
self._session_id = session_response.get("sessionId")
if not self._session_id:
raise ACPClientError("Invalid session/new response: missing sessionId")
# Create session state tracker
self._session = ACPSession(session_id=self._session_id)
self._initialized = True
except asyncio.TimeoutError:
await self._client.stop()
raise ACPClientError("Initialization timed out")
except Exception:
await self._client.stop()
raise
def _handle_notification(self, method: str, params: dict) -> None:
"""Handle notifications from agent.
Args:
method: Notification method name.
params: Notification parameters.
"""
if method == "session/update" and self._session:
# Handle both notification formats:
# Format 1 (flat): {"kind": "agent_message_chunk", "content": "..."}
# Format 2 (nested): {"update": {"sessionUpdate": "agent_message_chunk", "content": {...}}}
if "update" in params:
# Nested format (Gemini)
update = params["update"]
kind = update.get("sessionUpdate", "")
content_obj = update.get("content")
content = None
flat_params = {"kind": kind, "content": content}
allow_name_id = kind in ("tool_call", "tool_call_update")
# Extract text content if it's an object
if isinstance(content_obj, dict):
if "text" in content_obj:
content = content_obj.get("text", "")
flat_params["content"] = content
elif isinstance(content_obj, list):
for entry in content_obj:
if not isinstance(entry, dict):
continue
self._merge_tool_fields(flat_params, entry, allow_name_id=allow_name_id)
nested_tool_call = entry.get("toolCall") or entry.get("tool_call")
if isinstance(nested_tool_call, dict):
self._merge_tool_fields(
flat_params,
nested_tool_call,
allow_name_id=allow_name_id,
)
nested_tool = entry.get("tool")
if isinstance(nested_tool, dict):
self._merge_tool_fields(
flat_params,
nested_tool,
allow_name_id=allow_name_id,
)
else:
content = str(content_obj) if content_obj else ""
flat_params["content"] = content
self._merge_tool_fields(flat_params, update, allow_name_id=allow_name_id)
if isinstance(content_obj, dict):
self._merge_tool_fields(flat_params, content_obj, allow_name_id=allow_name_id)
nested_tool_call = content_obj.get("toolCall") or content_obj.get("tool_call")
if isinstance(nested_tool_call, dict):
self._merge_tool_fields(
flat_params,
nested_tool_call,
allow_name_id=allow_name_id,
)
nested_tool = content_obj.get("tool")
if isinstance(nested_tool, dict):
self._merge_tool_fields(
flat_params,
nested_tool,
allow_name_id=allow_name_id,
)
nested_tool_call = update.get("toolCall") or update.get("tool_call")
if isinstance(nested_tool_call, dict):
self._merge_tool_fields(
flat_params,
nested_tool_call,
allow_name_id=allow_name_id,
)
nested_tool = update.get("tool")
if isinstance(nested_tool, dict):
self._merge_tool_fields(
flat_params,
nested_tool,
allow_name_id=allow_name_id,
)
payload = UpdatePayload.from_dict(flat_params)
payload._raw = update
payload._raw_flat = flat_params
else:
# Flat format
payload = UpdatePayload.from_dict(params)
payload._raw = params
# Stream to console if verbose; always show tool calls
if self._current_verbose:
self._stream_update(payload, show_details=True)
elif payload.kind == "tool_call":
self._stream_update(payload, show_details=False)
self._session.process_update(payload)
def _merge_tool_fields(
self,
target: dict,
source: dict,
*,
allow_name_id: bool = False,
) -> None:
"""Merge tool call fields from source into target with alias support."""
for canonical, aliases in self._TOOL_FIELD_ALIASES.items():
if not allow_name_id and canonical in ("toolName", "toolCallId"):
aliases = tuple(
alias for alias in aliases if alias not in ("name", "id")
)
if canonical in target and target[canonical] not in (None, ""):
continue
for key in aliases:
if key in source:
value = source[key]
if key == "tool" and canonical == "toolName":
if isinstance(value, dict):
value = value.get("name") or value.get("toolName") or value.get("tool_name")
elif not isinstance(value, str):
value = None
if value is None or value == "":
continue
target[canonical] = value
break
def _format_agent_label(self) -> str:
"""Return the ACP agent command with arguments for display."""
if not self.agent_args:
return self.agent_command
return " ".join([self.agent_command, *self.agent_args])
def _format_payload_value(self, value: object, limit: int = 200) -> str:
"""Format payload values for console output."""
if value is None:
return ""
value_str = str(value)
if len(value_str) > limit:
return value_str[: limit - 3] + "..."
return value_str
def _format_payload_error(self, error: object) -> str:
"""Extract a readable error string from ACP error payloads."""
if error is None:
return ""
if isinstance(error, dict):
message = error.get("message") or error.get("error") or error.get("detail")
code = error.get("code")
data = error.get("data")
parts = []
if message:
parts.append(message)
if code is not None:
parts.append(f"code={code}")
if data and not message:
parts.append(str(data))
if parts:
return self._format_payload_value(" ".join(parts), limit=200)
return self._format_payload_value(error, limit=200)
def _get_raw_payload(self, payload: UpdatePayload) -> dict | None:
raw = getattr(payload, "_raw", None)
return raw if isinstance(raw, dict) else None
def _extract_tool_field(self, raw: dict | None, key: str) -> object:
if not isinstance(raw, dict):
return None
value = raw.get(key)
if value not in (None, ""):
return value
for nested_key in ("toolCall", "tool_call", "tool"):
nested = raw.get(nested_key)
if isinstance(nested, dict) and key in nested:
nested_value = nested.get(key)
if nested_value not in (None, ""):
return nested_value
return None
def _extract_tool_name_from_meta(self, raw: dict | None) -> str | None:
if not isinstance(raw, dict):
return None
meta = raw.get("_meta")
if not isinstance(meta, dict):
return None
for key in ("codex", "claudeCode", "agent", "acp"):
entry = meta.get(key)
if isinstance(entry, dict):
for name_key in ("toolName", "tool_name", "name", "tool"):
value = entry.get(name_key)
if isinstance(value, str) and value:
return value
for name_key in ("toolName", "tool_name", "name", "tool"):
value = meta.get(name_key)
if isinstance(value, str) and value:
return value
return None
def _extract_tool_response(self, raw: dict | None) -> object:
if not isinstance(raw, dict):
return None
meta = raw.get("_meta")
if not isinstance(meta, dict):
return None
for key in ("codex", "claudeCode", "agent", "acp"):
entry = meta.get(key)
if isinstance(entry, dict) and "toolResponse" in entry:
return entry.get("toolResponse")
if "toolResponse" in meta:
return meta.get("toolResponse")
return None
def _stream_update(self, payload: UpdatePayload, show_details: bool = True) -> None:
"""Stream session update to console.
Args:
payload: The update payload to stream.
show_details: Include detailed info (arguments, results, progress).
"""
kind = payload.kind
if kind == "agent_message_chunk":
# Stream agent output text
if payload.content:
self._console.print_message(payload.content)
elif kind == "agent_thought_chunk":
# Stream agent internal reasoning (dimmed)
if payload.content:
if self._console.console:
self._console.console.print(
f"[dim italic]{payload.content}[/dim italic]",
end="",
)
else:
print(payload.content, end="")
elif kind == "tool_call":
# Show tool call start
tool_name = payload.tool_name
raw_update = self._get_raw_payload(payload)
meta_tool_name = self._extract_tool_name_from_meta(raw_update)
title = self._extract_tool_field(raw_update, "title")
kind = self._extract_tool_field(raw_update, "kind")
raw_input = self._extract_tool_field(raw_update, "rawInput")
if raw_input is None:
raw_input = self._extract_tool_field(raw_update, "input")
tool_name = tool_name or meta_tool_name or title or kind or "unknown"
tool_id = payload.tool_call_id or "unknown"
self._console.print_separator()
self._console.print_status(f"TOOL CALL: {tool_name}", style="cyan bold")
self._console.print_info(f"ID: {tool_id[:12]}...")
self._console.print_info(f"Agent: {self._format_agent_label()}")
if show_details:
if title and title != tool_name:
self._console.print_info(f"Title: {title}")
if kind:
self._console.print_info(f"Kind: {kind}")
if tool_name == "unknown":
raw_str = self._format_payload_value(raw_update, limit=300)
if raw_str:
self._console.print_info(f"Update: {raw_str}")
if payload.arguments or raw_input:
input_value = payload.arguments or raw_input
if isinstance(input_value, dict):
self._console.print_info("Arguments:")
for key, value in input_value.items():
value_str = str(value)
if len(value_str) > 100:
value_str = value_str[:97] + "..."
self._console.print_info(f" - {key}: {value_str}")
else:
input_str = self._format_payload_value(input_value, limit=300)
if input_str:
self._console.print_info(f"Input: {input_str}")
elif kind == "tool_call_update":
if not show_details:
return
# Show tool call status update
tool_id = payload.tool_call_id or "unknown"
status = payload.status or "unknown"
tool_name = payload.tool_name
tool_args = None
tool_call = None
if self._session and payload.tool_call_id:
tool_call = self._session.get_tool_call(payload.tool_call_id)
if tool_call:
tool_name = tool_name or tool_call.tool_name
tool_args = tool_call.arguments or None
raw_update = self._get_raw_payload(payload)
meta_tool_name = self._extract_tool_name_from_meta(raw_update)
title = self._extract_tool_field(raw_update, "title")
kind = self._extract_tool_field(raw_update, "kind")
raw_input = self._extract_tool_field(raw_update, "rawInput")
raw_output = self._extract_tool_field(raw_update, "rawOutput")
tool_name = tool_name or meta_tool_name or title
display_name = tool_name or kind or "unknown"
if display_name == "unknown":
status_label = f"Tool call {tool_id[:12]}..."
else:
status_label = f"Tool {display_name} ({tool_id[:12]}...)"
if status == "completed":
self._console.print_success(
f"{status_label} completed"
)
result_value = payload.result
if result_value is None and tool_call:
result_value = tool_call.result
if result_value is None:
result_value = raw_output
if result_value is None:
result_value = self._extract_tool_response(raw_update)
result_str = self._format_payload_value(result_value)
if result_str:
self._console.print_info(f"Result: {result_str}")
elif status == "failed":
self._console.print_error(
f"{status_label} failed"
)
if display_name == "unknown":
raw_str = self._format_payload_value(raw_update, limit=300)
if raw_str:
self._console.print_info(f"Update: {raw_str}")
error_str = self._format_payload_error(payload.error)
if not error_str and payload.result is not None:
error_str = self._format_payload_value(payload.result)
if not error_str and raw_output is not None:
error_str = self._format_payload_value(raw_output)
if not error_str:
error_str = self._format_payload_value(
self._extract_tool_response(raw_update)
)
if error_str:
self._console.print_error(f"Error: {error_str}")
if tool_args or raw_input:
if tool_args is None:
tool_args = raw_input
self._console.print_info("Arguments:")
if isinstance(tool_args, dict):
for key, value in tool_args.items():
value_str = str(value)
if len(value_str) > 100:
value_str = value_str[:97] + "..."
self._console.print_info(f" - {key}: {value_str}")
else:
arg_str = self._format_payload_value(tool_args, limit=300)
if arg_str:
self._console.print_info(f" - {arg_str}")
elif status == "running":
self._console.print_status(
f"{status_label} running",
style="yellow",
)
progress_value = payload.result or payload.content
if progress_value is None:
progress_value = raw_output
progress_str = self._format_payload_value(progress_value, limit=200)
if progress_str:
self._console.print_info(f"Progress: {progress_str}")
else:
self._console.print_status(
f"{status_label} {status}",
style="yellow",
)
if title and title != display_name:
self._console.print_info(f"Title: {title}")
if kind:
self._console.print_info(f"Kind: {kind}")
def _handle_request(self, method: str, params: dict) -> dict:
"""Handle requests from agent.
Routes requests to appropriate handlers:
- session/request_permission: Permission checks
- fs/read_text_file: File read operations
- fs/write_text_file: File write operations
- terminal/*: Terminal operations
Args:
method: Request method name.
params: Request parameters.
Returns:
Response result dict.
"""
logger.info("ACP REQUEST: method=%s", method)
if method == "session/request_permission":
# Permission handler already returns ACP-compliant format
return self._handle_permission_request(params)
# File operations - return raw result (client wraps in JSON-RPC)
if method == "fs/read_text_file":
return self._handlers.handle_read_file(params)
if method == "fs/write_text_file":
return self._handlers.handle_write_file(params)
# Terminal operations - return raw result (client wraps in JSON-RPC)
if method == "terminal/create":
return self._handlers.handle_terminal_create(params)
if method == "terminal/output":
return self._handlers.handle_terminal_output(params)
if method == "terminal/wait_for_exit":
return self._handlers.handle_terminal_wait_for_exit(params)
if method == "terminal/kill":
return self._handlers.handle_terminal_kill(params)
if method == "terminal/release":
return self._handlers.handle_terminal_release(params)
# Unknown request - log and return error
logger.warning("Unknown ACP request method: %s with params: %s", method, params)
return {"error": {"code": -32601, "message": f"Method not found: {method}"}}
def _handle_permission_request(self, params: dict) -> dict:
"""Handle permission request from agent.
Delegates to ACPHandlers which supports multiple modes:
- auto_approve: Always approve
- deny_all: Always deny
- allowlist: Check against configured patterns
- interactive: Prompt user (if terminal available)
Args:
params: Permission request parameters.
Returns:
Response with approved: True/False.
"""
return self._handlers.handle_request_permission(params)
def _log_permission(self, message: str) -> None:
"""Log permission decision.
Args:
message: Permission decision message.
"""
logger.info(message)
def get_permission_history(self) -> list:
"""Get permission decision history.
Returns:
List of (request, result) tuples.
"""
return self._handlers.get_history()
def get_permission_stats(self) -> dict:
"""Get permission decision statistics.
Returns:
Dict with approved_count and denied_count.
"""
return {
"approved_count": self._handlers.get_approved_count(),
"denied_count": self._handlers.get_denied_count(),
}
async def _execute_prompt(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute a prompt through the ACP agent.
Sends session/prompt request with messages array and waits for response.
Session updates (streaming output, thoughts, tool calls) are processed
through _handle_notification during the request.
Args:
prompt: The prompt to execute.
**kwargs: Additional arguments (verbose: bool).
Returns:
ToolResponse with execution result.
"""
# Get verbose from kwargs (per-call override) without mutating instance state
verbose = kwargs.get("verbose", self.verbose)
# Store for use in _handle_notification during this request
self._current_verbose = verbose
# Reset session state for new prompt (preserve session_id)
if self._session:
self._session.reset()
# Print header if verbose
if verbose:
self._console.print_header(f"ACP AGENT ({self.agent_command})")
self._console.print_status("Processing prompt...")
# Build prompt array per ACP spec (ContentBlock format)
prompt_blocks = [{"type": "text", "text": prompt}]
# Send session/prompt request
try:
prompt_future = self._client.send_request(
"session/prompt",
{
"sessionId": self._session_id,
"prompt": prompt_blocks,
},
)
# Wait for response with timeout
response = await asyncio.wait_for(prompt_future, timeout=self.timeout)
# Check for error stop reason
stop_reason = response.get("stopReason", "unknown")
if stop_reason == "error":
error_obj = response.get("error", {})
error_msg = error_obj.get("message", "Unknown error from agent")
if verbose:
self._console.print_separator()
self._console.print_error(f"Agent error: {error_msg}")
return ToolResponse(
success=False,
output=self._session.output if self._session else "",
error=error_msg,
metadata={
"tool": "acp",
"agent": self.agent_command,
"session_id": self._session_id,
"stop_reason": stop_reason,
},
)
# Build successful response
output = self._session.output if self._session else ""
if verbose:
self._console.print_separator()
tool_count = len(self._session.tool_calls) if self._session else 0
self._console.print_success(f"Agent completed (tools: {tool_count})")
return ToolResponse(
success=True,
output=output,
metadata={
"tool": "acp",
"agent": self.agent_command,
"session_id": self._session_id,
"stop_reason": stop_reason,
"tool_calls_count": len(self._session.tool_calls) if self._session else 0,
"has_thoughts": bool(self._session.thoughts) if self._session else False,
},
)
except asyncio.TimeoutError:
if verbose:
self._console.print_separator()
self._console.print_error(f"Timeout after {self.timeout}s")
return ToolResponse(
success=False,
output=self._session.output if self._session else "",
error=f"Prompt execution timed out after {self.timeout} seconds",
metadata={
"tool": "acp",
"agent": self.agent_command,
"session_id": self._session_id,
},
)
async def _shutdown(self) -> None:
"""Shutdown the ACP connection.
Stops the client and cleans up state.
"""
# Kill all running terminals first
if self._handlers:
for terminal_id in list(self._handlers._terminals.keys()):
try:
self._handlers.handle_terminal_kill({"terminalId": terminal_id})
except Exception as e:
logger.warning("Failed to kill terminal %s: %s", terminal_id, e)
if self._client:
await self._client.stop()
self._client = None
self._initialized = False
self._session_id = None
self._session = None
def execute(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute the prompt synchronously.
Args:
prompt: The prompt to execute.
**kwargs: Additional arguments.
Returns:
ToolResponse with execution result.
"""
if not self.available:
return ToolResponse(
success=False,
output="",
error=f"ACP adapter not available: {self.agent_command} not found",
)
# Run async method in new event loop
try:
return asyncio.run(self.aexecute(prompt, **kwargs))
except Exception as e:
return ToolResponse(
success=False,
output="",
error=str(e),
)
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute the prompt asynchronously.
Args:
prompt: The prompt to execute.
**kwargs: Additional arguments.
Returns:
ToolResponse with execution result.
"""
if not self.available:
return ToolResponse(
success=False,
output="",
error=f"ACP adapter not available: {self.agent_command} not found",
)
try:
# Initialize if needed
if not self._initialized:
await self._initialize()
# Enhance prompt with orchestration instructions
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
# Execute prompt
return await self._execute_prompt(enhanced_prompt, **kwargs)
except ACPClientError as e:
return ToolResponse(
success=False,
output="",
error=f"ACP error: {e}",
)
except Exception as e:
return ToolResponse(
success=False,
output="",
error=str(e),
)
def estimate_cost(self, prompt: str) -> float:
"""Estimate execution cost.
ACP doesn't provide billing information, so returns 0.
Args:
prompt: The prompt to estimate.
Returns:
Always 0.0 (no billing info from ACP).
"""
return 0.0
def __del__(self) -> None:
"""Cleanup on deletion."""
self._restore_signal_handlers()
# Best-effort cleanup
if self._client:
try:
self.kill_subprocess_sync()
except Exception as e:
logger.debug("Exception during cleanup in __del__: %s", e)

View File

@@ -0,0 +1,352 @@
# ABOUTME: ACPClient manages subprocess lifecycle for ACP agents
# ABOUTME: Handles async message routing, request tracking, and graceful shutdown
"""ACPClient subprocess manager for ACP (Agent Client Protocol)."""
import asyncio
import logging
import os
from typing import Any, Callable, Optional
from .acp_protocol import ACPProtocol, MessageType
logger = logging.getLogger(__name__)
# Default asyncio StreamReader line limit (in bytes) for ACP agent stdout/stderr.
# Some ACP agents can emit large single-line JSON-RPC frames (e.g., big tool payloads).
# The asyncio default (~64KiB) can raise LimitOverrunError and break the ACP session.
DEFAULT_ACP_STREAM_LIMIT = 8 * 1024 * 1024 # 8 MiB
class ACPClientError(Exception):
"""Exception raised by ACPClient operations."""
pass
class ACPClient:
"""Manages subprocess lifecycle and async message routing for ACP agents.
Spawns an agent subprocess, handles JSON-RPC message serialization,
routes responses to pending requests, and invokes callbacks for
notifications and incoming requests.
Attributes:
command: The command to spawn the agent.
args: Additional command-line arguments.
timeout: Request timeout in seconds.
"""
def __init__(
self,
command: str,
args: Optional[list[str]] = None,
timeout: int = 300,
stream_limit: Optional[int] = None,
) -> None:
"""Initialize ACPClient.
Args:
command: The command to spawn the agent (e.g., "gemini").
args: Additional command-line arguments.
timeout: Request timeout in seconds (default: 300).
"""
self.command = command
self.args = args or []
self.timeout = timeout
# Limit used by asyncio for readline()/readuntil(); must be > max ACP frame line length.
env_limit = os.environ.get("RALPH_ACP_STREAM_LIMIT")
if stream_limit is not None:
self.stream_limit = stream_limit
elif env_limit:
try:
self.stream_limit = int(env_limit)
except ValueError:
logger.warning(
"Invalid RALPH_ACP_STREAM_LIMIT=%r; using default %d",
env_limit,
DEFAULT_ACP_STREAM_LIMIT,
)
self.stream_limit = DEFAULT_ACP_STREAM_LIMIT
else:
self.stream_limit = DEFAULT_ACP_STREAM_LIMIT
self._protocol = ACPProtocol()
self._process: Optional[asyncio.subprocess.Process] = None
self._read_task: Optional[asyncio.Task] = None
self._write_lock = asyncio.Lock()
# Pending requests: id -> Future
self._pending_requests: dict[int, asyncio.Future] = {}
# Notification handlers
self._notification_handlers: list[Callable[[str, dict], None]] = []
# Request handlers (for incoming requests from agent)
self._request_handlers: list[Callable[[str, dict], Any]] = []
@property
def is_running(self) -> bool:
"""Check if subprocess is running.
Returns:
True if subprocess is running, False otherwise.
"""
return self._process is not None and self._process.returncode is None
async def start(self) -> None:
"""Start the agent subprocess.
Spawns the subprocess with stdin/stdout/stderr pipes and starts
the read loop task.
Raises:
RuntimeError: If already running.
FileNotFoundError: If command not found.
"""
if self.is_running:
raise RuntimeError("ACPClient is already running")
cmd = [self.command] + self.args
self._process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
limit=self.stream_limit,
)
# Start the read loop
self._read_task = asyncio.create_task(self._read_loop())
async def stop(self) -> None:
"""Stop the agent subprocess.
Terminates the subprocess gracefully with 2 second timeout, then kills if necessary.
Cancels the read loop task and all pending requests.
"""
if not self.is_running:
return
# Cancel read loop first
if self._read_task and not self._read_task.done():
self._read_task.cancel()
try:
await asyncio.wait_for(self._read_task, timeout=0.5)
except asyncio.CancelledError:
logger.debug("Read task cancelled during shutdown")
except asyncio.TimeoutError:
logger.warning("Read task cancellation timed out")
# Terminate subprocess with 2 second timeout
if self._process:
try:
self._process.terminate()
try:
await asyncio.wait_for(self._process.wait(), timeout=2.0)
except asyncio.TimeoutError:
# Force kill if graceful termination fails
logger.warning("Process did not terminate gracefully, killing")
self._process.kill()
try:
await asyncio.wait_for(self._process.wait(), timeout=0.5)
except asyncio.TimeoutError:
logger.error("Process did not die after kill signal")
except ProcessLookupError:
logger.debug("Process already terminated")
self._process = None
self._read_task = None
# Cancel all pending requests
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
async def _read_loop(self) -> None:
"""Continuously read stdout and route messages.
Reads newline-delimited JSON-RPC messages from subprocess stdout.
"""
if not self._process or not self._process.stdout:
return
try:
while self.is_running:
line = await self._process.stdout.readline()
if not line:
break
message_str = line.decode().strip()
if message_str:
await self._handle_message(message_str)
except asyncio.CancelledError:
pass # Expected during shutdown
except Exception as e:
logger.error("ACP read loop failed: %s", e, exc_info=True)
finally:
# Cancel all pending requests when read loop exits (subprocess died or cancelled)
for future in self._pending_requests.values():
if not future.done():
future.set_exception(ACPClientError("Agent subprocess terminated"))
self._pending_requests.clear()
async def _handle_message(self, message_str: str) -> None:
"""Handle a received JSON-RPC message.
Routes message to appropriate handler based on type.
Args:
message_str: Raw JSON string.
"""
parsed = self._protocol.parse_message(message_str)
msg_type = parsed.get("type")
if msg_type == MessageType.RESPONSE:
# Route to pending request
request_id = parsed.get("id")
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(parsed.get("result"))
elif msg_type == MessageType.ERROR:
# Route error to pending request
request_id = parsed.get("id")
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
error = parsed.get("error", {})
error_msg = error.get("message", "Unknown error")
future.set_exception(ACPClientError(error_msg))
elif msg_type == MessageType.NOTIFICATION:
# Invoke notification handlers
method = parsed.get("method", "")
params = parsed.get("params", {})
for handler in self._notification_handlers:
try:
handler(method, params)
except Exception as e:
logger.error("Notification handler failed for method=%s: %s", method, e, exc_info=True)
elif msg_type == MessageType.REQUEST:
# Invoke request handlers and send response
request_id = parsed.get("id")
method = parsed.get("method", "")
params = parsed.get("params", {})
for handler in self._request_handlers:
try:
if asyncio.iscoroutinefunction(handler):
result = await handler(method, params)
else:
result = handler(method, params)
# Check if handler returned an error (dict with "error" key)
if isinstance(result, dict) and "error" in result:
error_info = result["error"]
error_code = error_info.get("code", -32603) if isinstance(error_info, dict) else -32603
error_msg = error_info.get("message", str(error_info)) if isinstance(error_info, dict) else str(error_info)
response = self._protocol.create_error_response(request_id, error_code, error_msg)
else:
response = self._protocol.create_response(request_id, result)
await self._write_message(response)
break # Only first handler responds
except Exception as e:
# Send error response
error_response = self._protocol.create_error_response(
request_id, -32603, str(e)
)
await self._write_message(error_response)
break
async def _write_message(self, message: str) -> None:
"""Write a JSON-RPC message to subprocess stdin.
Args:
message: JSON string to write.
Raises:
RuntimeError: If not running.
"""
if not self.is_running or not self._process or not self._process.stdin:
raise RuntimeError("ACPClient is not running")
async with self._write_lock:
self._process.stdin.write((message + "\n").encode())
await self._process.stdin.drain()
async def _do_send(self, request_id: int, message: str) -> None:
"""Helper to send message and handle write errors.
Args:
request_id: The request ID.
message: The JSON-RPC message to send.
"""
try:
await self._write_message(message)
except Exception as e:
future = self._pending_requests.pop(request_id, None)
if future and not future.done():
future.set_exception(ACPClientError(f"Failed to send request: {e}"))
def send_request(
self, method: str, params: dict[str, Any]
) -> asyncio.Future[Any]:
"""Send a JSON-RPC request and return Future for response.
Args:
method: The RPC method name.
params: The request parameters.
Returns:
Future that resolves with the response result.
"""
request_id, message = self._protocol.create_request(method, params)
# Create future for response
loop = asyncio.get_running_loop()
future: asyncio.Future[Any] = loop.create_future()
self._pending_requests[request_id] = future
# Schedule write with error handling
asyncio.create_task(self._do_send(request_id, message))
return future
async def send_notification(
self, method: str, params: dict[str, Any]
) -> None:
"""Send a JSON-RPC notification (no response expected).
Args:
method: The notification method name.
params: The notification parameters.
"""
message = self._protocol.create_notification(method, params)
await self._write_message(message)
def on_notification(
self, handler: Callable[[str, dict], None]
) -> None:
"""Register a notification handler.
Args:
handler: Callback invoked with (method, params) for each notification.
"""
self._notification_handlers.append(handler)
def on_request(
self, handler: Callable[[str, dict], Any]
) -> None:
"""Register a request handler for incoming requests from agent.
Handler should return the response result. Can be sync or async.
Args:
handler: Callback invoked with (method, params), returns result.
"""
self._request_handlers.append(handler)

View File

@@ -0,0 +1,889 @@
# ABOUTME: ACP handlers for permission requests and file/terminal operations
# ABOUTME: Provides permission_mode handling (auto_approve, deny_all, allowlist, interactive)
# ABOUTME: Implements fs/read_text_file and fs/write_text_file handlers with security
# ABOUTME: Implements terminal/* handlers for command execution
"""ACP Handlers for permission requests and agent-to-host operations.
This module provides the ACPHandlers class which manages permission requests
from ACP-compliant agents and handles file operations. It supports:
Permission modes:
- auto_approve: Approve all requests automatically
- deny_all: Deny all requests
- allowlist: Only approve requests matching configured patterns
- interactive: Prompt user for each request (requires terminal)
File operations:
- fs/read_text_file: Read file content with security validation
- fs/write_text_file: Write file content with security validation
Terminal operations:
- terminal/create: Create a new terminal with command
- terminal/output: Read output from a terminal
- terminal/wait_for_exit: Wait for terminal process to exit
- terminal/kill: Kill a terminal process
- terminal/release: Release terminal resources
"""
import fnmatch
import logging
import re
import subprocess
import sys
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
logger = logging.getLogger(__name__)
@dataclass
class Terminal:
"""Represents a terminal subprocess.
Attributes:
id: Unique identifier for the terminal.
process: The subprocess.Popen instance.
output_buffer: Accumulated output from stdout/stderr.
"""
id: str
process: subprocess.Popen
output_buffer: str = ""
@property
def is_running(self) -> bool:
"""Check if the process is still running."""
return self.process.poll() is None
@property
def exit_code(self) -> Optional[int]:
"""Get the exit code if process has exited."""
return self.process.poll()
def read_output(self) -> str:
"""Read any available output without blocking.
Returns:
New output since last read.
"""
import select
new_output = ""
# Try to read from stdout and stderr
for stream in [self.process.stdout, self.process.stderr]:
if stream is None:
continue
try:
# Non-blocking read using select
while True:
ready, _, _ = select.select([stream], [], [], 0)
if not ready:
break
chunk = stream.read(4096)
if chunk:
new_output += chunk
else:
break
except (OSError, IOError) as e:
logger.debug("Error reading terminal output: %s", e)
break
self.output_buffer += new_output
return new_output
def kill(self) -> None:
"""Kill the subprocess."""
if self.is_running:
self.process.terminate()
try:
self.process.wait(timeout=1.0)
except subprocess.TimeoutExpired:
self.process.kill()
self.process.wait()
def wait(self, timeout: Optional[float] = None) -> int:
"""Wait for the process to exit.
Args:
timeout: Maximum time to wait in seconds.
Returns:
Exit code of the process.
Raises:
subprocess.TimeoutExpired: If timeout is reached.
"""
return self.process.wait(timeout=timeout)
@dataclass
class PermissionRequest:
"""Represents a permission request from an agent.
Attributes:
operation: The operation being requested (e.g., 'fs/read_text_file').
path: Optional path for filesystem operations.
command: Optional command for terminal operations.
arguments: Full arguments dict from the request.
"""
operation: str
path: Optional[str] = None
command: Optional[str] = None
arguments: dict = field(default_factory=dict)
@classmethod
def from_params(cls, params: dict) -> "PermissionRequest":
"""Create PermissionRequest from request parameters.
Args:
params: Permission request parameters from agent.
Returns:
Parsed PermissionRequest instance.
"""
return cls(
operation=params.get("operation", ""),
path=params.get("path"),
command=params.get("command"),
arguments=params,
)
@dataclass
class PermissionResult:
"""Result of a permission decision.
Attributes:
approved: Whether the request was approved.
reason: Optional reason for the decision.
mode: Permission mode that made the decision.
"""
approved: bool
reason: Optional[str] = None
mode: str = "unknown"
def to_dict(self) -> dict:
"""Convert to ACP response format.
Returns:
Dict with 'approved' key for ACP response.
"""
return {"approved": self.approved}
class ACPHandlers:
"""Handles ACP permission requests with configurable modes.
Supports four permission modes:
- auto_approve: Always approve (useful for trusted environments)
- deny_all: Always deny (useful for testing)
- allowlist: Only approve operations matching configured patterns
- interactive: Prompt user for each request
Attributes:
permission_mode: Current permission mode.
allowlist: List of allowed operation patterns.
on_permission_log: Optional callback for logging decisions.
"""
# Valid permission modes
VALID_MODES = ("auto_approve", "deny_all", "allowlist", "interactive")
def __init__(
self,
permission_mode: str = "auto_approve",
permission_allowlist: Optional[list[str]] = None,
on_permission_log: Optional[Callable[[str], None]] = None,
) -> None:
"""Initialize ACPHandlers.
Args:
permission_mode: Permission handling mode (default: auto_approve).
permission_allowlist: List of allowed operation patterns for allowlist mode.
on_permission_log: Optional callback for logging permission decisions.
Raises:
ValueError: If permission_mode is not valid.
"""
if permission_mode not in self.VALID_MODES:
raise ValueError(
f"Invalid permission_mode: {permission_mode}. "
f"Must be one of: {', '.join(self.VALID_MODES)}"
)
self.permission_mode = permission_mode
self.allowlist = permission_allowlist or []
self.on_permission_log = on_permission_log
# Track permission history for debugging
self._history: list[tuple[PermissionRequest, PermissionResult]] = []
# Track active terminals
self._terminals: dict[str, Terminal] = {}
def handle_request_permission(self, params: dict) -> dict:
"""Handle a permission request from an agent.
Returns ACP-compliant response with nested outcome structure.
The agent handles tool execution after receiving permission.
Args:
params: Permission request parameters including options list.
Returns:
Dict with result.outcome.outcome (selected/cancelled) and optionId.
"""
request = PermissionRequest.from_params(params)
result = self._evaluate_permission(request)
# Log the decision
self._log_decision(request, result)
# Store in history
self._history.append((request, result))
# Extract options from params to find the appropriate optionId
options = params.get("options", [])
if result.approved:
# Find first "allow" option to use as optionId
selected_option_id = None
for option in options:
if option.get("type") == "allow":
selected_option_id = option.get("id")
break
# Fallback to first option if no "allow" type found
if not selected_option_id and options:
selected_option_id = options[0].get("id", "proceed_once")
elif not selected_option_id:
# Default if no options provided
selected_option_id = "proceed_once"
# Return raw result (client wraps in JSON-RPC response)
return {
"outcome": {
"outcome": "selected",
"optionId": selected_option_id
}
}
else:
# Permission denied - return cancelled outcome
return {
"outcome": {
"outcome": "cancelled"
}
}
def _evaluate_permission(self, request: PermissionRequest) -> PermissionResult:
"""Evaluate a permission request based on current mode.
Args:
request: The permission request to evaluate.
Returns:
PermissionResult with decision and reason.
"""
if self.permission_mode == "auto_approve":
return PermissionResult(
approved=True,
reason="auto_approve mode",
mode="auto_approve",
)
if self.permission_mode == "deny_all":
return PermissionResult(
approved=False,
reason="deny_all mode",
mode="deny_all",
)
if self.permission_mode == "allowlist":
return self._evaluate_allowlist(request)
if self.permission_mode == "interactive":
return self._evaluate_interactive(request)
# Fallback - should not reach here
return PermissionResult(
approved=False,
reason="unknown mode",
mode="unknown",
)
def _evaluate_allowlist(self, request: PermissionRequest) -> PermissionResult:
"""Evaluate permission against allowlist patterns.
Patterns can be:
- Exact match: 'fs/read_text_file'
- Glob pattern: 'fs/*' (matches any fs operation)
- Regex pattern: '/^terminal\\/.*$/' (surrounded by slashes)
Args:
request: The permission request to evaluate.
Returns:
PermissionResult with decision.
"""
operation = request.operation
for pattern in self.allowlist:
if self._matches_pattern(operation, pattern):
return PermissionResult(
approved=True,
reason=f"matches allowlist pattern: {pattern}",
mode="allowlist",
)
return PermissionResult(
approved=False,
reason="no matching allowlist pattern",
mode="allowlist",
)
def _matches_pattern(self, operation: str, pattern: str) -> bool:
"""Check if an operation matches a pattern.
Args:
operation: The operation name to check.
pattern: Pattern to match against.
Returns:
True if operation matches pattern.
"""
# Check for regex pattern (surrounded by slashes)
if pattern.startswith("/") and pattern.endswith("/"):
try:
regex_pattern = pattern[1:-1]
return bool(re.match(regex_pattern, operation))
except re.error as e:
logger.warning("Invalid regex pattern '%s' in permission allowlist: %s", pattern, e)
return False
# Check for glob pattern
if "*" in pattern or "?" in pattern:
return fnmatch.fnmatch(operation, pattern)
# Exact match
return operation == pattern
def _evaluate_interactive(self, request: PermissionRequest) -> PermissionResult:
"""Evaluate permission interactively by prompting user.
Falls back to deny_all if no terminal is available.
Args:
request: The permission request to evaluate.
Returns:
PermissionResult with user's decision.
"""
# Check if we have a terminal
if not sys.stdin.isatty():
return PermissionResult(
approved=False,
reason="no terminal available for interactive mode",
mode="interactive",
)
# Format the prompt
prompt = self._format_interactive_prompt(request)
try:
print(prompt, file=sys.stderr)
response = input("[y/N]: ").strip().lower()
if response in ("y", "yes"):
return PermissionResult(
approved=True,
reason="user approved",
mode="interactive",
)
else:
return PermissionResult(
approved=False,
reason="user denied",
mode="interactive",
)
except (EOFError, KeyboardInterrupt):
return PermissionResult(
approved=False,
reason="input interrupted",
mode="interactive",
)
def _format_interactive_prompt(self, request: PermissionRequest) -> str:
"""Format an interactive permission prompt.
Args:
request: The permission request to display.
Returns:
Formatted prompt string.
"""
lines = [
"",
"=" * 60,
f"Permission Request: {request.operation}",
"=" * 60,
]
if request.path:
lines.append(f" Path: {request.path}")
if request.command:
lines.append(f" Command: {request.command}")
# Add other arguments
for key, value in request.arguments.items():
if key not in ("operation", "path", "command"):
lines.append(f" {key}: {value}")
lines.extend([
"=" * 60,
"Approve this operation?",
])
return "\n".join(lines)
def _log_decision(
self, request: PermissionRequest, result: PermissionResult
) -> None:
"""Log a permission decision.
Args:
request: The permission request.
result: The permission decision.
"""
if self.on_permission_log:
status = "APPROVED" if result.approved else "DENIED"
message = (
f"Permission {status}: {request.operation} "
f"[mode={result.mode}, reason={result.reason}]"
)
self.on_permission_log(message)
def get_history(self) -> list[tuple[PermissionRequest, PermissionResult]]:
"""Get permission decision history.
Returns:
List of (request, result) tuples.
"""
return self._history.copy()
def clear_history(self) -> None:
"""Clear permission decision history."""
self._history.clear()
def get_approved_count(self) -> int:
"""Get count of approved permissions.
Returns:
Number of approved permission requests.
"""
return sum(1 for _, result in self._history if result.approved)
def get_denied_count(self) -> int:
"""Get count of denied permissions.
Returns:
Number of denied permission requests.
"""
return sum(1 for _, result in self._history if not result.approved)
# =========================================================================
# File Operation Handlers
# =========================================================================
def handle_read_file(self, params: dict) -> dict:
"""Handle fs/read_text_file request from agent.
Reads file content with security validation to prevent path traversal.
Args:
params: Request parameters with 'path' key.
Returns:
Dict with 'content' on success, or 'error' on failure.
"""
path_str = params.get("path")
if not path_str:
return {"error": {"code": -32602, "message": "Missing required parameter: path"}}
try:
# Resolve the path
path = Path(path_str)
# Security: require absolute path
if not path.is_absolute():
return {
"error": {
"code": -32602,
"message": f"Path must be absolute: {path_str}",
}
}
# Resolve symlinks and normalize
resolved_path = path.resolve()
# Check if file exists - return null content for non-existent files
# (this allows agents to check file existence without error)
if not resolved_path.exists():
return {"content": None, "exists": False}
# Check if it's a file (not directory)
if not resolved_path.is_file():
return {
"error": {
"code": -32002,
"message": f"Path is not a file: {path_str}",
}
}
# Read file content
content = resolved_path.read_text(encoding="utf-8")
return {"content": content}
except PermissionError:
return {
"error": {
"code": -32003,
"message": f"Permission denied: {path_str}",
}
}
except UnicodeDecodeError:
return {
"error": {
"code": -32004,
"message": f"File is not valid UTF-8 text: {path_str}",
}
}
except OSError as e:
return {
"error": {
"code": -32000,
"message": f"Failed to read file: {e}",
}
}
def handle_write_file(self, params: dict) -> dict:
"""Handle fs/write_text_file request from agent.
Writes content to file with security validation.
Args:
params: Request parameters with 'path' and 'content' keys.
Returns:
Dict with 'success: True' on success, or 'error' on failure.
"""
path_str = params.get("path")
content = params.get("content")
if not path_str:
return {"error": {"code": -32602, "message": "Missing required parameter: path"}}
if content is None:
return {"error": {"code": -32602, "message": "Missing required parameter: content"}}
try:
# Resolve the path
path = Path(path_str)
# Security: require absolute path
if not path.is_absolute():
return {
"error": {
"code": -32602,
"message": f"Path must be absolute: {path_str}",
}
}
# Resolve symlinks and normalize
resolved_path = path.resolve()
# Check if path exists and is a directory (can't write to directory)
if resolved_path.exists() and resolved_path.is_dir():
return {
"error": {
"code": -32002,
"message": f"Path is a directory: {path_str}",
}
}
# Create parent directories if needed
resolved_path.parent.mkdir(parents=True, exist_ok=True)
# Write file content
resolved_path.write_text(content, encoding="utf-8")
return {"success": True}
except PermissionError:
return {
"error": {
"code": -32003,
"message": f"Permission denied: {path_str}",
}
}
except OSError as e:
return {
"error": {
"code": -32000,
"message": f"Failed to write file: {e}",
}
}
# =========================================================================
# Terminal Operation Handlers
# =========================================================================
def handle_terminal_create(self, params: dict) -> dict:
"""Handle terminal/create request from agent.
Creates a new terminal subprocess for command execution.
Args:
params: Request parameters with 'command' (list of strings) and
optional 'cwd' (working directory).
Returns:
Dict with 'terminalId' on success, or 'error' on failure.
"""
command = params.get("command")
if command is None:
return {
"error": {
"code": -32602,
"message": "Missing required parameter: command",
}
}
if not isinstance(command, list):
return {
"error": {
"code": -32602,
"message": "command must be a list of strings",
}
}
if len(command) == 0:
return {
"error": {
"code": -32602,
"message": "command list cannot be empty",
}
}
cwd = params.get("cwd")
try:
# Create subprocess with pipes for stdout/stderr
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.DEVNULL,
cwd=cwd,
text=True,
bufsize=0,
)
# Generate unique terminal ID
terminal_id = str(uuid.uuid4())
# Create terminal instance
terminal = Terminal(id=terminal_id, process=process)
self._terminals[terminal_id] = terminal
return {"terminalId": terminal_id}
except FileNotFoundError:
return {
"error": {
"code": -32001,
"message": f"Command not found: {command[0]}",
}
}
except PermissionError:
return {
"error": {
"code": -32003,
"message": f"Permission denied executing: {command[0]}",
}
}
except OSError as e:
return {
"error": {
"code": -32000,
"message": f"Failed to create terminal: {e}",
}
}
def handle_terminal_output(self, params: dict) -> dict:
"""Handle terminal/output request from agent.
Reads available output from a terminal.
Args:
params: Request parameters with 'terminalId'.
Returns:
Dict with 'output' and 'done' on success, or 'error' on failure.
"""
terminal_id = params.get("terminalId")
if not terminal_id:
return {
"error": {
"code": -32602,
"message": "Missing required parameter: terminalId",
}
}
terminal = self._terminals.get(terminal_id)
if not terminal:
return {
"error": {
"code": -32001,
"message": f"Terminal not found: {terminal_id}",
}
}
# Read any new output
terminal.read_output()
return {
"output": terminal.output_buffer,
"done": not terminal.is_running,
}
def handle_terminal_wait_for_exit(self, params: dict) -> dict:
"""Handle terminal/wait_for_exit request from agent.
Waits for a terminal process to exit.
Args:
params: Request parameters with 'terminalId' and optional 'timeout'.
Returns:
Dict with 'exitCode' on success, or 'error' on failure/timeout.
"""
terminal_id = params.get("terminalId")
if not terminal_id:
return {
"error": {
"code": -32602,
"message": "Missing required parameter: terminalId",
}
}
terminal = self._terminals.get(terminal_id)
if not terminal:
return {
"error": {
"code": -32001,
"message": f"Terminal not found: {terminal_id}",
}
}
timeout = params.get("timeout")
try:
exit_code = terminal.wait(timeout=timeout)
# Read any remaining output
terminal.read_output()
return {"exitCode": exit_code}
except subprocess.TimeoutExpired:
return {
"error": {
"code": -32000,
"message": f"Wait timed out after {timeout}s",
}
}
def handle_terminal_kill(self, params: dict) -> dict:
"""Handle terminal/kill request from agent.
Kills a terminal process.
Args:
params: Request parameters with 'terminalId'.
Returns:
Dict with 'success: True' on success, or 'error' on failure.
"""
terminal_id = params.get("terminalId")
if not terminal_id:
return {
"error": {
"code": -32602,
"message": "Missing required parameter: terminalId",
}
}
terminal = self._terminals.get(terminal_id)
if not terminal:
return {
"error": {
"code": -32001,
"message": f"Terminal not found: {terminal_id}",
}
}
terminal.kill()
return {"success": True}
def handle_terminal_release(self, params: dict) -> dict:
"""Handle terminal/release request from agent.
Releases terminal resources, killing the process if still running.
Args:
params: Request parameters with 'terminalId'.
Returns:
Dict with 'success: True' on success, or 'error' on failure.
"""
terminal_id = params.get("terminalId")
if not terminal_id:
return {
"error": {
"code": -32602,
"message": "Missing required parameter: terminalId",
}
}
terminal = self._terminals.get(terminal_id)
if not terminal:
return {
"error": {
"code": -32001,
"message": f"Terminal not found: {terminal_id}",
}
}
# Kill the process if still running before releasing
if terminal.is_running:
try:
terminal.kill()
except Exception as e:
logger.warning("Failed to kill terminal %s during release: %s", terminal_id, e)
# Clean up the terminal
del self._terminals[terminal_id]
return {"success": True}

View File

@@ -0,0 +1,536 @@
# ABOUTME: Typed dataclasses for ACP (Agent Client Protocol) messages
# ABOUTME: Defines request/response types, session state, and configuration
"""Typed data models for ACP messages and session state."""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from ralph_orchestrator.main import AdapterConfig
# Get logger for this module
_logger = logging.getLogger(__name__)
class UpdateKind(str, Enum):
"""Types of session updates."""
AGENT_MESSAGE_CHUNK = "agent_message_chunk"
AGENT_THOUGHT_CHUNK = "agent_thought_chunk"
TOOL_CALL = "tool_call"
TOOL_CALL_UPDATE = "tool_call_update"
PLAN = "plan"
class ToolCallStatus(str, Enum):
"""Status of a tool call."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class PermissionMode(str, Enum):
"""Permission modes for ACP operations."""
AUTO_APPROVE = "auto_approve"
DENY_ALL = "deny_all"
ALLOWLIST = "allowlist"
INTERACTIVE = "interactive"
@dataclass
class ACPRequest:
"""JSON-RPC 2.0 request message.
Attributes:
id: Request identifier for matching response.
method: The RPC method to invoke.
params: Method parameters.
"""
id: int
method: str
params: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ACPRequest":
"""Parse ACPRequest from dict.
Args:
data: Dict with id, method, and optional params.
Returns:
ACPRequest instance.
Raises:
KeyError: If id or method is missing.
"""
return cls(
id=data["id"],
method=data["method"],
params=data.get("params", {}),
)
@dataclass
class ACPNotification:
"""JSON-RPC 2.0 notification (no response expected).
Attributes:
method: The notification method.
params: Method parameters.
"""
method: str
params: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ACPNotification":
"""Parse ACPNotification from dict.
Args:
data: Dict with method and optional params.
Returns:
ACPNotification instance.
"""
return cls(
method=data["method"],
params=data.get("params", {}),
)
@dataclass
class ACPResponse:
"""JSON-RPC 2.0 success response.
Attributes:
id: Request identifier this response matches.
result: The response result data.
"""
id: int
result: Any
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ACPResponse":
"""Parse ACPResponse from dict.
Args:
data: Dict with id and result.
Returns:
ACPResponse instance.
"""
return cls(
id=data["id"],
result=data.get("result"),
)
@dataclass
class ACPErrorObject:
"""JSON-RPC 2.0 error object.
Attributes:
code: Error code (negative integers for standard errors).
message: Human-readable error message.
data: Optional additional error data.
"""
code: int
message: str
data: Any = None
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ACPErrorObject":
"""Parse ACPErrorObject from dict.
Args:
data: Dict with code, message, and optional data.
Returns:
ACPErrorObject instance.
"""
return cls(
code=data["code"],
message=data["message"],
data=data.get("data"),
)
@dataclass
class ACPError:
"""JSON-RPC 2.0 error response.
Attributes:
id: Request identifier this error matches.
error: The error object with code and message.
"""
id: int
error: ACPErrorObject
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ACPError":
"""Parse ACPError from dict.
Args:
data: Dict with id and error object.
Returns:
ACPError instance.
"""
return cls(
id=data["id"],
error=ACPErrorObject.from_dict(data["error"]),
)
@dataclass
class UpdatePayload:
"""Payload for session/update notifications.
Handles different update kinds:
- agent_message_chunk: Text output from agent
- agent_thought_chunk: Internal reasoning (verbose mode)
- tool_call: Agent requesting tool execution
- tool_call_update: Status update for tool call
- plan: Agent's execution plan
Attributes:
kind: The update type (see UpdateKind enum for valid values).
content: Text content (for message/thought chunks).
tool_name: Name of tool being called.
tool_call_id: Unique identifier for tool call.
arguments: Tool call arguments.
status: Tool call status (see ToolCallStatus enum for valid values).
result: Tool call result data.
error: Tool call error message.
"""
kind: str # Valid values: UpdateKind enum members
content: Optional[str] = None
tool_name: Optional[str] = None
tool_call_id: Optional[str] = None
arguments: Optional[dict[str, Any]] = None
status: Optional[str] = None
result: Optional[Any] = None
error: Optional[str] = None
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "UpdatePayload":
"""Parse UpdatePayload from dict.
Handles camelCase to snake_case conversion for ACP fields.
Args:
data: Dict with kind and kind-specific fields.
Returns:
UpdatePayload instance.
"""
kind = data["kind"]
tool_name = data.get("toolName")
tool_call_id = data.get("toolCallId")
arguments = data.get("arguments")
if kind in (UpdateKind.TOOL_CALL, UpdateKind.TOOL_CALL_UPDATE):
if tool_name is None:
tool_name = data.get("tool_name") or data.get("name")
if tool_name is None and isinstance(data.get("tool"), str):
tool_name = data.get("tool")
if tool_call_id is None:
tool_call_id = data.get("tool_call_id") or data.get("id")
if kind == UpdateKind.TOOL_CALL and arguments is None:
arguments = data.get("args") or data.get("parameters") or data.get("params")
return cls(
kind=kind,
content=data.get("content"),
tool_name=tool_name,
tool_call_id=tool_call_id,
arguments=arguments,
status=data.get("status"),
result=data.get("result"),
error=data.get("error"),
)
@dataclass
class SessionUpdate:
"""Wrapper for session/update notification.
Attributes:
method: Should be "session/update".
payload: The update payload.
"""
method: str
payload: UpdatePayload
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SessionUpdate":
"""Parse SessionUpdate from notification dict.
Args:
data: Dict with method and params.
Returns:
SessionUpdate instance.
"""
return cls(
method=data["method"],
payload=UpdatePayload.from_dict(data["params"]),
)
@dataclass
class ToolCall:
"""Tracks a tool execution within a session.
Attributes:
tool_call_id: Unique identifier for this call.
tool_name: Name of the tool being called.
arguments: Arguments passed to the tool.
status: Current status (pending/running/completed/failed).
result: Result data if completed.
error: Error message if failed.
"""
tool_call_id: str
tool_name: str
arguments: dict[str, Any]
status: str = "pending"
result: Optional[Any] = None
error: Optional[str] = None
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ToolCall":
"""Parse ToolCall from dict.
Args:
data: Dict with toolCallId, toolName, arguments.
Returns:
ToolCall instance.
"""
return cls(
tool_call_id=data["toolCallId"],
tool_name=data["toolName"],
arguments=data.get("arguments", {}),
status=data.get("status", "pending"),
result=data.get("result"),
error=data.get("error"),
)
@dataclass
class ACPSession:
"""Accumulates session state during prompt execution.
Tracks output chunks, thoughts, and tool calls for building
the final response.
Attributes:
session_id: Unique session identifier.
output: Accumulated agent output text.
thoughts: Accumulated agent thoughts (verbose).
tool_calls: List of tool calls in this session.
"""
session_id: str
output: str = ""
thoughts: str = ""
tool_calls: list[ToolCall] = field(default_factory=list)
def append_output(self, text: str) -> None:
"""Append text to accumulated output.
Args:
text: Text chunk to append.
"""
self.output += text
def append_thought(self, text: str) -> None:
"""Append text to accumulated thoughts.
Args:
text: Thought chunk to append.
"""
self.thoughts += text
def add_tool_call(self, tool_call: ToolCall) -> None:
"""Add a new tool call to track.
Args:
tool_call: The ToolCall to track.
"""
self.tool_calls.append(tool_call)
def get_tool_call(self, tool_call_id: str) -> Optional[ToolCall]:
"""Find a tool call by ID.
Args:
tool_call_id: The ID to look up.
Returns:
ToolCall if found, None otherwise.
"""
for tc in self.tool_calls:
if tc.tool_call_id == tool_call_id:
return tc
return None
def process_update(self, payload: UpdatePayload) -> None:
"""Process a session update payload.
Routes update to appropriate handler based on kind.
Args:
payload: The update payload to process.
"""
if payload.kind == "agent_message_chunk":
if payload.content:
self.append_output(payload.content)
elif payload.kind == "agent_thought_chunk":
if payload.content:
self.append_thought(payload.content)
elif payload.kind == "tool_call":
tool_call = ToolCall(
tool_call_id=payload.tool_call_id or "",
tool_name=payload.tool_name or "",
arguments=payload.arguments or {},
)
self.add_tool_call(tool_call)
elif payload.kind == "tool_call_update":
if payload.tool_call_id:
tc = self.get_tool_call(payload.tool_call_id)
if tc:
if payload.status:
tc.status = payload.status
if payload.result is not None:
tc.result = payload.result
if payload.error:
tc.error = payload.error
def reset(self) -> None:
"""Reset session state for a new prompt.
Preserves session_id but clears accumulated data.
"""
self.output = ""
self.thoughts = ""
self.tool_calls = []
@dataclass
class ACPAdapterConfig:
"""Configuration for the ACP adapter.
Attributes:
agent_command: Command to spawn the agent (default: gemini).
agent_args: Additional arguments for agent command.
timeout: Request timeout in seconds.
permission_mode: How to handle permission requests.
- auto_approve: Approve all requests.
- deny_all: Deny all requests.
- allowlist: Check against permission_allowlist.
- interactive: Prompt user for each request.
permission_allowlist: Patterns to allow in allowlist mode.
"""
agent_command: str = "gemini"
agent_args: list[str] = field(default_factory=list)
timeout: int = 300
permission_mode: str = "auto_approve"
permission_allowlist: list[str] = field(default_factory=list)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ACPAdapterConfig":
"""Parse ACPAdapterConfig from dict.
Uses defaults for missing keys.
Args:
data: Configuration dict.
Returns:
ACPAdapterConfig instance.
"""
return cls(
agent_command=data.get("agent_command", "gemini"),
agent_args=data.get("agent_args", []),
timeout=data.get("timeout", 300),
permission_mode=data.get("permission_mode", "auto_approve"),
permission_allowlist=data.get("permission_allowlist", []),
)
@classmethod
def from_adapter_config(cls, adapter_config: "AdapterConfig") -> "ACPAdapterConfig":
"""Create ACPAdapterConfig from AdapterConfig with env var overrides.
Extracts ACP-specific settings from AdapterConfig.tool_permissions
and applies environment variable overrides.
Environment Variables:
RALPH_ACP_AGENT: Override agent_command
RALPH_ACP_PERMISSION_MODE: Override permission_mode
RALPH_ACP_TIMEOUT: Override timeout (integer)
Args:
adapter_config: General adapter configuration.
Returns:
ACPAdapterConfig with ACP-specific settings.
"""
# Start with tool_permissions or empty dict
tool_perms = adapter_config.tool_permissions or {}
# Get base values from tool_permissions
agent_command = tool_perms.get("agent_command", "gemini")
agent_args = tool_perms.get("agent_args", [])
timeout = tool_perms.get("timeout", adapter_config.timeout)
permission_mode = tool_perms.get("permission_mode", "auto_approve")
permission_allowlist = tool_perms.get("permission_allowlist", [])
# Apply environment variable overrides
if env_agent := os.environ.get("RALPH_ACP_AGENT"):
agent_command = env_agent
if env_mode := os.environ.get("RALPH_ACP_PERMISSION_MODE"):
valid_modes = {"auto_approve", "deny_all", "allowlist", "interactive"}
if env_mode in valid_modes:
permission_mode = env_mode
else:
_logger.warning(
"Invalid RALPH_ACP_PERMISSION_MODE value '%s'. Valid modes: %s. Using default: %s",
env_mode,
", ".join(valid_modes),
permission_mode,
)
if env_timeout := os.environ.get("RALPH_ACP_TIMEOUT"):
try:
timeout = int(env_timeout)
except ValueError:
_logger.warning(
"Invalid RALPH_ACP_TIMEOUT value '%s' - must be integer. Using default: %d",
env_timeout,
timeout,
)
return cls(
agent_command=agent_command,
agent_args=agent_args,
timeout=timeout,
permission_mode=permission_mode,
permission_allowlist=permission_allowlist,
)

View File

@@ -0,0 +1,214 @@
# ABOUTME: JSON-RPC 2.0 protocol handler for ACP (Agent Client Protocol)
# ABOUTME: Handles message serialization, parsing, and protocol state
"""JSON-RPC 2.0 protocol handling for ACP."""
import json
from enum import Enum, auto
from typing import Any
class MessageType(Enum):
"""Types of JSON-RPC 2.0 messages."""
REQUEST = auto() # Has id and method
NOTIFICATION = auto() # Has method but no id
RESPONSE = auto() # Has id and result
ERROR = auto() # Has id and error
PARSE_ERROR = auto() # Failed to parse JSON
INVALID = auto() # Invalid JSON-RPC message
class ACPErrorCodes:
"""Standard JSON-RPC 2.0 and ACP-specific error codes."""
# Standard JSON-RPC 2.0 error codes
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_PARAMS = -32602
INTERNAL_ERROR = -32603
# ACP-specific error codes
PERMISSION_DENIED = -32001
FILE_NOT_FOUND = -32002
FILE_ACCESS_ERROR = -32003
TERMINAL_ERROR = -32004
class ACPProtocol:
"""JSON-RPC 2.0 protocol handler for ACP.
Handles serialization and deserialization of JSON-RPC messages
for the Agent Client Protocol.
Attributes:
_request_id: Auto-incrementing request ID counter.
"""
JSONRPC_VERSION = "2.0"
def __init__(self) -> None:
"""Initialize protocol handler with request ID counter at 0."""
self._request_id: int = 0
def create_request(self, method: str, params: dict[str, Any]) -> tuple[int, str]:
"""Create a JSON-RPC 2.0 request message.
Args:
method: The RPC method name (e.g., "session/prompt").
params: The request parameters.
Returns:
Tuple of (request_id, json_string) for tracking and sending.
"""
self._request_id += 1
request_id = self._request_id
message = {
"jsonrpc": self.JSONRPC_VERSION,
"id": request_id,
"method": method,
"params": params,
}
return request_id, json.dumps(message)
def create_notification(self, method: str, params: dict[str, Any]) -> str:
"""Create a JSON-RPC 2.0 notification message (no id, no response expected).
Args:
method: The RPC method name.
params: The notification parameters.
Returns:
JSON string of the notification.
"""
message = {
"jsonrpc": self.JSONRPC_VERSION,
"method": method,
"params": params,
}
return json.dumps(message)
def parse_message(self, data: str) -> dict[str, Any]:
"""Parse an incoming JSON-RPC 2.0 message.
Determines the message type and validates structure.
Args:
data: Raw JSON string to parse.
Returns:
Dict with 'type' key indicating MessageType and parsed fields.
On error, includes 'error' key with description.
"""
# Try to parse JSON
try:
message = json.loads(data)
except json.JSONDecodeError as e:
return {
"type": MessageType.PARSE_ERROR,
"error": f"JSON parse error: {e}",
}
# Validate jsonrpc version field
if message.get("jsonrpc") != self.JSONRPC_VERSION:
return {
"type": MessageType.INVALID,
"error": f"Invalid or missing jsonrpc field. Expected '2.0', got '{message.get('jsonrpc')}'",
}
# Determine message type based on fields
has_id = "id" in message
has_method = "method" in message
has_result = "result" in message
has_error = "error" in message
if has_error and has_id:
# Error response
return {
"type": MessageType.ERROR,
"id": message["id"],
"error": message["error"],
}
elif has_result and has_id:
# Success response
return {
"type": MessageType.RESPONSE,
"id": message["id"],
"result": message["result"],
}
elif has_method and has_id:
# Request (has id, expects response)
return {
"type": MessageType.REQUEST,
"id": message["id"],
"method": message["method"],
"params": message.get("params", {}),
}
elif has_method and not has_id:
# Notification (no id, no response expected)
return {
"type": MessageType.NOTIFICATION,
"method": message["method"],
"params": message.get("params", {}),
}
else:
return {
"type": MessageType.INVALID,
"error": "Invalid JSON-RPC message structure",
}
def create_response(self, request_id: int, result: Any) -> str:
"""Create a JSON-RPC 2.0 success response.
Args:
request_id: The ID from the original request.
result: The result data to return.
Returns:
JSON string of the response.
"""
message = {
"jsonrpc": self.JSONRPC_VERSION,
"id": request_id,
"result": result,
}
return json.dumps(message)
def create_error_response(
self,
request_id: int,
code: int,
message: str,
data: Any = None,
) -> str:
"""Create a JSON-RPC 2.0 error response.
Args:
request_id: The ID from the original request.
code: The error code (use ACPErrorCodes constants).
message: Human-readable error message.
data: Optional additional error data.
Returns:
JSON string of the error response.
"""
error_obj: dict[str, Any] = {
"code": code,
"message": message,
}
if data is not None:
error_obj["data"] = data
response = {
"jsonrpc": self.JSONRPC_VERSION,
"id": request_id,
"error": error_obj,
}
return json.dumps(response)

View File

@@ -0,0 +1,189 @@
# ABOUTME: Abstract base class for tool adapters
# ABOUTME: Defines the interface all tool adapters must implement
"""Base adapter interface for AI tools."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Dict, Any
from pathlib import Path
import asyncio
@dataclass
class ToolResponse:
"""Response from a tool execution."""
success: bool
output: str
error: Optional[str] = None
tokens_used: Optional[int] = None
cost: Optional[float] = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class ToolAdapter(ABC):
"""Abstract base class for tool adapters."""
def __init__(self, name: str, config=None):
self.name = name
self.config = config or type('Config', (), {
'enabled': True, 'timeout': 300, 'max_retries': 3,
'args': [], 'env': {}
})()
self.completion_promise: Optional[str] = None
self.available = self.check_availability()
@abstractmethod
def check_availability(self) -> bool:
"""Check if the tool is available and properly configured."""
pass
@abstractmethod
def execute(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute the tool with the given prompt."""
pass
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
"""Async execute the tool with the given prompt.
Default implementation runs sync execute in thread pool.
Subclasses can override for native async support.
"""
loop = asyncio.get_event_loop()
# Create a function that can be called with no arguments for run_in_executor
def execute_with_args():
return self.execute(prompt, **kwargs)
return await loop.run_in_executor(None, execute_with_args)
def execute_with_file(self, prompt_file: Path, **kwargs) -> ToolResponse:
"""Execute the tool with a prompt file."""
if not prompt_file.exists():
return ToolResponse(
success=False,
output="",
error=f"Prompt file {prompt_file} not found"
)
with open(prompt_file, 'r') as f:
prompt = f.read()
return self.execute(prompt, **kwargs)
async def aexecute_with_file(self, prompt_file: Path, **kwargs) -> ToolResponse:
"""Async execute the tool with a prompt file."""
if not prompt_file.exists():
return ToolResponse(
success=False,
output="",
error=f"Prompt file {prompt_file} not found"
)
# Use asyncio.to_thread to avoid blocking the event loop with file I/O
prompt = await asyncio.to_thread(prompt_file.read_text, encoding='utf-8')
return await self.aexecute(prompt, **kwargs)
def estimate_cost(self, prompt: str) -> float:
"""Estimate the cost of executing this prompt."""
# Default implementation - subclasses can override
return 0.0
def _enhance_prompt_with_instructions(self, prompt: str, completion_promise: Optional[str] = None) -> str:
"""Enhance prompt with orchestration context and instructions.
Args:
prompt: The original prompt
completion_promise: Optional string that signals task completion.
If not provided, uses self.completion_promise.
Returns:
Enhanced prompt with orchestration instructions
"""
# Resolve completion promise: argument takes precedence, then instance state
promise = completion_promise or self.completion_promise
# Check if instructions already exist in the prompt
instruction_markers = [
"ORCHESTRATION CONTEXT:",
"IMPORTANT INSTRUCTIONS:",
"Implement only ONE small, focused task"
]
# If any marker exists, assume instructions are already present
instructions_present = False
for marker in instruction_markers:
if marker in prompt:
instructions_present = True
break
enhanced_prompt = prompt
if not instructions_present:
# Add orchestration context and instructions
orchestration_instructions = """
ORCHESTRATION CONTEXT:
You are running within the Ralph Orchestrator loop. This system will call you repeatedly
for multiple iterations until the overall task is complete. Each iteration is a separate
execution where you should make incremental progress.
The final output must be well-tested, documented, and production ready.
IMPORTANT INSTRUCTIONS:
1. Implement only ONE small, focused task from this prompt per iteration.
- Each iteration is independent - focus on a single atomic change
- The orchestrator will handle calling you again for the next task
- Mark subtasks complete as you finish them
- You must commit your changes after each iteration, for checkpointing.
2. Use the .agent/workspace/ directory for any temporary files or workspaces if not already instructed in the prompt.
3. Follow this workflow for implementing features:
- Explore: Research and understand the codebase
- Plan: Design your implementation approach
- Implement: Use Test-Driven Development (TDD) - write tests first, then code
- Commit: Commit your changes with clear messages
4. When you complete a subtask, document it in the prompt file so the next iteration knows what's done.
5. For maximum efficiency, whenever you need to perform multiple independent operations, invoke all relevant tools simultaneously rather than sequentially.
6. If you create any temporary new files, scripts, or helper files for iteration, clean up these files by removing them at the end of the task.
## Agent Scratchpad
Before starting your work, check if .agent/scratchpad.md exists in the current working directory.
If it does, read it to understand what was accomplished in previous iterations and continue from there.
At the end of your iteration, update .agent/scratchpad.md with:
- What you accomplished this iteration
- What remains to be done
- Any important context or decisions made
- Current blockers or issues (if any)
Do NOT restart from scratch if the scratchpad shows previous progress. Continue where the previous iteration left off.
Create the .agent/ directory if it doesn't exist.
---
ORIGINAL PROMPT:
"""
enhanced_prompt = orchestration_instructions + prompt
# Inject completion promise if requested and not present
if promise:
# Check if promise is already in the text (simple check)
# We check both the raw promise and the section header to be safe,
# but mainly we want to avoid duplicating the specific instruction.
if promise not in enhanced_prompt:
promise_section = f"""
## Completion Promise
When you have completed the task, output this exact line:
{promise}
"""
enhanced_prompt += promise_section
return enhanced_prompt
def __str__(self) -> str:
return f"{self.name} (available: {self.available})"

View File

@@ -0,0 +1,616 @@
# ABOUTME: Claude SDK adapter implementation
# ABOUTME: Provides integration with Anthropic's Claude via Python SDK
# ABOUTME: Supports inheriting user's Claude Code settings (MCP servers, CLAUDE.md, etc.)
"""Claude SDK adapter for Ralph Orchestrator."""
import asyncio
import logging
import os
import signal
from typing import Optional
from .base import ToolAdapter, ToolResponse
from ..error_formatter import ClaudeErrorFormatter
from ..output.console import RalphConsole
# Setup logging
logger = logging.getLogger(__name__)
try:
from claude_agent_sdk import ClaudeAgentOptions, query
CLAUDE_SDK_AVAILABLE = True
except ImportError:
# Fallback to old package name for backwards compatibility
try:
from claude_code_sdk import ClaudeCodeOptions as ClaudeAgentOptions, query
CLAUDE_SDK_AVAILABLE = True
except ImportError:
CLAUDE_SDK_AVAILABLE = False
query = None
ClaudeAgentOptions = None
class ClaudeAdapter(ToolAdapter):
"""Adapter for Claude using the Python SDK."""
# Default max buffer size: 10MB (handles large screenshots from chrome-devtools-mcp)
DEFAULT_MAX_BUFFER_SIZE = 10 * 1024 * 1024
# Default model: Claude Opus 4.5 (most intelligent model)
DEFAULT_MODEL = "claude-opus-4-5-20251101"
# Model pricing (per million tokens)
MODEL_PRICING = {
"claude-opus-4-5-20251101": {"input": 5.0, "output": 25.0},
"claude-sonnet-4-5-20250929": {"input": 3.0, "output": 15.0},
"claude-haiku-4-5-20251001": {"input": 1.0, "output": 5.0},
# Legacy models
"claude-3-opus": {"input": 15.0, "output": 75.0},
"claude-3-sonnet": {"input": 3.0, "output": 15.0},
"claude-3-haiku": {"input": 0.25, "output": 1.25},
}
def __init__(self, verbose: bool = False, max_buffer_size: int = None,
inherit_user_settings: bool = True, cli_path: str = None,
model: str = None):
super().__init__("claude")
self.sdk_available = CLAUDE_SDK_AVAILABLE
self._system_prompt = None
self._allowed_tools = None
self._disallowed_tools = None
self._enable_all_tools = False
self._enable_web_search = True # Enable WebSearch by default
self._max_buffer_size = max_buffer_size or self.DEFAULT_MAX_BUFFER_SIZE
self.verbose = verbose
# Enable loading user's Claude Code settings (including MCP servers) by default
self._inherit_user_settings = inherit_user_settings
# Optional path to user's Claude Code CLI (uses bundled CLI if not specified)
self._cli_path = cli_path
self._model = model or self.DEFAULT_MODEL
self._subprocess_pid: Optional[int] = None
self._console = RalphConsole()
def check_availability(self) -> bool:
"""Check if Claude SDK is available and properly configured."""
# Claude Code SDK works without API key - it uses the local environment
return CLAUDE_SDK_AVAILABLE
def configure(self,
system_prompt: Optional[str] = None,
allowed_tools: Optional[list] = None,
disallowed_tools: Optional[list] = None,
enable_all_tools: bool = False,
enable_web_search: bool = True,
inherit_user_settings: Optional[bool] = None,
cli_path: Optional[str] = None,
model: Optional[str] = None):
"""Configure the Claude adapter with custom options.
Args:
system_prompt: Custom system prompt for Claude
allowed_tools: List of allowed tools for Claude to use (if None and enable_all_tools=True, all tools are enabled)
disallowed_tools: List of disallowed tools
enable_all_tools: If True and allowed_tools is None, enables all native Claude tools
enable_web_search: If True, explicitly enables WebSearch tool (default: True)
inherit_user_settings: If True, load user's Claude Code settings including MCP servers (default: True)
cli_path: Path to user's Claude Code CLI (uses bundled CLI if not specified)
model: Model to use (default: claude-opus-4-5-20251101)
"""
self._system_prompt = system_prompt
self._allowed_tools = allowed_tools
self._disallowed_tools = disallowed_tools
self._enable_all_tools = enable_all_tools
self._enable_web_search = enable_web_search
# Update user settings inheritance if specified
if inherit_user_settings is not None:
self._inherit_user_settings = inherit_user_settings
# Update CLI path if specified
if cli_path is not None:
self._cli_path = cli_path
# Update model if specified
if model is not None:
self._model = model
# If web search is enabled and we have an allowed tools list, add WebSearch to it
if enable_web_search and allowed_tools is not None and 'WebSearch' not in allowed_tools:
self._allowed_tools = allowed_tools + ['WebSearch']
def execute(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute Claude with the given prompt synchronously.
This is a blocking wrapper around the async implementation.
"""
try:
# Create new event loop if needed
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.aexecute(prompt, **kwargs))
else:
# If loop is already running, schedule as task
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.aexecute(prompt, **kwargs))
return future.result()
except Exception as e:
# Use error formatter for user-friendly error messages
error_msg = ClaudeErrorFormatter.format_error_from_exception(
iteration=kwargs.get('iteration', 0),
exception=e
)
return ToolResponse(
success=False,
output="",
error=str(error_msg)
)
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute Claude with the given prompt asynchronously."""
if not self.available:
logger.debug("Claude SDK not available")
return ToolResponse(
success=False,
output="",
error="Claude SDK is not available"
)
try:
# Get configuration from kwargs or use defaults
prompt_file = kwargs.get('prompt_file', 'PROMPT.md')
# Build options for Claude Code
options_dict = {}
# Set system prompt with orchestration context
system_prompt = kwargs.get('system_prompt', self._system_prompt)
if not system_prompt:
# Create a default system prompt with orchestration context
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
system_prompt = (
f"You are helping complete a task. "
f"The task is described in the file '{prompt_file}'. "
f"Please edit this file directly to add your solution and progress updates."
)
# Use the enhanced prompt as the main prompt
prompt = enhanced_prompt
else:
# If custom system prompt provided, still enhance the main prompt
prompt = self._enhance_prompt_with_instructions(prompt)
options_dict['system_prompt'] = system_prompt
# Set tool restrictions if provided
# If enable_all_tools is True and no allowed_tools specified, don't set any restrictions
enable_all_tools = kwargs.get('enable_all_tools', self._enable_all_tools)
enable_web_search = kwargs.get('enable_web_search', self._enable_web_search)
allowed_tools = kwargs.get('allowed_tools', self._allowed_tools)
disallowed_tools = kwargs.get('disallowed_tools', self._disallowed_tools)
# Add WebSearch to allowed tools if web search is enabled
if enable_web_search and allowed_tools is not None and 'WebSearch' not in allowed_tools:
allowed_tools = allowed_tools + ['WebSearch']
# Only set tool restrictions if we're not enabling all tools or if specific tools are provided
if not enable_all_tools or allowed_tools:
if allowed_tools:
options_dict['allowed_tools'] = allowed_tools
if disallowed_tools:
options_dict['disallowed_tools'] = disallowed_tools
# If enable_all_tools is True and no allowed_tools, Claude will have access to all native tools
if enable_all_tools and not allowed_tools:
if self.verbose:
logger.debug("Enabling all native Claude tools (including WebSearch)")
# Set permission mode - default to bypassPermissions for smoother operation
permission_mode = kwargs.get('permission_mode', 'bypassPermissions')
options_dict['permission_mode'] = permission_mode
if self.verbose:
logger.debug(f"Permission mode: {permission_mode}")
# Set current working directory to ensure files are created in the right place
import os
cwd = kwargs.get('cwd', os.getcwd())
options_dict['cwd'] = cwd
if self.verbose:
logger.debug(f"Working directory: {cwd}")
# Set max buffer size for handling large responses (e.g., screenshots)
max_buffer_size = kwargs.get('max_buffer_size', self._max_buffer_size)
options_dict['max_buffer_size'] = max_buffer_size
if self.verbose:
logger.debug(f"Max buffer size: {max_buffer_size} bytes")
# Configure setting sources to inherit user's Claude Code configuration
# This enables MCP servers, CLAUDE.md files, and other user settings
inherit_user_settings = kwargs.get('inherit_user_settings', self._inherit_user_settings)
if inherit_user_settings:
# Load user, project, and local settings (includes MCP servers)
options_dict['setting_sources'] = ['user', 'project', 'local']
if self.verbose:
logger.debug("Inheriting user's Claude Code settings (MCP servers, CLAUDE.md, etc.)")
# Optional: use user's installed Claude Code CLI instead of bundled
cli_path = kwargs.get('cli_path', self._cli_path)
if cli_path:
options_dict['cli_path'] = cli_path
if self.verbose:
logger.debug(f"Using custom Claude CLI: {cli_path}")
# Set model - defaults to Opus 4.5
model = kwargs.get('model', self._model)
options_dict['model'] = model
if self.verbose:
logger.debug(f"Using model: {model}")
# Create options
options = ClaudeAgentOptions(**options_dict)
# Log request details if verbose
if self.verbose:
logger.debug("Claude SDK Request:")
logger.debug(f" Prompt length: {len(prompt)} characters")
logger.debug(f" System prompt: {system_prompt}")
if allowed_tools:
logger.debug(f" Allowed tools: {allowed_tools}")
if disallowed_tools:
logger.debug(f" Disallowed tools: {disallowed_tools}")
# Collect all response chunks
output_chunks = []
tokens_used = 0
chunk_count = 0
# Use one-shot query for simpler execution
if self.verbose:
logger.debug("Starting Claude SDK query...")
self._console.print_header("CLAUDE PROCESSING")
async for message in query(prompt=prompt, options=options):
chunk_count += 1
msg_type = type(message).__name__
if self.verbose:
print(f"\n[DEBUG: Received {msg_type}]", flush=True)
logger.debug(f"Received message type: {msg_type}")
# Handle different message types
if msg_type == 'AssistantMessage':
# Extract content from AssistantMessage
if hasattr(message, 'content') and message.content:
for content_block in message.content:
block_type = type(content_block).__name__
if hasattr(content_block, 'text'):
# TextBlock
text = content_block.text
output_chunks.append(text)
if self.verbose and text:
self._console.print_message(text)
logger.debug(f"Received assistant text: {len(text)} characters")
elif block_type == 'ToolUseBlock':
if self.verbose:
tool_name = getattr(content_block, 'name', 'unknown')
tool_id = getattr(content_block, 'id', 'unknown')
tool_input = getattr(content_block, 'input', {})
self._console.print_separator()
self._console.print_status(f"TOOL USE: {tool_name}", style="cyan bold")
self._console.print_info(f"ID: {tool_id[:12]}...")
if tool_input:
self._console.print_info("Input Parameters:")
for key, value in tool_input.items():
value_str = str(value)
if len(value_str) > 100:
value_str = value_str[:97] + "..."
self._console.print_info(f" - {key}: {value_str}")
logger.debug(f"Tool use detected: {tool_name} (id: {tool_id[:8]}...)")
if hasattr(content_block, 'input'):
logger.debug(f" Tool input: {content_block.input}")
else:
if self.verbose:
logger.debug(f"Unknown content block type: {block_type}")
elif msg_type == 'ResultMessage':
# ResultMessage contains final result and usage stats
if hasattr(message, 'result'):
# Don't append result - it's usually a duplicate of assistant message
if self.verbose:
logger.debug(f"Result message received: {len(str(message.result))} characters")
# Extract token usage from ResultMessage
if hasattr(message, 'usage'):
usage = message.usage
if isinstance(usage, dict):
tokens_used = usage.get('input_tokens', 0) + usage.get('output_tokens', 0)
else:
tokens_used = getattr(usage, 'total_tokens', 0)
if self.verbose:
logger.debug(f"Token usage: {tokens_used} tokens")
elif msg_type == 'SystemMessage':
# SystemMessage is initialization data, skip it
if self.verbose:
logger.debug("System initialization message received")
elif msg_type == 'UserMessage':
if self.verbose:
logger.debug("User message (tool result) received")
if hasattr(message, 'content'):
content = message.content
if isinstance(content, list):
for content_item in content:
if hasattr(content_item, '__class__'):
item_type = content_item.__class__.__name__
if item_type == 'ToolResultBlock':
tool_use_id = getattr(content_item, 'tool_use_id', 'unknown')
result_content = getattr(content_item, 'content', None)
is_error = getattr(content_item, 'is_error', False)
self._console.print_separator()
self._console.print_status("TOOL RESULT", style="yellow bold")
self._console.print_info(f"For Tool ID: {tool_use_id[:12]}...")
if is_error:
self._console.print_error("Status: ERROR")
else:
self._console.print_success("Status: Success")
if result_content:
self._console.print_info("Output:")
if isinstance(result_content, str):
if len(result_content) > 500:
self._console.print_message(f" {result_content[:497]}...")
else:
self._console.print_message(f" {result_content}")
elif isinstance(result_content, list):
for item in result_content[:3]:
self._console.print_info(f" - {item}")
if len(result_content) > 3:
self._console.print_info(f" ... and {len(result_content) - 3} more items")
elif msg_type == 'ToolResultMessage':
if self.verbose:
logger.debug("Tool result message received")
self._console.print_separator()
self._console.print_status("TOOL RESULT MESSAGE", style="yellow bold")
if hasattr(message, 'tool_use_id'):
self._console.print_info(f"Tool ID: {message.tool_use_id[:12]}...")
if hasattr(message, 'content'):
content = message.content
if content:
self._console.print_info("Content:")
if isinstance(content, str):
if len(content) > 500:
self._console.print_message(f" {content[:497]}...")
else:
self._console.print_message(f" {content}")
elif isinstance(content, list):
for item in content[:3]:
self._console.print_info(f" - {item}")
if len(content) > 3:
self._console.print_info(f" ... and {len(content) - 3} more items")
if hasattr(message, 'is_error') and message.is_error:
self._console.print_error("Error: True")
elif hasattr(message, 'text'):
chunk_text = message.text
output_chunks.append(chunk_text)
if self.verbose:
self._console.print_message(chunk_text)
logger.debug(f"Received text chunk {chunk_count}: {len(chunk_text)} characters")
elif isinstance(message, str):
output_chunks.append(message)
if self.verbose:
self._console.print_message(message)
logger.debug(f"Received string chunk {chunk_count}: {len(message)} characters")
else:
if self.verbose:
logger.debug(f"Unknown message type {msg_type}: {message}")
# Combine output
output = ''.join(output_chunks)
# End streaming section if verbose
if self.verbose:
self._console.print_separator()
# Always log the output we're about to return
logger.debug(f"Claude adapter returning {len(output)} characters of output")
if output:
logger.debug(f"Output preview: {output[:200]}...")
# Calculate cost if we have token count (using model-specific pricing)
cost = self._calculate_cost(tokens_used, model) if tokens_used > 0 else None
# Log response details if verbose
if self.verbose:
logger.debug("Claude SDK Response:")
logger.debug(f" Output length: {len(output)} characters")
logger.debug(f" Chunks received: {chunk_count}")
if tokens_used > 0:
logger.debug(f" Tokens used: {tokens_used}")
if cost:
logger.debug(f" Estimated cost: ${cost:.4f}")
logger.debug(f"Response preview: {output[:500]}..." if len(output) > 500 else f"Response: {output}")
return ToolResponse(
success=True,
output=output,
tokens_used=tokens_used if tokens_used > 0 else None,
cost=cost,
metadata={"model": model}
)
except asyncio.TimeoutError as e:
# Use error formatter for user-friendly timeout message
error_msg = ClaudeErrorFormatter.format_error_from_exception(
iteration=kwargs.get('iteration', 0),
exception=e
)
logger.warning(f"Claude SDK request timed out: {error_msg.message}")
return ToolResponse(
success=False,
output="",
error=str(error_msg)
)
except Exception as e:
# Check if this is a user-initiated cancellation (SIGINT = exit code -2)
error_str = str(e)
if "exit code -2" in error_str or "exit code: -2" in error_str:
logger.debug("Claude execution cancelled by user")
return ToolResponse(
success=False,
output="",
error="Execution cancelled by user"
)
# Use error formatter for user-friendly error messages
error_msg = ClaudeErrorFormatter.format_error_from_exception(
iteration=kwargs.get('iteration', 0),
exception=e
)
# Log additional debugging information for "Command failed" errors
if "Command failed with exit code 1" in error_str:
logger.error(f"Claude CLI command failed - this may indicate:")
logger.error(" 1. Claude CLI not properly installed or configured")
logger.error(" 2. Missing API key or authentication issues")
logger.error(" 3. Network connectivity problems")
logger.error(" 4. Insufficient permissions")
logger.error(f" Full error: {error_str}")
# Try to provide more specific guidance
try:
import subprocess
result = subprocess.run(['claude', '--version'], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
logger.error(f" Claude CLI version: {result.stdout.strip()}")
else:
logger.error(f" Claude CLI check failed: {result.stderr}")
except Exception as check_e:
logger.error(f" Could not check Claude CLI: {check_e}")
logger.error(f"Claude SDK error: {error_msg.message}", exc_info=True)
return ToolResponse(
success=False,
output="",
error=str(error_msg)
)
def _calculate_cost(self, tokens: Optional[int], model: str = None) -> Optional[float]:
"""Calculate estimated cost based on tokens and model.
Args:
tokens: Total tokens used (input + output combined)
model: Model ID used for the request
Returns:
Estimated cost in USD, or None if tokens is None/0
"""
if not tokens:
return None
model = model or self._model
# Get model pricing or use default
if model in self.MODEL_PRICING:
pricing = self.MODEL_PRICING[model]
else:
# Fallback to Opus 4.5 pricing for unknown models
pricing = self.MODEL_PRICING[self.DEFAULT_MODEL]
# Estimate input/output split (typically ~30% input, ~70% output for agent work)
# This is an approximation since we don't always get separate counts
input_tokens = int(tokens * 0.3)
output_tokens = int(tokens * 0.7)
input_cost = (input_tokens / 1_000_000) * pricing["input"]
output_cost = (output_tokens / 1_000_000) * pricing["output"]
return input_cost + output_cost
def estimate_cost(self, prompt: str, model: str = None) -> float:
"""Estimate cost for the prompt.
Args:
prompt: The prompt text to estimate cost for
model: Model ID to use for pricing (defaults to configured model)
Returns:
Estimated cost in USD
"""
# Rough estimation: 1 token ≈ 4 characters
estimated_tokens = len(prompt) / 4
return self._calculate_cost(int(estimated_tokens), model) or 0.0
def kill_subprocess_sync(self) -> None:
"""
Kill subprocess synchronously (safe to call from signal handler).
This method uses os.kill() which is signal-safe and can be called
from the signal handler context. It immediately terminates the subprocess,
which unblocks any I/O operations waiting on it.
"""
if self._subprocess_pid:
try:
# Try SIGTERM first for graceful shutdown
os.kill(self._subprocess_pid, signal.SIGTERM)
# Small delay to allow graceful shutdown - keep minimal for signal handler
import time
try:
time.sleep(0.01)
except Exception:
pass # Ignore errors during sleep in signal handler
# Then SIGKILL if still alive (more forceful)
try:
os.kill(self._subprocess_pid, signal.SIGKILL)
except ProcessLookupError:
pass # Already dead from SIGTERM
except ProcessLookupError:
pass # Already dead
except (PermissionError, OSError):
pass # Best effort - process might be owned by another user
finally:
self._subprocess_pid = None
async def _cleanup_transport(self) -> None:
"""Clean up transport and kill subprocess with timeout protection."""
# Kill subprocess first (if not already killed by signal handler)
if self._subprocess_pid:
try:
# Try SIGTERM first
os.kill(self._subprocess_pid, signal.SIGTERM)
# Wait with timeout to avoid hanging
try:
await asyncio.wait_for(asyncio.sleep(0.01), timeout=0.05)
except asyncio.TimeoutError:
pass # Continue even if sleep times out
# Force kill if still alive
try:
os.kill(self._subprocess_pid, signal.SIGKILL)
except ProcessLookupError:
pass # Already dead
except ProcessLookupError:
pass # Already terminated
except (PermissionError, OSError):
pass # Best effort cleanup
finally:
self._subprocess_pid = None

View File

@@ -0,0 +1,120 @@
# ABOUTME: Gemini CLI adapter implementation
# ABOUTME: Provides fallback integration with Google's Gemini AI
"""Gemini CLI adapter for Ralph Orchestrator."""
import subprocess
from typing import Optional
from .base import ToolAdapter, ToolResponse
class GeminiAdapter(ToolAdapter):
"""Adapter for Gemini CLI tool."""
def __init__(self):
self.command = "gemini"
super().__init__("gemini")
def check_availability(self) -> bool:
"""Check if Gemini CLI is available."""
try:
result = subprocess.run(
[self.command, "--version"],
capture_output=True,
timeout=5,
text=True
)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
return False
def execute(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute Gemini with the given prompt."""
if not self.available:
return ToolResponse(
success=False,
output="",
error="Gemini CLI is not available"
)
try:
# Enhance prompt with orchestration instructions
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
# Build command
cmd = [self.command]
# Add model if specified
if kwargs.get("model"):
cmd.extend(["--model", kwargs["model"]])
# Add the enhanced prompt
cmd.extend(["-p", enhanced_prompt])
# Add output format if specified
if kwargs.get("output_format"):
cmd.extend(["--output", kwargs["output_format"]])
# Execute command
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=kwargs.get("timeout", 300) # 5 minute default
)
if result.returncode == 0:
# Extract token count if available
tokens = self._extract_token_count(result.stderr)
return ToolResponse(
success=True,
output=result.stdout,
tokens_used=tokens,
cost=self._calculate_cost(tokens),
metadata={"model": kwargs.get("model", "gemini-2.5-pro")}
)
else:
return ToolResponse(
success=False,
output=result.stdout,
error=result.stderr or "Gemini command failed"
)
except subprocess.TimeoutExpired:
return ToolResponse(
success=False,
output="",
error="Gemini command timed out"
)
except Exception as e:
return ToolResponse(
success=False,
output="",
error=str(e)
)
def _extract_token_count(self, stderr: str) -> Optional[int]:
"""Extract token count from Gemini output."""
# Implementation depends on Gemini's output format
return None
def _calculate_cost(self, tokens: Optional[int]) -> Optional[float]:
"""Calculate estimated cost based on tokens."""
if not tokens:
return None
# Gemini has free tier up to 1M tokens
if tokens < 1_000_000:
return 0.0
# After free tier: $0.001 per 1K tokens (approximate)
excess_tokens = tokens - 1_000_000
cost_per_1k = 0.001
return (excess_tokens / 1000) * cost_per_1k
def estimate_cost(self, prompt: str) -> float:
"""Estimate cost for the prompt."""
# Rough estimation: 1 token ≈ 4 characters
estimated_tokens = len(prompt) / 4
return self._calculate_cost(estimated_tokens) or 0.0

View File

@@ -0,0 +1,562 @@
# ABOUTME: Kiro CLI adapter implementation
# ABOUTME: Provides integration with kiro-cli chat command for AI interactions
"""Kiro CLI adapter for Ralph Orchestrator."""
import subprocess
import os
import sys
import signal
import threading
import asyncio
import time
try:
import fcntl # Unix-only
except ModuleNotFoundError:
fcntl = None
from .base import ToolAdapter, ToolResponse
from ..logging_config import RalphLogger
# Get logger for this module
logger = RalphLogger.get_logger(RalphLogger.ADAPTER_KIRO)
class KiroAdapter(ToolAdapter):
"""Adapter for Kiro CLI tool."""
def __init__(self):
# Get configuration from environment variables
self.command = os.getenv("RALPH_KIRO_COMMAND", "kiro-cli")
self.default_timeout = int(os.getenv("RALPH_KIRO_TIMEOUT", "600"))
self.default_prompt_file = os.getenv("RALPH_KIRO_PROMPT_FILE", "PROMPT.md")
self.trust_all_tools = os.getenv("RALPH_KIRO_TRUST_TOOLS", "true").lower() == "true"
self.no_interactive = os.getenv("RALPH_KIRO_NO_INTERACTIVE", "true").lower() == "true"
# Initialize signal handler attributes before calling super()
self._original_sigint = None
self._original_sigterm = None
super().__init__("kiro")
self.current_process = None
self.shutdown_requested = False
# Thread synchronization
self._lock = threading.Lock()
# Register signal handlers to propagate shutdown to subprocess
self._register_signal_handlers()
logger.info(f"Kiro adapter initialized - Command: {self.command}, "
f"Default timeout: {self.default_timeout}s, "
f"Trust tools: {self.trust_all_tools}")
def _register_signal_handlers(self):
"""Register signal handlers and store originals."""
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
logger.debug("Signal handlers registered for SIGINT and SIGTERM")
def _restore_signal_handlers(self):
"""Restore original signal handlers."""
if hasattr(self, "_original_sigint") and self._original_sigint is not None:
signal.signal(signal.SIGINT, self._original_sigint)
if hasattr(self, "_original_sigterm") and self._original_sigterm is not None:
signal.signal(signal.SIGTERM, self._original_sigterm)
def _signal_handler(self, signum, frame):
"""Handle shutdown signals and terminate running subprocess."""
with self._lock:
self.shutdown_requested = True
process = self.current_process
if process and process.poll() is None:
logger.warning(f"Received signal {signum}, terminating Kiro process...")
try:
process.terminate()
process.wait(timeout=3)
logger.debug("Process terminated gracefully")
except subprocess.TimeoutExpired:
logger.warning("Force killing Kiro process...")
process.kill()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
logger.warning("Process may still be running after force kill")
def check_availability(self) -> bool:
"""Check if kiro-cli (or q) is available."""
# First check for configured command (defaults to kiro-cli)
if self._check_command(self.command):
logger.debug(f"Kiro command '{self.command}' available")
return True
# If kiro-cli not found, try fallback to 'q'
if self.command == "kiro-cli":
if self._check_command("q"):
logger.info("kiro-cli not found, falling back to 'q'")
self.command = "q"
return True
logger.warning(f"Kiro command '{self.command}' (and fallback) not found")
return False
def _check_command(self, cmd: str) -> bool:
"""Check if a specific command exists."""
try:
result = subprocess.run(
["which", cmd],
capture_output=True,
timeout=5,
text=True
)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
return False
def execute(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute kiro-cli chat with the given prompt."""
if not self.available:
return ToolResponse(
success=False,
output="",
error="Kiro CLI is not available"
)
try:
# Get verbose flag from kwargs
verbose = kwargs.get("verbose", True)
# Get the prompt file path from kwargs if available
prompt_file = kwargs.get("prompt_file", self.default_prompt_file)
logger.info(f"Executing Kiro chat - Prompt file: {prompt_file}, Verbose: {verbose}")
# Enhance prompt with orchestration instructions
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
# Construct a more effective prompt
# Tell it explicitly to edit the prompt file
effective_prompt = (
f"Please read and complete the task described in the file '{prompt_file}'. "
f"The current content is:\n\n{enhanced_prompt}\n\n"
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
)
# Build command
cmd = [self.command, "chat"]
if self.no_interactive:
cmd.append("--no-interactive")
if self.trust_all_tools:
cmd.append("--trust-all-tools")
cmd.append(effective_prompt)
logger.debug(f"Command constructed: {' '.join(cmd)}")
timeout = kwargs.get("timeout", self.default_timeout)
if verbose:
logger.info(f"Starting {self.command} chat command...")
logger.info(f"Command: {' '.join(cmd)}")
logger.info(f"Working directory: {os.getcwd()}")
logger.info(f"Timeout: {timeout} seconds")
print("-" * 60, file=sys.stderr)
# Use Popen for real-time output streaming
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=os.getcwd(),
bufsize=0, # Unbuffered to prevent deadlock
universal_newlines=True
)
# Set process reference with lock
with self._lock:
self.current_process = process
# Make pipes non-blocking to prevent deadlock
self._make_non_blocking(process.stdout)
self._make_non_blocking(process.stderr)
# Collect output while streaming
stdout_lines = []
stderr_lines = []
start_time = time.time()
last_output_time = start_time
while True:
# Check for shutdown signal first with lock
with self._lock:
shutdown = self.shutdown_requested
if shutdown:
if verbose:
print("Shutdown requested, terminating process...", file=sys.stderr)
process.terminate()
try:
process.wait(timeout=3)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=2)
# Clean up process reference with lock
with self._lock:
self.current_process = None
return ToolResponse(
success=False,
output="".join(stdout_lines),
error="Process terminated due to shutdown signal"
)
# Check for timeout
elapsed_time = time.time() - start_time
# Log progress every 30 seconds
if int(elapsed_time) % 30 == 0 and int(elapsed_time) > 0:
logger.debug(f"Kiro still running... elapsed: {elapsed_time:.1f}s / {timeout}s")
# Check if the process seems stuck (no output for a while)
time_since_output = time.time() - last_output_time
if time_since_output > 60:
logger.info(f"No output received for {time_since_output:.1f}s, Kiro might be stuck")
if verbose:
print(f"Kiro still running... elapsed: {elapsed_time:.1f}s / {timeout}s", file=sys.stderr)
if elapsed_time > timeout:
logger.warning(f"Command timed out after {elapsed_time:.2f} seconds")
if verbose:
print(f"Command timed out after {elapsed_time:.2f} seconds", file=sys.stderr)
# Try to terminate gracefully first
process.terminate()
try:
# Wait a bit for graceful termination
process.wait(timeout=3)
except subprocess.TimeoutExpired:
logger.warning("Graceful termination failed, force killing process")
if verbose:
print("Graceful termination failed, force killing process", file=sys.stderr)
process.kill()
# Wait for force kill to complete
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
logger.warning("Process may still be running after kill")
if verbose:
print("Warning: Process may still be running after kill", file=sys.stderr)
# Try to capture any remaining output after termination
try:
remaining_stdout = process.stdout.read()
remaining_stderr = process.stderr.read()
if remaining_stdout:
stdout_lines.append(remaining_stdout)
if remaining_stderr:
stderr_lines.append(remaining_stderr)
except Exception as e:
logger.warning(f"Could not read remaining output after timeout: {e}")
if verbose:
print(f"Warning: Could not read remaining output after timeout: {e}", file=sys.stderr)
# Clean up process reference with lock
with self._lock:
self.current_process = None
return ToolResponse(
success=False,
output="".join(stdout_lines),
error=f"Kiro command timed out after {elapsed_time:.2f} seconds"
)
# Check if process is still running
if process.poll() is not None:
# Process finished, read remaining output
remaining_stdout = process.stdout.read()
remaining_stderr = process.stderr.read()
if remaining_stdout:
stdout_lines.append(remaining_stdout)
if verbose:
print(f"{remaining_stdout}", end="", file=sys.stderr)
if remaining_stderr:
stderr_lines.append(remaining_stderr)
if verbose:
print(f"{remaining_stderr}", end="", file=sys.stderr)
break
# Read available data without blocking
try:
# Read stdout
stdout_data = self._read_available(process.stdout)
if stdout_data:
stdout_lines.append(stdout_data)
last_output_time = time.time()
if verbose:
print(stdout_data, end="", file=sys.stderr)
# Read stderr
stderr_data = self._read_available(process.stderr)
if stderr_data:
stderr_lines.append(stderr_data)
last_output_time = time.time()
if verbose:
print(stderr_data, end="", file=sys.stderr)
# Small sleep to prevent busy waiting
time.sleep(0.01)
except BlockingIOError:
# Non-blocking read returns BlockingIOError when no data available
pass
# Get final return code
returncode = process.poll()
execution_time = time.time() - start_time
logger.info(f"Process completed - Return code: {returncode}, Execution time: {execution_time:.2f}s")
if verbose:
print("-" * 60, file=sys.stderr)
print(f"Process completed with return code: {returncode}", file=sys.stderr)
print(f"Total execution time: {execution_time:.2f} seconds", file=sys.stderr)
# Clean up process reference with lock
with self._lock:
self.current_process = None
# Combine output
full_stdout = "".join(stdout_lines)
full_stderr = "".join(stderr_lines)
if returncode == 0:
logger.debug(f"Kiro chat succeeded - Output length: {len(full_stdout)} chars")
return ToolResponse(
success=True,
output=full_stdout,
metadata={
"tool": "kiro chat",
"execution_time": execution_time,
"verbose": verbose,
"return_code": returncode
}
)
else:
logger.warning(f"Kiro chat failed - Return code: {returncode}, Error: {full_stderr[:200]}")
return ToolResponse(
success=False,
output=full_stdout,
error=full_stderr or f"Kiro command failed with code {returncode}"
)
except Exception as e:
logger.exception(f"Exception during Kiro chat execution: {str(e)}")
if verbose:
print(f"Exception occurred: {str(e)}", file=sys.stderr)
# Clean up process reference on exception
with self._lock:
self.current_process = None
return ToolResponse(
success=False,
output="",
error=str(e)
)
def _make_non_blocking(self, pipe):
"""Make a pipe non-blocking to prevent deadlock."""
if not pipe or fcntl is None:
return
try:
fd = pipe.fileno()
if isinstance(fd, int) and fd >= 0:
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
except (AttributeError, ValueError, OSError):
# In tests or when pipe doesn't support fileno()
pass
def _read_available(self, pipe):
"""Read available data from a non-blocking pipe."""
if not pipe:
return ""
try:
# Try to read up to 4KB at a time
data = pipe.read(4096)
# Ensure we always return a string, not None
if data is None:
return ""
return data if data else ""
except (IOError, OSError):
# Would block or no data available
return ""
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
"""Native async execution using asyncio subprocess."""
if not self.available:
return ToolResponse(
success=False,
output="",
error="Kiro CLI is not available"
)
try:
verbose = kwargs.get("verbose", True)
prompt_file = kwargs.get("prompt_file", self.default_prompt_file)
timeout = kwargs.get("timeout", self.default_timeout)
logger.info(f"Executing Kiro chat async - Prompt file: {prompt_file}, Timeout: {timeout}s")
# Enhance prompt with orchestration instructions
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
# Construct effective prompt
effective_prompt = (
f"Please read and complete the task described in the file '{prompt_file}'. "
f"The current content is:\n\n{enhanced_prompt}\n\n"
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
)
# Build command
cmd = [
self.command,
"chat",
"--no-interactive",
"--trust-all-tools",
effective_prompt
]
logger.debug(f"Starting async Kiro chat command: {' '.join(cmd)}")
if verbose:
print(f"Starting {self.command} chat command (async)...", file=sys.stderr)
print(f"Command: {' '.join(cmd)}", file=sys.stderr)
print("-" * 60, file=sys.stderr)
# Create async subprocess
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=os.getcwd()
)
# Set process reference with lock
with self._lock:
self.current_process = process
try:
# Wait for completion with timeout
stdout_data, stderr_data = await asyncio.wait_for(
process.communicate(),
timeout=timeout
)
# Decode output
stdout = stdout_data.decode("utf-8") if stdout_data else ""
stderr = stderr_data.decode("utf-8") if stderr_data else ""
if verbose and stdout:
print(stdout, file=sys.stderr)
if verbose and stderr:
print(stderr, file=sys.stderr)
# Check return code
if process.returncode == 0:
logger.debug(f"Async Kiro chat succeeded - Output length: {len(stdout)} chars")
return ToolResponse(
success=True,
output=stdout,
metadata={
"tool": "kiro chat",
"verbose": verbose,
"async": True,
"return_code": process.returncode
}
)
else:
logger.warning(f"Async Kiro chat failed - Return code: {process.returncode}")
return ToolResponse(
success=False,
output=stdout,
error=stderr or f"Kiro chat failed with code {process.returncode}"
)
except asyncio.TimeoutError:
# Timeout occurred
logger.warning(f"Async Kiro chat timed out after {timeout} seconds")
if verbose:
print(f"Async Kiro chat timed out after {timeout} seconds", file=sys.stderr)
# Try to terminate process
try:
process.terminate()
await asyncio.wait_for(process.wait(), timeout=3)
except (asyncio.TimeoutError, ProcessLookupError):
try:
process.kill()
await process.wait()
except ProcessLookupError:
pass
return ToolResponse(
success=False,
output="",
error=f"Kiro command timed out after {timeout} seconds"
)
finally:
# Clean up process reference
with self._lock:
self.current_process = None
except Exception as e:
logger.exception(f"Async execution error: {str(e)}")
if kwargs.get("verbose"):
print(f"Async execution error: {str(e)}", file=sys.stderr)
return ToolResponse(
success=False,
output="",
error=str(e)
)
def estimate_cost(self, prompt: str) -> float:
"""Kiro chat cost estimation (if applicable)."""
# Return 0 for now
return 0.0
def __del__(self):
"""Cleanup on deletion."""
# Restore original signal handlers
self._restore_signal_handlers()
# Ensure any running process is terminated
if hasattr(self, "_lock"):
with self._lock:
process = self.current_process if hasattr(self, "current_process") else None
else:
process = getattr(self, "current_process", None)
if process:
try:
if hasattr(process, "poll"):
# Sync process
if process.poll() is None:
process.terminate()
process.wait(timeout=1)
else:
# Async process - can't do much in __del__
pass
except Exception as e:
# Best-effort cleanup during interpreter shutdown
logger.debug(f"Cleanup warning in __del__: {type(e).__name__}: {e}")

View File

@@ -0,0 +1,579 @@
# ABOUTME: Q Chat adapter implementation for q CLI tool (DEPRECATED)
# ABOUTME: Provides integration with q chat command for AI interactions
# ABOUTME: NOTE: Consider using KiroAdapter instead - Q CLI rebranded to Kiro CLI
"""Q Chat adapter for Ralph Orchestrator.
.. deprecated::
The QChatAdapter is deprecated in favor of KiroAdapter.
Amazon Q Developer CLI has been rebranded to Kiro CLI (v1.20+).
Use `-a kiro` instead of `-a qchat` or `-a q` for new projects.
Migration guide:
- Config paths changed: ~/.aws/amazonq/ -> ~/.kiro/
- MCP servers: ~/.kiro/settings/mcp.json
- Project files: .kiro/ folder
"""
import subprocess
import os
import sys
import signal
import threading
import asyncio
import time
import warnings
try:
import fcntl # Unix-only
except ModuleNotFoundError:
fcntl = None
from .base import ToolAdapter, ToolResponse
from ..logging_config import RalphLogger
# Get logger for this module
logger = RalphLogger.get_logger(RalphLogger.ADAPTER_QCHAT)
class QChatAdapter(ToolAdapter):
"""Adapter for Q Chat CLI tool.
.. deprecated::
Use :class:`KiroAdapter` instead. The Amazon Q Developer CLI has been
rebranded to Kiro CLI. This adapter is maintained for backwards
compatibility only.
"""
def __init__(self):
# Emit deprecation warning
warnings.warn(
"QChatAdapter is deprecated. Use KiroAdapter instead. "
"Amazon Q Developer CLI has been rebranded to Kiro CLI (v1.20+). "
"Run with '-a kiro' instead of '-a qchat'.",
DeprecationWarning,
stacklevel=2
)
# Get configuration from environment variables
self.command = os.getenv("RALPH_QCHAT_COMMAND", "q")
self.default_timeout = int(os.getenv("RALPH_QCHAT_TIMEOUT", "600"))
self.default_prompt_file = os.getenv("RALPH_QCHAT_PROMPT_FILE", "PROMPT.md")
self.trust_all_tools = os.getenv("RALPH_QCHAT_TRUST_TOOLS", "true").lower() == "true"
self.no_interactive = os.getenv("RALPH_QCHAT_NO_INTERACTIVE", "true").lower() == "true"
# Initialize signal handler attributes before calling super()
self._original_sigint = None
self._original_sigterm = None
super().__init__("qchat")
self.current_process = None
self.shutdown_requested = False
# Thread synchronization
self._lock = threading.Lock()
# Register signal handlers to propagate shutdown to subprocess
self._register_signal_handlers()
logger.info(f"Q Chat adapter initialized - Command: {self.command}, "
f"Default timeout: {self.default_timeout}s, "
f"Trust tools: {self.trust_all_tools}")
def _register_signal_handlers(self):
"""Register signal handlers and store originals."""
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
logger.debug("Signal handlers registered for SIGINT and SIGTERM")
def _restore_signal_handlers(self):
"""Restore original signal handlers."""
if hasattr(self, '_original_sigint') and self._original_sigint is not None:
signal.signal(signal.SIGINT, self._original_sigint)
if hasattr(self, '_original_sigterm') and self._original_sigterm is not None:
signal.signal(signal.SIGTERM, self._original_sigterm)
def _signal_handler(self, signum, frame):
"""Handle shutdown signals and terminate running subprocess."""
with self._lock:
self.shutdown_requested = True
process = self.current_process
if process and process.poll() is None:
logger.warning(f"Received signal {signum}, terminating q chat process...")
try:
process.terminate()
process.wait(timeout=3)
logger.debug("Process terminated gracefully")
except subprocess.TimeoutExpired:
logger.warning("Force killing q chat process...")
process.kill()
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
logger.warning("Process may still be running after force kill")
def check_availability(self) -> bool:
"""Check if q CLI is available."""
try:
# Try to check if q command exists
result = subprocess.run(
["which", self.command],
capture_output=True,
timeout=5,
text=True
)
available = result.returncode == 0
logger.debug(f"Q command '{self.command}' availability check: {available}")
return available
except (subprocess.TimeoutExpired, FileNotFoundError) as e:
logger.warning(f"Q command availability check failed: {e}")
return False
def execute(self, prompt: str, **kwargs) -> ToolResponse:
"""Execute q chat with the given prompt."""
if not self.available:
return ToolResponse(
success=False,
output="",
error="q CLI is not available"
)
try:
# Get verbose flag from kwargs
verbose = kwargs.get('verbose', True)
# Get the prompt file path from kwargs if available
prompt_file = kwargs.get('prompt_file', self.default_prompt_file)
logger.info(f"Executing Q chat - Prompt file: {prompt_file}, Verbose: {verbose}")
# Enhance prompt with orchestration instructions
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
# Construct a more effective prompt for q chat
# Tell it explicitly to edit the prompt file
effective_prompt = (
f"Please read and complete the task described in the file '{prompt_file}'. "
f"The current content is:\n\n{enhanced_prompt}\n\n"
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
)
# Build command - q chat works with files by adding them to context
# We pass the prompt through stdin and tell it to trust file operations
cmd = [self.command, "chat"]
if self.no_interactive:
cmd.append("--no-interactive")
if self.trust_all_tools:
cmd.append("--trust-all-tools")
cmd.append(effective_prompt)
logger.debug(f"Command constructed: {' '.join(cmd)}")
timeout = kwargs.get("timeout", self.default_timeout)
if verbose:
logger.info("Starting q chat command...")
logger.info(f"Command: {' '.join(cmd)}")
logger.info(f"Working directory: {os.getcwd()}")
logger.info(f"Timeout: {timeout} seconds")
print("-" * 60, file=sys.stderr)
# Use Popen for real-time output streaming
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=os.getcwd(),
bufsize=0, # Unbuffered to prevent deadlock
universal_newlines=True
)
# Set process reference with lock
with self._lock:
self.current_process = process
# Make pipes non-blocking to prevent deadlock
self._make_non_blocking(process.stdout)
self._make_non_blocking(process.stderr)
# Collect output while streaming
stdout_lines = []
stderr_lines = []
start_time = time.time()
last_output_time = start_time
while True:
# Check for shutdown signal first with lock
with self._lock:
shutdown = self.shutdown_requested
if shutdown:
if verbose:
print("Shutdown requested, terminating q chat process...", file=sys.stderr)
process.terminate()
try:
process.wait(timeout=3)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=2)
# Clean up process reference with lock
with self._lock:
self.current_process = None
return ToolResponse(
success=False,
output="".join(stdout_lines),
error="Process terminated due to shutdown signal"
)
# Check for timeout
elapsed_time = time.time() - start_time
# Log progress every 30 seconds
if int(elapsed_time) % 30 == 0 and int(elapsed_time) > 0:
logger.debug(f"Q chat still running... elapsed: {elapsed_time:.1f}s / {timeout}s")
# Check if the process seems stuck (no output for a while)
time_since_output = time.time() - last_output_time
if time_since_output > 60:
logger.info(f"No output received for {time_since_output:.1f}s, Q might be stuck")
if verbose:
print(f"Q chat still running... elapsed: {elapsed_time:.1f}s / {timeout}s", file=sys.stderr)
if elapsed_time > timeout:
logger.warning(f"Command timed out after {elapsed_time:.2f} seconds")
if verbose:
print(f"Command timed out after {elapsed_time:.2f} seconds", file=sys.stderr)
# Try to terminate gracefully first
process.terminate()
try:
# Wait a bit for graceful termination
process.wait(timeout=3)
except subprocess.TimeoutExpired:
logger.warning("Graceful termination failed, force killing process")
if verbose:
print("Graceful termination failed, force killing process", file=sys.stderr)
process.kill()
# Wait for force kill to complete
try:
process.wait(timeout=2)
except subprocess.TimeoutExpired:
logger.warning("Process may still be running after kill")
if verbose:
print("Warning: Process may still be running after kill", file=sys.stderr)
# Try to capture any remaining output after termination
try:
remaining_stdout = process.stdout.read()
remaining_stderr = process.stderr.read()
if remaining_stdout:
stdout_lines.append(remaining_stdout)
if remaining_stderr:
stderr_lines.append(remaining_stderr)
except Exception as e:
logger.warning(f"Could not read remaining output after timeout: {e}")
if verbose:
print(f"Warning: Could not read remaining output after timeout: {e}", file=sys.stderr)
# Clean up process reference with lock
with self._lock:
self.current_process = None
return ToolResponse(
success=False,
output="".join(stdout_lines),
error=f"q chat command timed out after {elapsed_time:.2f} seconds"
)
# Check if process is still running
if process.poll() is not None:
# Process finished, read remaining output
remaining_stdout = process.stdout.read()
remaining_stderr = process.stderr.read()
if remaining_stdout:
stdout_lines.append(remaining_stdout)
if verbose:
print(f"{remaining_stdout}", end='', file=sys.stderr)
if remaining_stderr:
stderr_lines.append(remaining_stderr)
if verbose:
print(f"{remaining_stderr}", end='', file=sys.stderr)
break
# Read available data without blocking
try:
# Read stdout
stdout_data = self._read_available(process.stdout)
if stdout_data:
stdout_lines.append(stdout_data)
last_output_time = time.time()
if verbose:
print(stdout_data, end='', file=sys.stderr)
# Read stderr
stderr_data = self._read_available(process.stderr)
if stderr_data:
stderr_lines.append(stderr_data)
last_output_time = time.time()
if verbose:
print(stderr_data, end='', file=sys.stderr)
# Small sleep to prevent busy waiting
time.sleep(0.01)
except BlockingIOError:
# Non-blocking read returns BlockingIOError when no data available
pass
# Get final return code
returncode = process.poll()
execution_time = time.time() - start_time
logger.info(f"Process completed - Return code: {returncode}, Execution time: {execution_time:.2f}s")
if verbose:
print("-" * 60, file=sys.stderr)
print(f"Process completed with return code: {returncode}", file=sys.stderr)
print(f"Total execution time: {execution_time:.2f} seconds", file=sys.stderr)
# Clean up process reference with lock
with self._lock:
self.current_process = None
# Combine output
full_stdout = "".join(stdout_lines)
full_stderr = "".join(stderr_lines)
if returncode == 0:
logger.debug(f"Q chat succeeded - Output length: {len(full_stdout)} chars")
return ToolResponse(
success=True,
output=full_stdout,
metadata={
"tool": "q chat",
"execution_time": execution_time,
"verbose": verbose,
"return_code": returncode
}
)
else:
logger.warning(f"Q chat failed - Return code: {returncode}, Error: {full_stderr[:200]}")
return ToolResponse(
success=False,
output=full_stdout,
error=full_stderr or f"q chat command failed with code {returncode}"
)
except Exception as e:
logger.exception(f"Exception during Q chat execution: {str(e)}")
if verbose:
print(f"Exception occurred: {str(e)}", file=sys.stderr)
# Clean up process reference on exception (matches async version's finally block)
with self._lock:
self.current_process = None
return ToolResponse(
success=False,
output="",
error=str(e)
)
def _make_non_blocking(self, pipe):
"""Make a pipe non-blocking to prevent deadlock."""
if not pipe or fcntl is None:
return
try:
fd = pipe.fileno()
if isinstance(fd, int) and fd >= 0:
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
except (AttributeError, ValueError, OSError):
# In tests or when pipe doesn't support fileno()
pass
def _read_available(self, pipe):
"""Read available data from a non-blocking pipe."""
if not pipe:
return ""
try:
# Try to read up to 4KB at a time
data = pipe.read(4096)
# Ensure we always return a string, not None
if data is None:
return ""
return data if data else ""
except (IOError, OSError):
# Would block or no data available
return ""
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
"""Native async execution using asyncio subprocess."""
if not self.available:
return ToolResponse(
success=False,
output="",
error="q CLI is not available"
)
try:
verbose = kwargs.get('verbose', True)
prompt_file = kwargs.get('prompt_file', self.default_prompt_file)
timeout = kwargs.get('timeout', self.default_timeout)
logger.info(f"Executing Q chat async - Prompt file: {prompt_file}, Timeout: {timeout}s")
# Enhance prompt with orchestration instructions
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
# Construct effective prompt
effective_prompt = (
f"Please read and complete the task described in the file '{prompt_file}'. "
f"The current content is:\n\n{enhanced_prompt}\n\n"
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
)
# Build command
cmd = [
self.command,
"chat",
"--no-interactive",
"--trust-all-tools",
effective_prompt
]
logger.debug(f"Starting async Q chat command: {' '.join(cmd)}")
if verbose:
print("Starting q chat command (async)...", file=sys.stderr)
print(f"Command: {' '.join(cmd)}", file=sys.stderr)
print("-" * 60, file=sys.stderr)
# Create async subprocess
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=os.getcwd()
)
# Set process reference with lock
with self._lock:
self.current_process = process
try:
# Wait for completion with timeout
stdout_data, stderr_data = await asyncio.wait_for(
process.communicate(),
timeout=timeout
)
# Decode output
stdout = stdout_data.decode('utf-8') if stdout_data else ""
stderr = stderr_data.decode('utf-8') if stderr_data else ""
if verbose and stdout:
print(stdout, file=sys.stderr)
if verbose and stderr:
print(stderr, file=sys.stderr)
# Check return code
if process.returncode == 0:
logger.debug(f"Async Q chat succeeded - Output length: {len(stdout)} chars")
return ToolResponse(
success=True,
output=stdout,
metadata={
"tool": "q chat",
"verbose": verbose,
"async": True,
"return_code": process.returncode
}
)
else:
logger.warning(f"Async Q chat failed - Return code: {process.returncode}")
return ToolResponse(
success=False,
output=stdout,
error=stderr or f"q chat failed with code {process.returncode}"
)
except asyncio.TimeoutError:
# Timeout occurred
logger.warning(f"Async q chat timed out after {timeout} seconds")
if verbose:
print(f"Async q chat timed out after {timeout} seconds", file=sys.stderr)
# Try to terminate process
try:
process.terminate()
await asyncio.wait_for(process.wait(), timeout=3)
except (asyncio.TimeoutError, ProcessLookupError):
try:
process.kill()
await process.wait()
except ProcessLookupError:
pass
return ToolResponse(
success=False,
output="",
error=f"q chat command timed out after {timeout} seconds"
)
finally:
# Clean up process reference
with self._lock:
self.current_process = None
except Exception as e:
logger.exception(f"Async execution error: {str(e)}")
if kwargs.get('verbose'):
print(f"Async execution error: {str(e)}", file=sys.stderr)
return ToolResponse(
success=False,
output="",
error=str(e)
)
def estimate_cost(self, prompt: str) -> float:
"""Q chat cost estimation (if applicable)."""
# Q chat might be free or have different pricing
# Return 0 for now, can be updated based on actual pricing
return 0.0
def __del__(self):
"""Cleanup on deletion."""
# Restore original signal handlers
self._restore_signal_handlers()
# Ensure any running process is terminated
if hasattr(self, '_lock'):
with self._lock:
process = self.current_process if hasattr(self, 'current_process') else None
else:
process = getattr(self, 'current_process', None)
if process:
try:
if hasattr(process, 'poll'):
# Sync process
if process.poll() is None:
process.terminate()
process.wait(timeout=1)
else:
# Async process - can't do much in __del__
pass
except Exception as e:
# Best-effort cleanup during interpreter shutdown
# Log at debug level since __del__ is unreliable
logger.debug(f"Cleanup warning in __del__: {type(e).__name__}: {e}")

View File

@@ -0,0 +1,472 @@
# ABOUTME: Advanced async logging with rotation, thread safety, and security features
# ABOUTME: Provides dual interface (async + sync) with unicode sanitization
"""
Advanced async logging with rotation for Ralph Orchestrator.
Features:
- Automatic log rotation at 10MB with 3 backups
- Thread-safe rotation with threading.Lock
- Unicode sanitization for encoding errors
- Security-aware logging (masks sensitive data)
- Dual interface: async methods + sync wrappers
"""
import asyncio
import functools
import shutil
import sys
import threading
import warnings
from datetime import datetime
from pathlib import Path
from typing import Optional
from ralph_orchestrator.security import SecurityValidator
def async_method_warning(func):
"""Decorator to warn when async methods are called without await."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
coro = func(self, *args, **kwargs)
class WarningCoroutine:
"""Wrapper that warns when garbage collected without being awaited."""
def __init__(self, coro, method_name):
self._coro = coro
self._method_name = method_name
self._warned = False
self._awaited = False
def __await__(self):
self._awaited = True
return self._coro.__await__()
def __del__(self):
if not self._warned and not self._awaited:
warnings.warn(
f"AsyncFileLogger.{self._method_name}() was called without await. "
"The message was not logged. Use 'await logger.{self._method_name}(...)' "
f"or 'logger.{self._method_name}_sync(...)' instead.",
RuntimeWarning,
stacklevel=3,
)
self._warned = True
def close(self):
"""Support close() method for compatibility."""
pass
return WarningCoroutine(coro, func.__name__)
return wrapper
class AsyncFileLogger:
"""Async file logger with timestamps, rotation, and security features."""
# Log rotation constants
MAX_LOG_SIZE_BYTES = 10 * 1024 * 1024 # 10MB in bytes
MAX_BACKUP_FILES = 3
# Default values for log parsing
DEFAULT_RECENT_LINES_COUNT = 3
def __init__(self, log_file: str, verbose: bool = False) -> None:
"""
Initialize async logger.
Args:
log_file: Path to log file
verbose: If True, also print to console
Raises:
ValueError: If log_file is None or empty
"""
if not log_file:
raise ValueError("log_file cannot be None or empty")
# Convert to string for validation if it's a Path object
log_file_str = str(log_file)
if not log_file_str or not log_file_str.strip():
raise ValueError("log_file cannot be empty")
self.log_file = Path(log_file)
self.verbose = verbose
self._lock = asyncio.Lock() # For async methods (single event loop)
self._thread_lock = threading.Lock() # For sync methods (multi-threaded)
self._rotation_lock = threading.Lock() # Thread safety for file rotation
# Emergency shutdown flag for graceful signal handling
self._emergency_shutdown = False
# Threading event for immediate signal-safe notification
self._emergency_event = threading.Event()
# Track logging failures when both file and stderr fail
self._logging_failures_count = 0
# Ensure log directory exists
self.log_file.parent.mkdir(parents=True, exist_ok=True)
# Rotate log if needed on startup (single-threaded, safe)
self._rotate_if_needed()
def emergency_shutdown(self) -> None:
"""
Signal emergency shutdown to make logging operations non-blocking.
This method is signal-safe and can be called from signal handlers.
After calling this, all logging operations will be skipped to allow
rapid shutdown without blocking on file I/O.
"""
self._emergency_shutdown = True
self._emergency_event.set()
def is_shutdown(self) -> bool:
"""Check if emergency shutdown has been triggered."""
return self._emergency_shutdown
async def log(self, level: str, message: str) -> None:
"""
Log a message with timestamp.
Args:
level: Log level (INFO, SUCCESS, ERROR, WARNING)
message: Message to log
"""
# Skip logging during emergency shutdown
if self._emergency_shutdown:
return
# Sanitize the message to handle problematic unicode
sanitized_message = self._sanitize_unicode(message)
# Mask sensitive data to prevent security vulnerabilities
secure_message = SecurityValidator.mask_sensitive_data(sanitized_message)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_line = f"{timestamp} [{level}] {secure_message}\n"
async with self._lock:
# Double-check shutdown flag after acquiring lock
if self._emergency_shutdown:
return
# Write to file
await asyncio.to_thread(self._write_to_file, log_line)
# Print to console if verbose
if self.verbose:
print(log_line.rstrip())
def _sanitize_unicode(self, message: str) -> str:
"""
Sanitize unicode message to prevent encoding errors.
Args:
message: Original message
Returns:
Sanitized message safe for UTF-8 encoding
"""
try:
# Test if the message can be encoded as UTF-8
message.encode("utf-8")
return message
except UnicodeEncodeError:
# If encoding fails, replace problematic characters
try:
# Try to encode with errors='replace' first
return message.encode("utf-8", errors="replace").decode("utf-8")
except Exception:
# If that still fails, use more aggressive sanitization
sanitized = []
for char in message:
try:
char.encode("utf-8")
sanitized.append(char)
except UnicodeEncodeError:
# Replace problematic character with a placeholder
sanitized.append("[?]")
return "".join(sanitized)
except Exception:
# For any other unexpected errors, return a safe fallback
return "[Unicode encoding error]"
def _write_to_file(self, line: str) -> None:
"""Synchronous file write (called via to_thread)."""
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(line)
# Check if rotation is needed (thread-safe)
self._rotate_if_needed_thread_safe()
def _rotate_if_needed_thread_safe(self) -> None:
"""Thread-safe version of _rotate_if_needed() to prevent race conditions."""
with self._rotation_lock:
self._rotate_if_needed()
def _rotate_if_needed(self) -> None:
"""
Rotate log file if it exceeds max size.
Note: This method should only be called:
- During __init__ (single-threaded, safe before logger is shared)
- From within _rotate_if_needed_thread_safe() when rotation lock is held
"""
if not self.log_file.exists():
return
# Double-check file size with lock held
try:
file_size = self.log_file.stat().st_size
except (OSError, IOError):
# File might have been moved or deleted by another thread
return
if file_size > self.MAX_LOG_SIZE_BYTES:
# Create a temporary file to ensure atomic rotation
temp_backup = self.log_file.with_suffix(".log.tmp")
try:
# Atomically move current log to temporary backup
shutil.move(str(self.log_file), str(temp_backup))
# Rotate backups in reverse order
for i in range(self.MAX_BACKUP_FILES - 1, 0, -1):
old_backup = self.log_file.with_suffix(f".log.{i}")
new_backup = self.log_file.with_suffix(f".log.{i + 1}")
if old_backup.exists():
if new_backup.exists():
new_backup.unlink()
shutil.move(str(old_backup), str(new_backup))
# Move temporary backup to .1
backup = self.log_file.with_suffix(".log.1")
if backup.exists():
backup.unlink()
shutil.move(str(temp_backup), str(backup))
# Clean up any backups beyond MAX_BACKUP_FILES
i = self.MAX_BACKUP_FILES + 1
while True:
old_backup = self.log_file.with_suffix(f".log.{i}")
if old_backup.exists():
old_backup.unlink()
i += 1
else:
break
except (OSError, IOError):
# If rotation fails, try to restore from temporary backup
if temp_backup.exists() and not self.log_file.exists():
try:
shutil.move(str(temp_backup), str(self.log_file))
except (OSError, IOError):
# If we can't restore, at least remove the temp file
if temp_backup.exists():
temp_backup.unlink()
async def log_info(self, message: str) -> None:
"""Log info message."""
await self.log("INFO", message)
async def log_success(self, message: str) -> None:
"""Log success message."""
await self.log("SUCCESS", message)
async def log_error(self, message: str) -> None:
"""Log error message."""
await self.log("ERROR", message)
async def log_warning(self, message: str) -> None:
"""Log warning message."""
await self.log("WARNING", message)
def __del__(self):
"""Destructor to warn about unretrieved coroutines."""
try:
# Check if Python is shutting down
import sys
if sys.meta_path is None:
# Python is shutting down, skip cleanup to avoid errors
return
except Exception:
# Silently ignore any errors during destructor to avoid crashes
pass
# Synchronous wrapper methods for compatibility
def _log_sync_direct(self, level: str, message: str) -> None:
"""
Log a message synchronously using threading.Lock (thread-safe).
This bypasses asyncio entirely for true multi-threaded safety.
Includes defensive error handling with stderr fallback.
"""
if self._emergency_shutdown:
return
# Sanitize and secure the message
sanitized_message = self._sanitize_unicode(message)
secure_message = SecurityValidator.mask_sensitive_data(sanitized_message)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_line = f"{timestamp} [{level}] {secure_message}\n"
try:
with self._thread_lock:
if self._emergency_shutdown:
return
self._write_to_file(log_line)
if self.verbose:
print(log_line.rstrip())
except (PermissionError, OSError, IOError, FileNotFoundError) as e:
# Fallback to stderr when file I/O fails
# Truncate message to 200 chars for stderr output
truncated_message = secure_message[:200]
try:
print(
f"[LOGGING ERROR] {type(e).__name__}: {e}\n"
f"Original message (truncated): {truncated_message}",
file=sys.stderr
)
except Exception:
# If stderr also fails, track the failure count for diagnostics
# Cannot log or print - just increment counter
self._logging_failures_count += 1
def log_info_sync(self, message: str) -> None:
"""Log info message synchronously (thread-safe)."""
self._log_sync_direct("INFO", message)
def log_success_sync(self, message: str) -> None:
"""Log success message synchronously (thread-safe)."""
self._log_sync_direct("SUCCESS", message)
def log_error_sync(self, message: str) -> None:
"""Log error message synchronously (thread-safe)."""
self._log_sync_direct("ERROR", message)
def log_warning_sync(self, message: str) -> None:
"""Log warning message synchronously (thread-safe)."""
self._log_sync_direct("WARNING", message)
# Standard logging interface methods for compatibility
def info(self, message: str) -> None:
"""Standard logging interface - log info message synchronously."""
self.log_info_sync(message)
def debug(self, message: str) -> None:
"""Standard logging interface - log debug message synchronously (maps to info)."""
self.log_info_sync(message)
def warning(self, message: str) -> None:
"""Standard logging interface - log warning message synchronously."""
self.log_warning_sync(message)
def error(self, message: str) -> None:
"""Standard logging interface - log error message synchronously."""
self.log_error_sync(message)
def critical(self, message: str) -> None:
"""Standard logging interface - log critical message synchronously (maps to error)."""
self.log_error_sync(message)
def get_stats(self) -> dict[str, int | str | None]:
"""
Get statistics from log file.
Returns:
Dict with success_count, error_count, start_time
"""
if not self.log_file.exists():
return {"success_count": 0, "error_count": 0, "start_time": None}
success_count = 0
error_count = 0
start_time = None
with open(self.log_file, encoding="utf-8") as f:
lines = f.readlines()
if lines:
# Extract start time from first line
first_line = lines[0]
start_time = first_line.split(" [")[0] if " [" in first_line else None
# Count successes and errors
for line in lines:
if "Iteration" in line and "completed successfully" in line:
success_count += 1
elif "Iteration" in line and "failed" in line:
error_count += 1
return {
"success_count": success_count,
"error_count": error_count,
"start_time": start_time,
}
def get_recent_lines(self, count: Optional[int] = None) -> list[str]:
"""
Get recent log lines.
Args:
count: Number of recent lines to return (default: DEFAULT_RECENT_LINES_COUNT)
Returns:
List of recent log lines
"""
if count is None:
count = self.DEFAULT_RECENT_LINES_COUNT
if not self.log_file.exists():
return []
with open(self.log_file, encoding="utf-8") as f:
lines = f.readlines()
return [line.rstrip() for line in lines[-count:]]
def count_pattern(self, pattern: str) -> int:
"""
Count occurrences of pattern in log file.
Args:
pattern: Pattern to search for
Returns:
Number of occurrences
"""
if not self.log_file.exists():
return 0
with open(self.log_file, encoding="utf-8") as f:
content = f.read()
return content.count(pattern)
def get_start_time(self) -> Optional[str]:
"""
Get start time from first log entry.
Returns:
Start time string or None if no logs
"""
if not self.log_file.exists():
return None
with open(self.log_file, encoding="utf-8") as f:
first_line = f.readline()
if not first_line:
return None
# Extract timestamp from first line (format: "YYYY-MM-DD HH:MM:SS [LEVEL] message")
parts = first_line.split(" ", 2)
if len(parts) >= 2:
return f"{parts[0]} {parts[1]}"
return None

View File

@@ -0,0 +1,225 @@
# ABOUTME: Context management and optimization for Ralph Orchestrator
# ABOUTME: Handles prompt caching, summarization, and context window management
"""Context management for Ralph Orchestrator."""
from pathlib import Path
from typing import List, Optional, Dict
import hashlib
import logging
logger = logging.getLogger('ralph-orchestrator.context')
class ContextManager:
"""Manage prompt context and optimization."""
def __init__(
self,
prompt_file: Path,
max_context_size: int = 8000,
cache_dir: Path = Path(".agent/cache"),
prompt_text: Optional[str] = None
):
"""Initialize context manager.
Args:
prompt_file: Path to the main prompt file
max_context_size: Maximum context size in characters
cache_dir: Directory for caching context
prompt_text: Direct prompt text (overrides prompt_file if provided)
"""
self.prompt_file = prompt_file
self.max_context_size = max_context_size
self.cache_dir = cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.prompt_text = prompt_text # Direct prompt text override
# Context components
self.stable_prefix: Optional[str] = None
self.dynamic_context: List[str] = []
self.error_history: List[str] = []
self.success_patterns: List[str] = []
# Load initial prompt
self._load_initial_prompt()
def _load_initial_prompt(self):
"""Load and analyze the initial prompt."""
# Use direct prompt text if provided
if self.prompt_text:
logger.info("Using direct prompt_text input")
content = self.prompt_text
elif self.prompt_file.exists():
try:
content = self.prompt_file.read_text()
except UnicodeDecodeError as e:
logger.warning(f"Encoding error reading {self.prompt_file}: {e}")
return
except PermissionError as e:
logger.warning(f"Permission denied reading {self.prompt_file}: {e}")
return
except OSError as e:
logger.warning(f"OS error reading {self.prompt_file}: {e}")
return
else:
logger.info(f"Prompt file {self.prompt_file} not found")
return
# Extract stable prefix (instructions that don't change)
lines = content.split('\n')
stable_lines = []
for line in lines:
if line.startswith('#') or line.startswith('##'):
stable_lines.append(line)
# No longer breaking on completion markers
elif len(stable_lines) > 0 and line.strip() == '':
stable_lines.append(line)
elif len(stable_lines) > 0:
break
self.stable_prefix = '\n'.join(stable_lines)
logger.info(f"Extracted stable prefix: {len(self.stable_prefix)} chars")
def get_prompt(self) -> str:
"""Get the current prompt with optimizations."""
# Use direct prompt text if provided
if self.prompt_text:
base_content = self.prompt_text
elif self.prompt_file.exists():
try:
base_content = self.prompt_file.read_text()
except UnicodeDecodeError as e:
logger.warning(f"Encoding error reading {self.prompt_file}: {e}")
return ""
except PermissionError as e:
logger.warning(f"Permission denied reading {self.prompt_file}: {e}")
return ""
except OSError as e:
logger.warning(f"OS error reading {self.prompt_file}: {e}")
return ""
else:
logger.warning(f"No prompt available: prompt_text={self.prompt_text is not None}, prompt_file={self.prompt_file}")
return ""
# Check if we need to optimize
if len(base_content) > self.max_context_size:
return self._optimize_prompt(base_content)
# Add dynamic context if there's room
if self.dynamic_context:
context_addition = "\n\n## Previous Context\n" + "\n".join(self.dynamic_context[-3:])
if len(base_content) + len(context_addition) < self.max_context_size:
base_content += context_addition
# Add error history if relevant
if self.error_history:
error_addition = "\n\n## Recent Errors to Avoid\n" + "\n".join(self.error_history[-2:])
if len(base_content) + len(error_addition) < self.max_context_size:
base_content += error_addition
return base_content
def _optimize_prompt(self, content: str) -> str:
"""Optimize a prompt that's too large."""
logger.info("Optimizing large prompt")
# Strategy 1: Use stable prefix caching
if self.stable_prefix:
# Cache the stable prefix
prefix_hash = hashlib.sha256(self.stable_prefix.encode()).hexdigest()[:8]
cache_file = self.cache_dir / f"prefix_{prefix_hash}.txt"
if not cache_file.exists():
cache_file.write_text(self.stable_prefix)
# Reference the cached prefix instead of including it
optimized = f"<!-- Using cached prefix {prefix_hash} -->\n"
# Add the dynamic part
dynamic_part = content[len(self.stable_prefix):]
# Truncate if still too large
if len(dynamic_part) > self.max_context_size - 100:
dynamic_part = self._summarize_content(dynamic_part)
optimized += dynamic_part
return optimized
# Strategy 2: Summarize the content
return self._summarize_content(content)
def _summarize_content(self, content: str) -> str:
"""Summarize content to fit within limits."""
lines = content.split('\n')
# Keep headers and key instructions
important_lines = []
for line in lines:
if any([
line.startswith('#'),
# 'TODO' in line,
'IMPORTANT' in line,
'ERROR' in line,
line.startswith('- [ ]'), # Unchecked tasks
]):
important_lines.append(line)
summary = '\n'.join(important_lines)
# If still too long, truncate
if len(summary) > self.max_context_size:
summary = summary[:self.max_context_size - 100] + "\n<!-- Content truncated -->"
return summary
def update_context(self, output: str):
"""Update dynamic context based on agent output."""
# Extract key information from output
if "error" in output.lower():
# Track errors for learning
error_lines = [line for line in output.split('\n') if 'error' in line.lower()]
self.error_history.extend(error_lines[:2])
# Keep only recent errors
self.error_history = self.error_history[-5:]
if "success" in output.lower() or "complete" in output.lower():
# Track successful patterns
success_lines = [line for line in output.split('\n')
if any(word in line.lower() for word in ['success', 'complete', 'done'])]
self.success_patterns.extend(success_lines[:1])
self.success_patterns = self.success_patterns[-3:]
# Add to dynamic context (summarized)
if len(output) > 500:
summary = output[:200] + "..." + output[-200:]
self.dynamic_context.append(summary)
else:
self.dynamic_context.append(output)
# Keep dynamic context limited
self.dynamic_context = self.dynamic_context[-5:]
def add_error_feedback(self, error: str):
"""Add error feedback to context."""
self.error_history.append(f"Error: {error}")
self.error_history = self.error_history[-5:]
def reset(self):
"""Reset dynamic context."""
self.dynamic_context = []
self.error_history = []
self.success_patterns = []
logger.info("Context reset")
def get_stats(self) -> Dict:
"""Get context statistics."""
return {
"stable_prefix_size": len(self.stable_prefix) if self.stable_prefix else 0,
"dynamic_context_items": len(self.dynamic_context),
"error_history_items": len(self.error_history),
"success_patterns": len(self.success_patterns),
"cache_files": len(list(self.cache_dir.glob("*.txt")))
}

View File

@@ -0,0 +1,254 @@
# ABOUTME: Error formatter for Ralph Orchestrator
# ABOUTME: Provides structured error messages with user-friendly suggestions
"""
Error formatter for Ralph Orchestrator.
This module provides structured error messages with user-friendly suggestions
and security-aware error sanitization for Claude SDK and adapter errors.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
@dataclass
class ErrorMessage:
"""Formatted error message with main message and suggestion.
Attributes:
message: The main error message (user-facing)
suggestion: A helpful suggestion for resolving the error
"""
message: str
suggestion: str
def __str__(self) -> str:
"""Return combined error and suggestion."""
return f"{self.message} | {self.suggestion}"
class ClaudeErrorFormatter:
"""Formats error messages with user-friendly suggestions.
This class provides static methods to format various error types
encountered during Claude SDK operations into structured messages
with helpful suggestions for resolution.
All methods use security-aware sanitization to prevent information
disclosure of sensitive data in error messages.
"""
@staticmethod
def format_timeout_error(iteration: int, timeout: int) -> ErrorMessage:
"""Format timeout error message.
Args:
iteration: Current iteration number
timeout: Timeout limit in seconds
Returns:
Formatted error message with suggestion
"""
return ErrorMessage(
message=f"Iteration {iteration} exceeded timeout limit of {timeout}s",
suggestion="Try: Increase iteration_timeout in config or simplify your prompt"
)
@staticmethod
def format_process_terminated_error(iteration: int) -> ErrorMessage:
"""Format process termination error message.
Args:
iteration: Current iteration number
Returns:
Formatted error message with suggestion
"""
return ErrorMessage(
message=f"Iteration {iteration} failed: Claude subprocess terminated unexpectedly",
suggestion="Try: Check if Claude Code CLI is properly installed and has correct permissions"
)
@staticmethod
def format_interrupted_error(iteration: int) -> ErrorMessage:
"""Format interrupted error message (SIGTERM received).
Args:
iteration: Current iteration number
Returns:
Formatted error message with suggestion
"""
return ErrorMessage(
message=f"Iteration {iteration} was interrupted (SIGTERM)",
suggestion="This usually happens when stopping Ralph - no action needed"
)
@staticmethod
def format_command_failed_error(iteration: int) -> ErrorMessage:
"""Format command failed error message (exit code 1).
Args:
iteration: Current iteration number
Returns:
Formatted error message with suggestion
"""
return ErrorMessage(
message=f"Iteration {iteration} failed: Claude CLI command failed",
suggestion="Try: Check Claude CLI installation with 'claude --version' or verify API key with 'claude login'"
)
@staticmethod
def format_connection_error(iteration: int) -> ErrorMessage:
"""Format connection error message.
Args:
iteration: Current iteration number
Returns:
Formatted error message with suggestion
"""
return ErrorMessage(
message=f"Iteration {iteration} failed: Cannot connect to Claude CLI",
suggestion="Try: Verify Claude Code CLI is installed: claude --version"
)
@staticmethod
def format_rate_limit_error(iteration: int, retry_after: int = 60) -> ErrorMessage:
"""Format rate limit error message.
Args:
iteration: Current iteration number
retry_after: Seconds until retry is recommended
Returns:
Formatted error message with suggestion
"""
return ErrorMessage(
message=f"Iteration {iteration} failed: Rate limit exceeded",
suggestion=f"Try: Wait {retry_after}s before retrying or reduce request frequency"
)
@staticmethod
def format_authentication_error(iteration: int) -> ErrorMessage:
"""Format authentication error message.
Args:
iteration: Current iteration number
Returns:
Formatted error message with suggestion
"""
return ErrorMessage(
message=f"Iteration {iteration} failed: Authentication error",
suggestion="Try: Check your API credentials or re-authenticate with 'claude login'"
)
@staticmethod
def format_permission_error(iteration: int, path: str = "") -> ErrorMessage:
"""Format permission error message.
Args:
iteration: Current iteration number
path: Optional path that caused the permission error
Returns:
Formatted error message with suggestion
"""
# Sanitize path to avoid information disclosure
safe_path = path if path and len(path) < 100 else ""
if safe_path:
return ErrorMessage(
message=f"Iteration {iteration} failed: Permission denied for '{safe_path}'",
suggestion="Try: Check file permissions or run with appropriate privileges"
)
return ErrorMessage(
message=f"Iteration {iteration} failed: Permission denied",
suggestion="Try: Check file/directory permissions or run with appropriate privileges"
)
@staticmethod
def format_generic_error(iteration: int, error_type: str, error_str: str) -> ErrorMessage:
"""Format generic error message with security sanitization.
Args:
iteration: Current iteration number
error_type: Type of the exception (e.g., 'ValueError')
error_str: String representation of the error
Returns:
Formatted error message with sanitized content
"""
# Import here to avoid circular imports
from .security import SecurityValidator
# Sanitize error string to prevent information disclosure
sanitized_error_str = SecurityValidator.mask_sensitive_data(error_str)
# Truncate very long error messages
if len(sanitized_error_str) > 200:
sanitized_error_str = sanitized_error_str[:197] + "..."
return ErrorMessage(
message=f"Iteration {iteration} failed: {error_type}: {sanitized_error_str}",
suggestion="Check logs for details or try reducing prompt complexity"
)
@staticmethod
def format_error_from_exception(iteration: int, exception: Exception) -> ErrorMessage:
"""Format error message from exception.
Analyzes the exception type and message to provide the most
appropriate error format with helpful suggestions.
Args:
iteration: Current iteration number
exception: The exception that occurred
Returns:
Formatted error message with appropriate suggestion
"""
error_type = type(exception).__name__
error_str = str(exception)
# Match error patterns and provide specific suggestions
# Process transport issues (subprocess terminated)
if "ProcessTransport is not ready" in error_str:
return ClaudeErrorFormatter.format_process_terminated_error(iteration)
# SIGTERM interruption (exit code 143 = 128 + 15)
if "Command failed with exit code 143" in error_str:
return ClaudeErrorFormatter.format_interrupted_error(iteration)
# General command failure (exit code 1)
if "Command failed with exit code 1" in error_str:
return ClaudeErrorFormatter.format_command_failed_error(iteration)
# Connection errors
if error_type == "CLIConnectionError" or "connection" in error_str.lower():
return ClaudeErrorFormatter.format_connection_error(iteration)
# Timeout errors
if error_type in ("TimeoutError", "asyncio.TimeoutError") or "timeout" in error_str.lower():
return ClaudeErrorFormatter.format_timeout_error(iteration, 0)
# Rate limit errors
if "rate limit" in error_str.lower() or error_type == "RateLimitError":
return ClaudeErrorFormatter.format_rate_limit_error(iteration)
# Authentication errors
if "authentication" in error_str.lower() or "auth" in error_str.lower() or error_type == "AuthenticationError":
return ClaudeErrorFormatter.format_authentication_error(iteration)
# Permission errors
if error_type == "PermissionError" or "permission denied" in error_str.lower():
return ClaudeErrorFormatter.format_permission_error(iteration)
# Fall back to generic error format
return ClaudeErrorFormatter.format_generic_error(iteration, error_type, error_str)

View File

@@ -0,0 +1,218 @@
# ABOUTME: Logging configuration module for Ralph Orchestrator
# ABOUTME: Provides centralized logging setup with proper formatters and handlers
"""Logging configuration for Ralph Orchestrator."""
import logging
import logging.handlers
import os
import sys
from pathlib import Path
from typing import Optional, Dict, Any
class RalphLogger:
"""Centralized logging configuration for Ralph Orchestrator."""
# Logger names for different components
ORCHESTRATOR = "ralph.orchestrator"
ADAPTER_BASE = "ralph.adapter"
ADAPTER_QCHAT = "ralph.adapter.qchat"
ADAPTER_KIRO = "ralph.adapter.kiro"
ADAPTER_CLAUDE = "ralph.adapter.claude"
ADAPTER_GEMINI = "ralph.adapter.gemini"
SAFETY = "ralph.safety"
METRICS = "ralph.metrics"
# Default log format
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
DETAILED_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(funcName)s() - %(message)s"
_initialized = False
_log_dir: Optional[Path] = None
@classmethod
def initialize(cls,
log_level: Optional[str] = None,
log_file: Optional[str] = None,
log_dir: Optional[str] = None,
console_output: Optional[bool] = None,
detailed_format: bool = False) -> None:
"""Initialize logging configuration.
Args:
log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_file: Path to log file (optional)
log_dir: Directory for log files (optional)
console_output: Whether to output to console
detailed_format: Use detailed format with file/line info
"""
if cls._initialized:
return
# Get configuration from environment variables
log_level = log_level or os.getenv("RALPH_LOG_LEVEL", "INFO")
log_file = log_file or os.getenv("RALPH_LOG_FILE")
log_dir = log_dir or os.getenv("RALPH_LOG_DIR", ".logs")
# Handle console_output properly - only use env var if not explicitly set
if console_output is None:
console_output = os.getenv("RALPH_LOG_CONSOLE", "true").lower() == "true"
detailed_format = detailed_format or \
os.getenv("RALPH_LOG_DETAILED", "false").lower() == "true"
# Convert log level string to logging constant
numeric_level = getattr(logging, log_level.upper(), logging.INFO)
# Choose format
log_format = cls.DETAILED_FORMAT if detailed_format else cls.DEFAULT_FORMAT
# Create formatter
formatter = logging.Formatter(log_format)
# Configure root logger
root_logger = logging.getLogger("ralph")
root_logger.setLevel(numeric_level)
root_logger.handlers = [] # Clear existing handlers
# Add console handler
if console_output:
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setFormatter(formatter)
console_handler.setLevel(numeric_level)
root_logger.addHandler(console_handler)
# Add file handler if specified
if log_file or log_dir:
cls._setup_file_handler(root_logger, formatter, log_file, log_dir, numeric_level)
cls._initialized = True
# Suppress verbose INFO logs from claude-agent-sdk internals
# The SDK logs operational details at INFO level (e.g., "Using bundled Claude Code CLI")
logging.getLogger('claude_agent_sdk').setLevel(logging.WARNING)
# Log initialization
logger = logging.getLogger(cls.ORCHESTRATOR)
logger.debug(f"Logging initialized - Level: {log_level}, Console: {console_output}, "
f"File: {log_file or 'None'}, Dir: {log_dir or 'None'}")
@classmethod
def _setup_file_handler(cls,
logger: logging.Logger,
formatter: logging.Formatter,
log_file: Optional[str],
log_dir: Optional[str],
level: int) -> None:
"""Setup file handler for logging.
Args:
logger: Logger to add handler to
formatter: Log formatter
log_file: Specific log file path
log_dir: Directory for log files
level: Logging level
"""
# Determine log file path
if log_file:
log_path = Path(log_file)
else:
# Use log directory with default filename
cls._log_dir = Path(log_dir)
cls._log_dir.mkdir(parents=True, exist_ok=True)
log_path = cls._log_dir / "ralph_orchestrator.log"
# Create parent directories if needed
log_path.parent.mkdir(parents=True, exist_ok=True)
# Use rotating file handler
max_bytes = int(os.getenv("RALPH_LOG_MAX_BYTES", "10485760")) # 10MB default
backup_count = int(os.getenv("RALPH_LOG_BACKUP_COUNT", "5"))
file_handler = logging.handlers.RotatingFileHandler(
log_path,
maxBytes=max_bytes,
backupCount=backup_count
)
file_handler.setFormatter(formatter)
file_handler.setLevel(level)
logger.addHandler(file_handler)
@classmethod
def get_logger(cls, name: str) -> logging.Logger:
"""Get a logger instance.
Args:
name: Logger name (use class constants for consistency)
Returns:
Configured logger instance
"""
if not cls._initialized:
cls.initialize()
return logging.getLogger(name)
@classmethod
def log_config(cls) -> Dict[str, Any]:
"""Get current logging configuration.
Returns:
Dictionary with current logging settings
"""
root_logger = logging.getLogger("ralph")
config = {
"level": logging.getLevelName(root_logger.level),
"handlers": [],
"log_dir": str(cls._log_dir) if cls._log_dir else None,
"initialized": cls._initialized
}
for handler in root_logger.handlers:
handler_info = {
"type": handler.__class__.__name__,
"level": logging.getLevelName(handler.level)
}
if hasattr(handler, 'baseFilename'):
handler_info["file"] = handler.baseFilename
config["handlers"].append(handler_info)
return config
@classmethod
def set_level(cls, level: str, logger_name: Optional[str] = None) -> None:
"""Dynamically set logging level.
Args:
level: New logging level
logger_name: Specific logger to update (None for root)
"""
numeric_level = getattr(logging, level.upper(), logging.INFO)
if logger_name:
logger = logging.getLogger(logger_name)
else:
logger = logging.getLogger("ralph")
logger.setLevel(numeric_level)
# Update handlers
for handler in logger.handlers:
handler.setLevel(numeric_level)
# Convenience function for getting loggers
def get_logger(name: str) -> logging.Logger:
"""Get a configured logger instance.
Args:
name: Logger name
Returns:
Logger instance
"""
return RalphLogger.get_logger(name)

View File

@@ -0,0 +1,626 @@
#!/usr/bin/env python3
# ABOUTME: Ralph orchestrator main loop implementation with multi-agent support
# ABOUTME: Implements the core Ralph Wiggum technique with continuous iteration
import sys
import logging
import argparse
import threading
import yaml
from pathlib import Path
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
from enum import Enum
from .orchestrator import RalphOrchestrator
# Configuration defaults
DEFAULT_MAX_ITERATIONS = 100
DEFAULT_MAX_RUNTIME = 14400 # 4 hours
DEFAULT_PROMPT_FILE = "PROMPT.md"
DEFAULT_CHECKPOINT_INTERVAL = 5
DEFAULT_RETRY_DELAY = 2
DEFAULT_MAX_TOKENS = 1000000 # 1M tokens total
DEFAULT_MAX_COST = 50.0 # $50 USD
DEFAULT_CONTEXT_WINDOW = 200000 # 200K token context window
DEFAULT_CONTEXT_THRESHOLD = 0.8 # Trigger summarization at 80% of context
DEFAULT_METRICS_INTERVAL = 10 # Log metrics every 10 iterations
DEFAULT_MAX_PROMPT_SIZE = 10485760 # 10MB max prompt file size
DEFAULT_COMPLETION_PROMISE = "LOOP_COMPLETE"
# Token costs per million (approximate)
TOKEN_COSTS = {
"claude": {"input": 3.0, "output": 15.0}, # Claude 3.5 Sonnet
"q": {"input": 0.5, "output": 1.5}, # Estimated
"kiro": {"input": 0.5, "output": 1.5}, # Estimated
"gemini": {"input": 0.5, "output": 1.5} # Gemini Pro
}
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logger = logging.getLogger('ralph-orchestrator')
class AgentType(Enum):
"""Supported AI agent types"""
CLAUDE = "claude"
Q = "q"
KIRO = "kiro"
GEMINI = "gemini"
ACP = "acp"
AUTO = "auto"
class ConfigValidator:
"""Validates Ralph configuration settings.
Provides validation methods for configuration parameters with security
checks and warnings for unusual values.
"""
# Validation thresholds
LARGE_DELAY_THRESHOLD_SECONDS = 3600 # 1 hour
SHORT_TIMEOUT_THRESHOLD_SECONDS = 10 # Very short timeout
TYPICAL_AI_ITERATION_MIN_SECONDS = 30 # Typical minimum time for AI iteration
TYPICAL_AI_ITERATION_MAX_SECONDS = 300 # Typical maximum time for AI iteration
# Reasonable limits to prevent resource exhaustion
MAX_ITERATIONS_LIMIT = 100000
MAX_RUNTIME_LIMIT = 604800 # 1 week in seconds
MAX_TOKENS_LIMIT = 100000000 # 100M tokens
MAX_COST_LIMIT = 10000.0 # $10K USD
@staticmethod
def validate_max_iterations(max_iterations: int) -> List[str]:
"""Validate max iterations parameter."""
errors = []
if max_iterations < 0:
errors.append("Max iterations must be non-negative")
elif max_iterations > ConfigValidator.MAX_ITERATIONS_LIMIT:
errors.append(f"Max iterations exceeds limit ({ConfigValidator.MAX_ITERATIONS_LIMIT})")
return errors
@staticmethod
def validate_max_runtime(max_runtime: int) -> List[str]:
"""Validate max runtime parameter."""
errors = []
if max_runtime < 0:
errors.append("Max runtime must be non-negative")
elif max_runtime > ConfigValidator.MAX_RUNTIME_LIMIT:
errors.append(f"Max runtime exceeds limit ({ConfigValidator.MAX_RUNTIME_LIMIT}s)")
return errors
@staticmethod
def validate_checkpoint_interval(checkpoint_interval: int) -> List[str]:
"""Validate checkpoint interval parameter."""
errors = []
if checkpoint_interval < 0:
errors.append("Checkpoint interval must be non-negative")
return errors
@staticmethod
def validate_retry_delay(retry_delay: int) -> List[str]:
"""Validate retry delay parameter."""
errors = []
if retry_delay < 0:
errors.append("Retry delay must be non-negative")
elif retry_delay > ConfigValidator.LARGE_DELAY_THRESHOLD_SECONDS:
errors.append(f"Retry delay exceeds limit ({ConfigValidator.LARGE_DELAY_THRESHOLD_SECONDS}s)")
return errors
@staticmethod
def validate_max_tokens(max_tokens: int) -> List[str]:
"""Validate max tokens parameter."""
errors = []
if max_tokens < 0:
errors.append("Max tokens must be non-negative")
elif max_tokens > ConfigValidator.MAX_TOKENS_LIMIT:
errors.append(f"Max tokens exceeds limit ({ConfigValidator.MAX_TOKENS_LIMIT})")
return errors
@staticmethod
def validate_max_cost(max_cost: float) -> List[str]:
"""Validate max cost parameter."""
errors = []
if max_cost < 0:
errors.append("Max cost must be non-negative")
elif max_cost > ConfigValidator.MAX_COST_LIMIT:
errors.append(f"Max cost exceeds limit (${ConfigValidator.MAX_COST_LIMIT})")
return errors
@staticmethod
def validate_context_threshold(context_threshold: float) -> List[str]:
"""Validate context threshold parameter."""
errors = []
if not 0.0 <= context_threshold <= 1.0:
errors.append("Context threshold must be between 0.0 and 1.0")
return errors
@staticmethod
def validate_prompt_file(prompt_file: str) -> List[str]:
"""Validate prompt file exists and is readable."""
errors = []
path = Path(prompt_file)
if not path.exists():
errors.append(f"Prompt file not found: {prompt_file}")
elif not path.is_file():
errors.append(f"Prompt file is not a regular file: {prompt_file}")
return errors
@staticmethod
def get_warning_large_delay(retry_delay: int) -> List[str]:
"""Check for unusually large delay values."""
if retry_delay > ConfigValidator.LARGE_DELAY_THRESHOLD_SECONDS:
return [
f"Warning: Retry delay is very large ({retry_delay}s = {retry_delay/60:.1f}m). "
f"Did you mean to use minutes instead of seconds?"
]
return []
@staticmethod
def get_warning_single_iteration(max_iterations: int) -> List[str]:
"""Check for max_iterations=1."""
if max_iterations == 1:
return [
"Warning: max_iterations is 1. "
"Ralph is designed for continuous loops. Did you mean 0 (infinite)?"
]
return []
@staticmethod
def get_warning_short_timeout(max_runtime: int) -> List[str]:
"""Check for very short runtime limits."""
if 0 < max_runtime < ConfigValidator.SHORT_TIMEOUT_THRESHOLD_SECONDS:
return [
f"Warning: Max runtime is very short ({max_runtime}s). "
f"AI iterations typically take {ConfigValidator.TYPICAL_AI_ITERATION_MIN_SECONDS}-"
f"{ConfigValidator.TYPICAL_AI_ITERATION_MAX_SECONDS} seconds."
]
return []
@dataclass
class AdapterConfig:
"""Configuration for individual adapters"""
enabled: bool = True
args: List[str] = field(default_factory=list)
env: Dict[str, str] = field(default_factory=dict)
timeout: int = 300
max_retries: int = 3
tool_permissions: Dict[str, Any] = field(default_factory=dict)
@dataclass
class RalphConfig:
"""Configuration for Ralph orchestrator.
Thread-safe configuration class with RLock protection for mutable fields.
Provides both direct attribute access (backwards compatible) and thread-safe
getter/setter methods for concurrent access scenarios.
"""
# Core configuration fields
agent: AgentType = AgentType.AUTO
# Agent selection and fallback priority (used when agent=auto, and for fallback ordering)
# Valid values: "acp", "claude", "gemini", "qchat" (also accepts aliases: "codex"->"acp", "q"->"qchat")
agent_priority: List[str] = field(default_factory=lambda: ["claude", "kiro", "qchat", "gemini", "acp"])
prompt_file: str = DEFAULT_PROMPT_FILE
prompt_text: Optional[str] = None # Direct prompt text (overrides prompt_file)
completion_promise: Optional[str] = DEFAULT_COMPLETION_PROMISE # String to match in agent output to stop
max_iterations: int = DEFAULT_MAX_ITERATIONS
max_runtime: int = DEFAULT_MAX_RUNTIME
checkpoint_interval: int = DEFAULT_CHECKPOINT_INTERVAL
retry_delay: int = DEFAULT_RETRY_DELAY
archive_prompts: bool = True
git_checkpoint: bool = True
verbose: bool = False
dry_run: bool = False
max_tokens: int = DEFAULT_MAX_TOKENS
max_cost: float = DEFAULT_MAX_COST
context_window: int = DEFAULT_CONTEXT_WINDOW
context_threshold: float = DEFAULT_CONTEXT_THRESHOLD
metrics_interval: int = DEFAULT_METRICS_INTERVAL
enable_metrics: bool = True
max_prompt_size: int = DEFAULT_MAX_PROMPT_SIZE
allow_unsafe_paths: bool = False
agent_args: List[str] = field(default_factory=list)
adapters: Dict[str, AdapterConfig] = field(default_factory=dict)
# Output formatting configuration
output_format: str = "rich" # "plain", "rich", or "json"
output_verbosity: str = "normal" # "quiet", "normal", "verbose", "debug"
show_token_usage: bool = True # Display token usage after iterations
show_timestamps: bool = True # Include timestamps in output
# Thread safety lock - not included in initialization/equals
_lock: threading.RLock = field(
default_factory=threading.RLock, init=False, repr=False, compare=False
)
# Thread-safe property access methods for mutable fields
def get_max_iterations(self) -> int:
"""Thread-safe access to max_iterations property."""
with self._lock:
return self.max_iterations
def set_max_iterations(self, value: int) -> None:
"""Thread-safe setting of max_iterations property."""
with self._lock:
object.__setattr__(self, 'max_iterations', value)
def get_max_runtime(self) -> int:
"""Thread-safe access to max_runtime property."""
with self._lock:
return self.max_runtime
def set_max_runtime(self, value: int) -> None:
"""Thread-safe setting of max_runtime property."""
with self._lock:
object.__setattr__(self, 'max_runtime', value)
def get_checkpoint_interval(self) -> int:
"""Thread-safe access to checkpoint_interval property."""
with self._lock:
return self.checkpoint_interval
def set_checkpoint_interval(self, value: int) -> None:
"""Thread-safe setting of checkpoint_interval property."""
with self._lock:
object.__setattr__(self, 'checkpoint_interval', value)
def get_retry_delay(self) -> int:
"""Thread-safe access to retry_delay property."""
with self._lock:
return self.retry_delay
def set_retry_delay(self, value: int) -> None:
"""Thread-safe setting of retry_delay property."""
with self._lock:
object.__setattr__(self, 'retry_delay', value)
def get_max_tokens(self) -> int:
"""Thread-safe access to max_tokens property."""
with self._lock:
return self.max_tokens
def set_max_tokens(self, value: int) -> None:
"""Thread-safe setting of max_tokens property."""
with self._lock:
object.__setattr__(self, 'max_tokens', value)
def get_max_cost(self) -> float:
"""Thread-safe access to max_cost property."""
with self._lock:
return self.max_cost
def set_max_cost(self, value: float) -> None:
"""Thread-safe setting of max_cost property."""
with self._lock:
object.__setattr__(self, 'max_cost', value)
def get_verbose(self) -> bool:
"""Thread-safe access to verbose property."""
with self._lock:
return self.verbose
def set_verbose(self, value: bool) -> None:
"""Thread-safe setting of verbose property."""
with self._lock:
object.__setattr__(self, 'verbose', value)
@classmethod
def from_yaml(cls, config_path: str) -> 'RalphConfig':
"""Load configuration from YAML file."""
config_file = Path(config_path)
if not config_file.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
with open(config_file, 'r') as f:
config_data = yaml.safe_load(f)
# Convert agent string to AgentType enum
if 'agent' in config_data:
config_data['agent'] = AgentType(config_data['agent'])
# Process adapter configurations
if 'adapters' in config_data:
adapter_configs = {}
for name, adapter_data in config_data['adapters'].items():
if isinstance(adapter_data, dict):
adapter_configs[name] = AdapterConfig(**adapter_data)
else:
# Simple boolean enable/disable
adapter_configs[name] = AdapterConfig(enabled=bool(adapter_data))
config_data['adapters'] = adapter_configs
# Filter out unknown keys
valid_keys = {f.name for f in cls.__dataclass_fields__.values()}
filtered_data = {k: v for k, v in config_data.items() if k in valid_keys}
return cls(**filtered_data)
def get_adapter_config(self, adapter_name: str) -> AdapterConfig:
"""Get configuration for a specific adapter."""
with self._lock:
return self.adapters.get(adapter_name, AdapterConfig())
def validate(self) -> List[str]:
"""Validate configuration settings.
Returns:
List of validation errors (empty if valid).
"""
errors = []
with self._lock:
errors.extend(ConfigValidator.validate_max_iterations(self.max_iterations))
errors.extend(ConfigValidator.validate_max_runtime(self.max_runtime))
errors.extend(ConfigValidator.validate_checkpoint_interval(self.checkpoint_interval))
errors.extend(ConfigValidator.validate_retry_delay(self.retry_delay))
errors.extend(ConfigValidator.validate_max_tokens(self.max_tokens))
errors.extend(ConfigValidator.validate_max_cost(self.max_cost))
errors.extend(ConfigValidator.validate_context_threshold(self.context_threshold))
return errors
def get_warnings(self) -> List[str]:
"""Get configuration warnings (non-blocking issues).
Returns:
List of warning messages.
"""
warnings = []
with self._lock:
warnings.extend(ConfigValidator.get_warning_large_delay(self.retry_delay))
warnings.extend(ConfigValidator.get_warning_single_iteration(self.max_iterations))
warnings.extend(ConfigValidator.get_warning_short_timeout(self.max_runtime))
return warnings
def create_output_formatter(self):
"""Create an output formatter based on configuration settings.
Returns:
OutputFormatter instance configured according to settings.
"""
from ralph_orchestrator.output import VerbosityLevel, create_formatter
# Map verbosity string to enum
verbosity_map = {
"quiet": VerbosityLevel.QUIET,
"normal": VerbosityLevel.NORMAL,
"verbose": VerbosityLevel.VERBOSE,
"debug": VerbosityLevel.DEBUG,
}
with self._lock:
verbosity = verbosity_map.get(self.output_verbosity.lower(), VerbosityLevel.NORMAL)
return create_formatter(
format_type=self.output_format,
verbosity=verbosity,
)
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(
description="Ralph Wiggum Orchestrator - Put AI in a loop until done"
)
parser.add_argument(
"--agent", "-a",
type=str,
choices=["claude", "q", "kiro", "gemini", "acp", "auto"],
default="auto",
help="AI agent to use (default: auto-detect)"
)
parser.add_argument(
"--prompt-file", "-P",
type=str,
default=DEFAULT_PROMPT_FILE,
dest="prompt",
help="Prompt file path (default: PROMPT.md)"
)
parser.add_argument(
"--prompt-text", "-p",
type=str,
default=None,
help="Direct prompt text (overrides --prompt-file)"
)
parser.add_argument(
"--completion-promise",
type=str,
default=DEFAULT_COMPLETION_PROMISE,
help=f"Stop when agent output contains this exact string (default: {DEFAULT_COMPLETION_PROMISE})"
)
parser.add_argument(
"--max-iterations", "-i",
type=int,
default=DEFAULT_MAX_ITERATIONS,
help=f"Maximum iterations (default: {DEFAULT_MAX_ITERATIONS})"
)
parser.add_argument(
"--max-runtime", "-t",
type=int,
default=DEFAULT_MAX_RUNTIME,
help=f"Maximum runtime in seconds (default: {DEFAULT_MAX_RUNTIME})"
)
parser.add_argument(
"--checkpoint-interval", "-c",
type=int,
default=DEFAULT_CHECKPOINT_INTERVAL,
help=f"Checkpoint interval (default: {DEFAULT_CHECKPOINT_INTERVAL})"
)
parser.add_argument(
"--retry-delay", "-r",
type=int,
default=DEFAULT_RETRY_DELAY,
help=f"Retry delay in seconds (default: {DEFAULT_RETRY_DELAY})"
)
parser.add_argument(
"--max-tokens",
type=int,
default=DEFAULT_MAX_TOKENS,
help=f"Maximum total tokens (default: {DEFAULT_MAX_TOKENS:,})"
)
parser.add_argument(
"--max-cost",
type=float,
default=DEFAULT_MAX_COST,
help=f"Maximum cost in USD (default: ${DEFAULT_MAX_COST:.2f})"
)
parser.add_argument(
"--context-window",
type=int,
default=DEFAULT_CONTEXT_WINDOW,
help=f"Context window size in tokens (default: {DEFAULT_CONTEXT_WINDOW:,})"
)
parser.add_argument(
"--context-threshold",
type=float,
default=DEFAULT_CONTEXT_THRESHOLD,
help=f"Context summarization threshold (default: {DEFAULT_CONTEXT_THRESHOLD:.1f} = {DEFAULT_CONTEXT_THRESHOLD*100:.0f}%%)"
)
parser.add_argument(
"--metrics-interval",
type=int,
default=DEFAULT_METRICS_INTERVAL,
help=f"Metrics logging interval (default: {DEFAULT_METRICS_INTERVAL})"
)
parser.add_argument(
"--no-metrics",
action="store_true",
help="Disable metrics collection"
)
parser.add_argument(
"--max-prompt-size",
type=int,
default=DEFAULT_MAX_PROMPT_SIZE,
help=f"Maximum prompt file size in bytes (default: {DEFAULT_MAX_PROMPT_SIZE})"
)
parser.add_argument(
"--allow-unsafe-paths",
action="store_true",
help="Allow potentially unsafe prompt paths (use with caution)"
)
parser.add_argument(
"--no-git",
action="store_true",
help="Disable git checkpointing"
)
parser.add_argument(
"--no-archive",
action="store_true",
help="Disable prompt archiving"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose output"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Dry run mode (don't execute agents)"
)
# Output formatting options
parser.add_argument(
"--output-format",
type=str,
choices=["plain", "rich", "json"],
default="rich",
help="Output format (default: rich)"
)
parser.add_argument(
"--output-verbosity",
type=str,
choices=["quiet", "normal", "verbose", "debug"],
default="normal",
help="Output verbosity level (default: normal)"
)
parser.add_argument(
"--no-token-usage",
action="store_true",
help="Disable token usage display"
)
parser.add_argument(
"--no-timestamps",
action="store_true",
help="Disable timestamps in output"
)
parser.add_argument(
"agent_args",
nargs="*",
help="Additional arguments to pass to the AI agent"
)
args = parser.parse_args()
# Configure logging level
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Create config
config = RalphConfig(
agent=AgentType(args.agent),
prompt_file=args.prompt,
prompt_text=args.prompt_text,
max_iterations=args.max_iterations,
max_runtime=args.max_runtime,
checkpoint_interval=args.checkpoint_interval,
retry_delay=args.retry_delay,
archive_prompts=not args.no_archive,
git_checkpoint=not args.no_git,
verbose=args.verbose,
dry_run=args.dry_run,
max_tokens=args.max_tokens,
max_cost=args.max_cost,
context_window=args.context_window,
context_threshold=args.context_threshold,
metrics_interval=args.metrics_interval,
enable_metrics=not args.no_metrics,
max_prompt_size=args.max_prompt_size,
allow_unsafe_paths=args.allow_unsafe_paths,
agent_args=args.agent_args,
completion_promise=args.completion_promise,
# Output formatting options
output_format=args.output_format,
output_verbosity=args.output_verbosity,
show_token_usage=not args.no_token_usage,
show_timestamps=not args.no_timestamps,
)
# Run orchestrator
orchestrator = RalphOrchestrator(config)
return orchestrator.run()
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,345 @@
# ABOUTME: Metrics tracking and cost calculation for Ralph Orchestrator
# ABOUTME: Monitors performance, usage, and costs across different AI tools
"""Metrics and cost tracking for Ralph Orchestrator."""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, List, Any
import time
import json
class TriggerReason(str, Enum):
"""Reasons why an iteration was triggered.
Used for per-iteration telemetry to understand why the orchestrator
started each iteration, enabling analysis of orchestration patterns.
"""
INITIAL = "initial" # First iteration of a session
TASK_INCOMPLETE = "task_incomplete" # Previous iteration didn't complete task
PREVIOUS_SUCCESS = "previous_success" # Previous iteration succeeded, continuing
RECOVERY = "recovery" # Recovering from a previous failure
LOOP_DETECTED = "loop_detected" # Loop detection triggered intervention
SAFETY_LIMIT = "safety_limit" # Safety limits triggered
USER_STOP = "user_stop" # User requested stop
@dataclass
class Metrics:
"""Track orchestration metrics."""
iterations: int = 0
successful_iterations: int = 0
failed_iterations: int = 0
errors: int = 0
checkpoints: int = 0
rollbacks: int = 0
start_time: float = field(default_factory=time.time)
def elapsed_hours(self) -> float:
"""Get elapsed time in hours."""
return (time.time() - self.start_time) / 3600
def success_rate(self) -> float:
"""Calculate success rate."""
total = self.successful_iterations + self.failed_iterations
if total == 0:
return 0.0
return self.successful_iterations / total
def to_dict(self) -> Dict:
"""Convert to dictionary."""
return {
"iterations": self.iterations,
"successful_iterations": self.successful_iterations,
"failed_iterations": self.failed_iterations,
"errors": self.errors,
"checkpoints": self.checkpoints,
"rollbacks": self.rollbacks,
"elapsed_hours": self.elapsed_hours(),
"success_rate": self.success_rate()
}
def to_json(self) -> str:
"""Convert to JSON string."""
return json.dumps(self.to_dict(), indent=2)
class CostTracker:
"""Track costs across different AI tools."""
# Cost per 1K tokens (approximate)
COSTS = {
"claude": {
"input": 0.003, # $3 per 1M input tokens
"output": 0.015 # $15 per 1M output tokens
},
"gemini": {
"input": 0.00025, # $0.25 per 1M input tokens
"output": 0.001 # $1 per 1M output tokens
},
"qchat": {
"input": 0.0, # Free/local
"output": 0.0
},
"acp": {
"input": 0.0, # ACP doesn't provide billing info
"output": 0.0 # Cost depends on underlying agent
},
"gpt-4": {
"input": 0.03, # $30 per 1M input tokens
"output": 0.06 # $60 per 1M output tokens
}
}
def __init__(self):
"""Initialize cost tracker."""
self.total_cost = 0.0
self.costs_by_tool: Dict[str, float] = {}
self.usage_history: List[Dict] = []
def add_usage(
self,
tool: str,
input_tokens: int,
output_tokens: int
) -> float:
"""Add usage and calculate cost.
Args:
tool: Name of the AI tool
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Cost for this usage
"""
if tool not in self.COSTS:
tool = "qchat" # Default to free tier
costs = self.COSTS[tool]
input_cost = (input_tokens / 1000) * costs["input"]
output_cost = (output_tokens / 1000) * costs["output"]
total = input_cost + output_cost
# Update tracking
self.total_cost += total
if tool not in self.costs_by_tool:
self.costs_by_tool[tool] = 0.0
self.costs_by_tool[tool] += total
# Add to history
self.usage_history.append({
"timestamp": time.time(),
"tool": tool,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": total
})
return total
def get_summary(self) -> Dict:
"""Get cost summary."""
return {
"total_cost": self.total_cost,
"costs_by_tool": self.costs_by_tool,
"usage_count": len(self.usage_history),
"average_cost": self.total_cost / len(self.usage_history) if self.usage_history else 0
}
def to_json(self) -> str:
"""Convert to JSON string."""
return json.dumps(self.get_summary(), indent=2)
@dataclass
class IterationStats:
"""Memory-efficient iteration statistics tracking.
Tracks per-iteration details (duration, success/failure, errors) while
limiting stored iterations to prevent memory leaks in long-running sessions.
"""
total: int = 0
successes: int = 0
failures: int = 0
start_time: datetime | None = None
current_iteration: int = 0
iterations: List[Dict[str, Any]] = field(default_factory=list)
max_iterations_stored: int = 1000 # Memory limit for stored iterations
max_preview_length: int = 500 # Max chars for output preview truncation
def __post_init__(self) -> None:
"""Initialize start time if not set."""
if self.start_time is None:
self.start_time = datetime.now()
def record_start(self, iteration: int) -> None:
"""Record iteration start.
Args:
iteration: Iteration number
"""
self.current_iteration = iteration
self.total = max(self.total, iteration)
def record_success(self, iteration: int) -> None:
"""Record successful iteration.
Args:
iteration: Iteration number
"""
self.total = iteration
self.successes += 1
def record_failure(self, iteration: int) -> None:
"""Record failed iteration.
Args:
iteration: Iteration number
"""
self.total = iteration
self.failures += 1
def record_iteration(
self,
iteration: int,
duration: float,
success: bool,
error: str,
trigger_reason: str = "",
output_preview: str = "",
tokens_used: int = 0,
cost: float = 0.0,
tools_used: List[str] | None = None,
) -> None:
"""Record iteration with full details.
Args:
iteration: Iteration number
duration: Duration in seconds
success: Whether iteration was successful
error: Error message if any
trigger_reason: Why this iteration was triggered (from TriggerReason)
output_preview: Preview of iteration output (truncated for privacy)
tokens_used: Total tokens consumed in this iteration
cost: Cost in dollars for this iteration
tools_used: List of tools/MCPs invoked during iteration
"""
# Update basic statistics
self.total = max(self.total, iteration)
self.current_iteration = iteration
if success:
self.successes += 1
else:
self.failures += 1
# Truncate output preview for privacy (configurable length)
if output_preview and len(output_preview) > self.max_preview_length:
output_preview = output_preview[:self.max_preview_length] + "..."
# Store detailed iteration information
iteration_data = {
"iteration": iteration,
"duration": duration,
"success": success,
"error": error,
"timestamp": datetime.now().isoformat(),
"trigger_reason": trigger_reason,
"output_preview": output_preview,
"tokens_used": tokens_used,
"cost": cost,
"tools_used": tools_used or [],
}
self.iterations.append(iteration_data)
# Enforce memory limit by evicting oldest entries
if len(self.iterations) > self.max_iterations_stored:
excess = len(self.iterations) - self.max_iterations_stored
self.iterations = self.iterations[excess:]
def get_success_rate(self) -> float:
"""Calculate success rate as percentage.
Returns:
Success rate (0-100)
"""
total_attempts = self.successes + self.failures
if total_attempts == 0:
return 0.0
return (self.successes / total_attempts) * 100
def get_runtime(self) -> str:
"""Get human-readable runtime duration.
Returns:
Runtime string (e.g., "2h 30m 15s")
"""
if self.start_time is None:
return "Unknown"
delta = datetime.now() - self.start_time
hours, remainder = divmod(int(delta.total_seconds()), 3600)
minutes, seconds = divmod(remainder, 60)
if hours > 0:
return f"{hours}h {minutes}m {seconds}s"
if minutes > 0:
return f"{minutes}m {seconds}s"
return f"{seconds}s"
def get_recent_iterations(self, count: int) -> List[Dict[str, Any]]:
"""Get most recent iterations.
Args:
count: Maximum number of iterations to return
Returns:
List of recent iteration data dictionaries
"""
if count >= len(self.iterations):
return self.iterations.copy()
return self.iterations[-count:]
def get_average_duration(self) -> float:
"""Calculate average iteration duration.
Returns:
Average duration in seconds, or 0.0 if no iterations
"""
if not self.iterations:
return 0.0
total_duration = sum(it["duration"] for it in self.iterations)
return total_duration / len(self.iterations)
def get_error_messages(self) -> List[str]:
"""Extract error messages from failed iterations.
Returns:
List of non-empty error messages
"""
return [
it["error"]
for it in self.iterations
if not it["success"] and it["error"]
]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization.
Returns:
Stats as dictionary (excludes iteration list for compatibility)
"""
return {
"total": self.total,
"current": self.current_iteration,
"successes": self.successes,
"failures": self.failures,
"success_rate": self.get_success_rate(),
"runtime": self.get_runtime(),
"start_time": self.start_time.isoformat() if self.start_time else None,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,121 @@
# ABOUTME: Output formatter module initialization
# ABOUTME: Exports base classes, formatter implementations, and legacy console classes
"""Output formatting module for Claude adapter responses.
This module provides:
1. Legacy output utilities (backward compatible):
- DiffStats, DiffFormatter, RalphConsole - Rich terminal utilities
- RICH_AVAILABLE - Rich library availability flag
2. New formatter classes for structured output:
- PlainTextFormatter: Basic text output without colors
- RichTerminalFormatter: Rich terminal output with colors and panels
- JsonFormatter: Structured JSON output for programmatic consumption
Example usage (new formatters):
from ralph_orchestrator.output import (
RichTerminalFormatter,
VerbosityLevel,
ToolCallInfo,
)
formatter = RichTerminalFormatter(verbosity=VerbosityLevel.VERBOSE)
tool_info = ToolCallInfo(tool_name="Read", tool_id="abc123", input_params={"path": "test.py"})
output = formatter.format_tool_call(tool_info, iteration=1)
formatter.print(output)
Example usage (legacy console):
from ralph_orchestrator.output import RalphConsole
console = RalphConsole()
console.print_status("Processing...")
console.print_success("Done!")
"""
# Import legacy classes from console module (backward compatibility)
from .console import (
RICH_AVAILABLE,
DiffFormatter,
DiffStats,
RalphConsole,
)
# Import new formatter base classes
from .base import (
FormatContext,
MessageType,
OutputFormatter,
TokenUsage,
ToolCallInfo,
VerbosityLevel,
)
# Import content detection
from .content_detector import ContentDetector, ContentType
# Import new formatter implementations
from .json_formatter import JsonFormatter
from .plain import PlainTextFormatter
from .rich_formatter import RichTerminalFormatter
__all__ = [
# Legacy exports (backward compatibility)
"RICH_AVAILABLE",
"DiffStats",
"DiffFormatter",
"RalphConsole",
# New base classes
"OutputFormatter",
"VerbosityLevel",
"MessageType",
"TokenUsage",
"ToolCallInfo",
"FormatContext",
# Content detection
"ContentDetector",
"ContentType",
# New formatters
"PlainTextFormatter",
"RichTerminalFormatter",
"JsonFormatter",
# Factory function
"create_formatter",
]
def create_formatter(
format_type: str = "rich",
verbosity: VerbosityLevel = VerbosityLevel.NORMAL,
**kwargs,
) -> OutputFormatter:
"""Factory function to create appropriate formatter.
Args:
format_type: Type of formatter ("plain", "rich", "json")
verbosity: Verbosity level for output
**kwargs: Additional arguments passed to formatter constructor
Returns:
Configured OutputFormatter instance
Raises:
ValueError: If format_type is not recognized
"""
formatters = {
"plain": PlainTextFormatter,
"text": PlainTextFormatter,
"rich": RichTerminalFormatter,
"terminal": RichTerminalFormatter,
"json": JsonFormatter,
}
if format_type.lower() not in formatters:
raise ValueError(
f"Unknown format type: {format_type}. "
f"Valid options: {', '.join(formatters.keys())}"
)
formatter_class = formatters[format_type.lower()]
return formatter_class(verbosity=verbosity, **kwargs)

View File

@@ -0,0 +1,404 @@
# ABOUTME: Base classes and interfaces for output formatting
# ABOUTME: Defines OutputFormatter ABC with verbosity levels, event types, and token tracking
"""Base classes for Claude adapter output formatting."""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
_logger = logging.getLogger(__name__)
class VerbosityLevel(Enum):
"""Verbosity levels for output formatting."""
QUIET = 0 # Only errors and final results
NORMAL = 1 # Tool calls, assistant messages (no details)
VERBOSE = 2 # Full tool inputs/outputs, detailed messages
DEBUG = 3 # Everything including internal state
class MessageType(Enum):
"""Types of messages that can be formatted."""
SYSTEM = "system"
ASSISTANT = "assistant"
USER = "user"
TOOL_CALL = "tool_call"
TOOL_RESULT = "tool_result"
ERROR = "error"
INFO = "info"
PROGRESS = "progress"
@dataclass
class TokenUsage:
"""Tracks token usage and costs."""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
cost: float = 0.0
model: str = ""
# Running totals across session
session_input_tokens: int = 0
session_output_tokens: int = 0
session_total_tokens: int = 0
session_cost: float = 0.0
def add(
self,
input_tokens: int = 0,
output_tokens: int = 0,
cost: float = 0.0,
model: str = "",
) -> None:
"""Add tokens to current and session totals."""
self.input_tokens = input_tokens
self.output_tokens = output_tokens
self.total_tokens = input_tokens + output_tokens
self.cost = cost
if model:
self.model = model
self.session_input_tokens += input_tokens
self.session_output_tokens += output_tokens
self.session_total_tokens += input_tokens + output_tokens
self.session_cost += cost
def reset_current(self) -> None:
"""Reset current iteration tokens (keep session totals)."""
self.input_tokens = 0
self.output_tokens = 0
self.total_tokens = 0
self.cost = 0.0
@dataclass
class ToolCallInfo:
"""Information about a tool call."""
tool_name: str
tool_id: str
input_params: Dict[str, Any] = field(default_factory=dict)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
result: Optional[Any] = None
is_error: bool = False
duration_ms: Optional[int] = None
@dataclass
class FormatContext:
"""Context information for formatting operations."""
iteration: int = 0
verbosity: VerbosityLevel = VerbosityLevel.NORMAL
timestamp: Optional[datetime] = None
token_usage: Optional[TokenUsage] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
if self.timestamp is None:
self.timestamp = datetime.now()
if self.token_usage is None:
self.token_usage = TokenUsage()
class OutputFormatter(ABC):
"""Abstract base class for output formatters.
Formatters handle rendering of Claude adapter events to different output
formats (plain text, rich terminal, JSON, etc.). They support verbosity
levels and consistent token usage tracking.
"""
def __init__(self, verbosity: VerbosityLevel = VerbosityLevel.NORMAL) -> None:
"""Initialize formatter with verbosity level.
Args:
verbosity: Output verbosity level
"""
self._verbosity = verbosity
self._token_usage = TokenUsage()
self._start_time = datetime.now()
self._callbacks: List[Callable[[MessageType, Any, FormatContext], None]] = []
@property
def verbosity(self) -> VerbosityLevel:
"""Get current verbosity level."""
return self._verbosity
@verbosity.setter
def verbosity(self, level: VerbosityLevel) -> None:
"""Set verbosity level."""
self._verbosity = level
@property
def token_usage(self) -> TokenUsage:
"""Get current token usage."""
return self._token_usage
def should_display(self, message_type: MessageType) -> bool:
"""Check if message type should be displayed at current verbosity.
Args:
message_type: Type of message to check
Returns:
True if message should be displayed
"""
# Always show errors
if message_type == MessageType.ERROR:
return True
if self._verbosity == VerbosityLevel.QUIET:
return False
if self._verbosity == VerbosityLevel.NORMAL:
return message_type in (
MessageType.ASSISTANT,
MessageType.TOOL_CALL,
MessageType.PROGRESS,
MessageType.INFO,
)
# VERBOSE and DEBUG show everything
return True
def register_callback(
self, callback: Callable[[MessageType, Any, FormatContext], None]
) -> None:
"""Register a callback for format events.
Args:
callback: Function to call with (message_type, content, context)
"""
self._callbacks.append(callback)
def _notify_callbacks(
self, message_type: MessageType, content: Any, context: FormatContext
) -> None:
"""Notify all registered callbacks."""
for callback in self._callbacks:
try:
callback(message_type, content, context)
except Exception as e:
# Log but don't let callback errors break formatting
callback_name = getattr(callback, "__name__", repr(callback))
_logger.debug(
"Callback %s failed for %s: %s: %s",
callback_name,
message_type,
type(e).__name__,
e,
)
def _create_context(
self, iteration: int = 0, metadata: Optional[Dict[str, Any]] = None
) -> FormatContext:
"""Create a format context with current state.
Args:
iteration: Current iteration number
metadata: Additional metadata
Returns:
FormatContext instance
"""
return FormatContext(
iteration=iteration,
verbosity=self._verbosity,
timestamp=datetime.now(),
token_usage=self._token_usage,
metadata=metadata or {},
)
def update_tokens(
self,
input_tokens: int = 0,
output_tokens: int = 0,
cost: float = 0.0,
model: str = "",
) -> None:
"""Update token usage tracking.
Args:
input_tokens: Number of input tokens used
output_tokens: Number of output tokens used
cost: Cost in USD
model: Model name
"""
self._token_usage.add(input_tokens, output_tokens, cost, model)
@abstractmethod
def format_tool_call(
self,
tool_info: ToolCallInfo,
iteration: int = 0,
) -> str:
"""Format a tool call for display.
Args:
tool_info: Tool call information
iteration: Current iteration number
Returns:
Formatted string representation
"""
pass
@abstractmethod
def format_tool_result(
self,
tool_info: ToolCallInfo,
iteration: int = 0,
) -> str:
"""Format a tool result for display.
Args:
tool_info: Tool call info with result
iteration: Current iteration number
Returns:
Formatted string representation
"""
pass
@abstractmethod
def format_assistant_message(
self,
message: str,
iteration: int = 0,
) -> str:
"""Format an assistant message for display.
Args:
message: Assistant message text
iteration: Current iteration number
Returns:
Formatted string representation
"""
pass
@abstractmethod
def format_system_message(
self,
message: str,
iteration: int = 0,
) -> str:
"""Format a system message for display.
Args:
message: System message text
iteration: Current iteration number
Returns:
Formatted string representation
"""
pass
@abstractmethod
def format_error(
self,
error: str,
exception: Optional[Exception] = None,
iteration: int = 0,
) -> str:
"""Format an error for display.
Args:
error: Error message
exception: Optional exception object
iteration: Current iteration number
Returns:
Formatted string representation
"""
pass
@abstractmethod
def format_progress(
self,
message: str,
current: int = 0,
total: int = 0,
iteration: int = 0,
) -> str:
"""Format progress information for display.
Args:
message: Progress message
current: Current progress value
total: Total progress value
iteration: Current iteration number
Returns:
Formatted string representation
"""
pass
@abstractmethod
def format_token_usage(self, show_session: bool = True) -> str:
"""Format token usage summary for display.
Args:
show_session: Include session totals
Returns:
Formatted string representation
"""
pass
@abstractmethod
def format_section_header(self, title: str, iteration: int = 0) -> str:
"""Format a section header for display.
Args:
title: Section title
iteration: Current iteration number
Returns:
Formatted string representation
"""
pass
@abstractmethod
def format_section_footer(self) -> str:
"""Format a section footer for display.
Returns:
Formatted string representation
"""
pass
def summarize_content(self, content: str, max_length: int = 500) -> str:
"""Summarize long content for display.
Args:
content: Content to summarize
max_length: Maximum length before truncation
Returns:
Summarized content
"""
if len(content) <= max_length:
return content
# Truncate with indicator
half = (max_length - 20) // 2
return f"{content[:half]}\n... [{len(content)} chars truncated] ...\n{content[-half:]}"
def get_elapsed_time(self) -> float:
"""Get elapsed time since formatter creation.
Returns:
Elapsed time in seconds
"""
return (datetime.now() - self._start_time).total_seconds()

View File

@@ -0,0 +1,915 @@
# ABOUTME: Colored terminal output utilities using Rich
# ABOUTME: Provides DiffFormatter, DiffStats, and RalphConsole for enhanced CLI output
"""Colored terminal output utilities using Rich."""
import logging
import re
from dataclasses import dataclass, field
from typing import Optional
_logger = logging.getLogger(__name__)
# Try to import Rich components with fallback
try:
from rich.console import Console
from rich.markdown import Markdown
from rich.markup import escape
from rich.panel import Panel
from rich.syntax import Syntax
from rich.table import Table
RICH_AVAILABLE = True
except ImportError:
RICH_AVAILABLE = False
Console = None # type: ignore
Markdown = None # type: ignore
Panel = None # type: ignore
Syntax = None # type: ignore
Table = None # type: ignore
def escape(x: str) -> str:
"""Fallback escape function."""
return str(x).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
@dataclass
class DiffStats:
"""Statistics for diff content."""
additions: int = 0
deletions: int = 0
files: int = 0
files_changed: dict[str, tuple[int, int]] = field(
default_factory=dict
) # filename -> (additions, deletions)
class DiffFormatter:
"""Formatter for enhanced diff visualization."""
# Diff display constants
MAX_CONTEXT_LINES = 3 # Maximum context lines to show before/after changes
LARGE_DIFF_THRESHOLD = 100 # Lines count for "large diff" detection
SEPARATOR_WIDTH = 60 # Width of visual separators
LINE_NUM_WIDTH = 6 # Width for line number display
# Binary file patterns
BINARY_EXTENSIONS = {
".png",
".jpg",
".jpeg",
".gif",
".pdf",
".zip",
".tar",
".gz",
".so",
".pyc",
".exe",
".dll",
}
def __init__(self, console: "Console") -> None:
"""
Initialize diff formatter.
Args:
console: Rich console for output
"""
self.console = console
def format_and_print(self, text: str) -> None:
"""
Print diff with enhanced visualization and file path highlighting.
Features:
- Color-coded diff lines (additions, deletions, context)
- File path highlighting with improved contrast
- Diff statistics summary (+X/-Y lines) with per-file breakdown
- Visual separation between file changes with subtle styling
- Enhanced hunk headers with line range info and context highlighting
- Smart context line limiting for large diffs
- Improved spacing for better readability
- Binary file detection and special handling
- Empty diff detection with clear messaging
Args:
text: Diff text to render
"""
if not RICH_AVAILABLE:
print(text)
return
lines = text.split("\n")
# Calculate diff statistics
stats = self._calculate_stats(lines)
# Handle empty diffs
if stats.additions == 0 and stats.deletions == 0 and stats.files == 0:
self.console.print("[dim italic]No changes detected[/dim italic]")
return
# Print summary if we have changes
self._print_summary(stats)
current_file = None
current_file_name = None
context_line_count = 0
in_change_section = False
for line in lines:
# File headers - highlight with bold cyan
if line.startswith("diff --git"):
# Add visual separator between files (except first)
if current_file is not None:
# Print per-file stats before separator
if current_file_name is not None:
self._print_file_stats(current_file_name, stats)
self.console.print()
self.console.print(f"[dim]{'' * self.SEPARATOR_WIDTH}[/dim]")
self.console.print()
current_file = line
# Extract filename for stats tracking
current_file_name = self._extract_filename(line)
# Check for binary files
if self._is_binary_file(line):
self.console.print(
f"[bold magenta]{line} [dim](binary)[/dim][/bold magenta]"
)
else:
self.console.print(f"[bold cyan]{line}[/bold cyan]")
context_line_count = 0
elif line.startswith("Binary files"):
# Binary file indicator
self.console.print(f"[yellow]📦 {line}[/yellow]")
continue
elif line.startswith("---") or line.startswith("+++"):
# File paths - extract and highlight
if line.startswith("---"):
self.console.print(f"[bold red]{line}[/bold red]")
else: # +++
self.console.print(f"[bold green]{line}[/bold green]")
# Hunk headers - enhanced with context
elif line.startswith("@@"):
# Add subtle spacing before hunk for better visual separation
if context_line_count > 0:
self.console.print()
# Extract line ranges for better readability
hunk_info = self._format_hunk_header(line)
self.console.print(f"[bold magenta]{hunk_info}[/bold magenta]")
context_line_count = 0
in_change_section = False
# Added lines - enhanced with bold for better contrast
elif line.startswith("+"):
self.console.print(f"[bold green]{line}[/bold green]")
in_change_section = True
context_line_count = 0
# Removed lines - enhanced with bold for better contrast
elif line.startswith("-"):
self.console.print(f"[bold red]{line}[/bold red]")
in_change_section = True
context_line_count = 0
# Context lines
else:
# Only show limited context lines for large diffs
if stats.additions + stats.deletions > self.LARGE_DIFF_THRESHOLD:
# Show context around changes only
if in_change_section:
if context_line_count < self.MAX_CONTEXT_LINES:
self.console.print(f"[dim]{line}[/dim]")
context_line_count += 1
elif context_line_count == self.MAX_CONTEXT_LINES:
self.console.print(
"[dim italic] ⋮ (context lines omitted for readability)[/dim italic]"
)
context_line_count += 1
else:
# Leading context - always show up to limit
if context_line_count < self.MAX_CONTEXT_LINES:
self.console.print(f"[dim]{line}[/dim]")
context_line_count += 1
else:
# Small diff - show all context
self.console.print(f"[dim]{line}[/dim]")
# Print final file stats
if current_file_name:
self._print_file_stats(current_file_name, stats)
# Add spacing after diff for better separation from next content
self.console.print()
def _calculate_stats(self, lines: list[str]) -> DiffStats:
"""
Calculate statistics from diff lines including per-file breakdown.
Args:
lines: List of diff lines
Returns:
DiffStats with additions, deletions, files count, and per-file breakdown
"""
stats = DiffStats()
current_file = None
for line in lines:
if line.startswith("diff --git"):
stats.files += 1
current_file = self._extract_filename(line)
if current_file and current_file not in stats.files_changed:
stats.files_changed[current_file] = (0, 0)
elif line.startswith("+") and not line.startswith("+++"):
stats.additions += 1
if current_file and current_file in stats.files_changed:
adds, dels = stats.files_changed[current_file]
stats.files_changed[current_file] = (adds + 1, dels)
elif line.startswith("-") and not line.startswith("---"):
stats.deletions += 1
if current_file and current_file in stats.files_changed:
adds, dels = stats.files_changed[current_file]
stats.files_changed[current_file] = (adds, dels + 1)
return stats
def _print_summary(self, stats: DiffStats) -> None:
"""
Print diff statistics summary.
Args:
stats: Diff statistics
"""
if stats.additions == 0 and stats.deletions == 0:
return
summary = "[bold cyan]📊 Changes:[/bold cyan] "
if stats.additions > 0:
summary += f"[green]+{stats.additions}[/green]"
if stats.additions > 0 and stats.deletions > 0:
summary += " "
if stats.deletions > 0:
summary += f"[red]-{stats.deletions}[/red]"
if stats.files > 1:
summary += f" [dim]({stats.files} files)[/dim]"
self.console.print(summary)
self.console.print()
def _is_binary_file(self, diff_header: str) -> bool:
"""
Check if diff is for a binary file based on extension.
Args:
diff_header: Diff header line (e.g., "diff --git a/file.png b/file.png")
Returns:
True if file appears to be binary
"""
from pathlib import Path
# Extract file path from diff header
parts = diff_header.split()
if len(parts) >= 3:
file_path = parts[2] # e.g., "a/file.png"
ext = Path(file_path).suffix.lower()
return ext in self.BINARY_EXTENSIONS
return False
def _extract_filename(self, diff_header: str) -> Optional[str]:
"""
Extract filename from diff header line.
Args:
diff_header: Diff header line (e.g., "diff --git a/file.py b/file.py")
Returns:
Filename or None if not found
"""
parts = diff_header.split()
if len(parts) >= 3:
# Extract from "a/file.py" or "b/file.py"
file_path = parts[2]
if file_path.startswith("a/") or file_path.startswith("b/"):
return file_path[2:]
return file_path
return None
def _print_file_stats(self, filename: str, stats: DiffStats) -> None:
"""
Print per-file statistics with visual bar.
Args:
filename: Name of the file
stats: DiffStats containing per-file breakdown
"""
if filename and filename in stats.files_changed:
adds, dels = stats.files_changed[filename]
if adds > 0 or dels > 0:
# Calculate visual bar proportions (max 30 chars)
total_changes = adds + dels
bar_width = min(30, total_changes)
if total_changes > 0:
add_width = int((adds / total_changes) * bar_width)
del_width = bar_width - add_width
# Create visual bar
bar = ""
if add_width > 0:
bar += f"[bold green]{'' * add_width}[/bold green]"
if del_width > 0:
bar += f"[bold red]{'' * del_width}[/bold red]"
# Print stats with bar
summary = f" {bar} "
if adds > 0:
summary += f"[bold green]+{adds}[/bold green]"
if adds > 0 and dels > 0:
summary += " "
if dels > 0:
summary += f"[bold red]-{dels}[/bold red]"
self.console.print(summary)
def _format_hunk_header(self, hunk: str) -> str:
"""
Format hunk header with enhanced readability and context highlighting.
Transforms: @@ -140,7 +140,7 @@ class RalphConsole:
Into: @@ Lines 140-147 → 140-147 @@ class RalphConsole:
With context (function/class name) highlighted in cyan.
Args:
hunk: Original hunk header line
Returns:
Formatted hunk header with improved readability
"""
# Extract line ranges using regex
pattern = r"@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@(.*)$"
match = re.search(pattern, hunk)
if not match:
return hunk
old_start = int(match.group(1))
old_count = int(match.group(2)) if match.group(2) else 1
new_start = int(match.group(3))
new_count = int(match.group(4)) if match.group(4) else 1
context = match.group(5).strip()
# Calculate end lines
old_end = old_start + old_count - 1
new_end = new_start + new_count - 1
# Format with readable line ranges
header = f"@@ Lines {old_start}-{old_end}{new_start}-{new_end} @@"
# Highlight context (function/class name) if present
if context:
# Highlight the context in cyan for better visibility
header += f" [cyan]{context}[/cyan]"
return header
class RalphConsole:
"""Rich console wrapper for Ralph output."""
# Display constants
CLEAR_LINE_WIDTH = 80 # Characters to clear when clearing a line
PROGRESS_BAR_WIDTH = 30 # Width of progress bar in characters
COUNTDOWN_COLOR_CHANGE_THRESHOLD_HIGH = 5 # Seconds remaining for yellow
COUNTDOWN_COLOR_CHANGE_THRESHOLD_LOW = 2 # Seconds remaining for red
MARKDOWN_INDICATOR_THRESHOLD = 2 # Minimum markdown patterns to consider as markdown
DIFF_SCAN_LINE_LIMIT = 5 # Number of lines to scan for diff indicators
DIFF_HUNK_SCAN_CHARS = 100 # Characters to scan for diff hunk markers
# Regex patterns for content detection and formatting
CODE_BLOCK_PATTERN = r"```(\w+)?\n(.*?)\n```"
FILE_REF_PATTERN = r"(\S+\.[a-zA-Z0-9]+):(\d+)"
INLINE_CODE_PATTERN = r"`([^`\n]+)`"
HUNK_HEADER_PATTERN = r"@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@(.*)$"
TABLE_SEPARATOR_PATTERN = r"^\s*\|[\s\-:|]+\|\s*$"
MARKDOWN_HEADING_PATTERN = r"^#{1,6}\s+.+"
MARKDOWN_UNORDERED_LIST_PATTERN = r"^[\*\-]\s+.+"
MARKDOWN_ORDERED_LIST_PATTERN = r"^\d+\.\s+.+"
MARKDOWN_BOLD_PATTERN = r"\*\*.+?\*\*"
MARKDOWN_ITALIC_PATTERN = r"\*.+?\*"
MARKDOWN_BLOCKQUOTE_PATTERN = r"^>\s+.+"
MARKDOWN_TASK_LIST_PATTERN = r"^[\*\-]\s+\[([ xX])\]\s+.+"
MARKDOWN_HORIZONTAL_RULE_PATTERN = r"^(\-{3,}|\*{3,}|_{3,})\s*$"
def __init__(self) -> None:
"""Initialize Rich console."""
if RICH_AVAILABLE:
self.console = Console()
self.diff_formatter = DiffFormatter(self.console)
else:
self.console = None
self.diff_formatter = None
def print_status(self, message: str, style: str = "cyan") -> None:
"""Print status message."""
if self.console:
# Use markup escaping to prevent Rich from parsing brackets in the icon
self.console.print(f"[{style}][[*]] {message}[/{style}]")
else:
print(f"[*] {message}")
def print_success(self, message: str) -> None:
"""Print success message."""
if self.console:
self.console.print(f"[green]✓[/green] {message}")
else:
print(f"{message}")
def print_error(self, message: str, severity: str = "error") -> None:
"""
Print error message with severity-based formatting.
Args:
message: Error message to print
severity: Error severity level ("critical", "error", "warning")
"""
severity_styles = {
"critical": ("[red bold]⛔[/red bold]", "red bold"),
"error": ("[red]✗[/red]", "red"),
"warning": ("[yellow]⚠[/yellow]", "yellow"),
}
icon, style = severity_styles.get(severity, severity_styles["error"])
if self.console:
self.console.print(f"{icon} [{style}]{message}[/{style}]")
else:
print(f"{message}")
def print_warning(self, message: str) -> None:
"""Print warning message."""
if self.console:
self.console.print(f"[yellow]⚠[/yellow] {message}")
else:
print(f"{message}")
def print_info(self, message: str) -> None:
"""Print info message."""
if self.console:
self.console.print(f"[blue][/blue] {message}")
else:
print(f" {message}")
def print_header(self, title: str) -> None:
"""Print section header."""
if self.console and Panel:
self.console.print(
Panel(title, style="green bold", border_style="green"),
justify="left",
)
else:
print(f"\n=== {title} ===\n")
def print_iteration_header(self, iteration: int) -> None:
"""Print iteration header."""
if self.console:
self.console.print(
f"\n[cyan bold]=== RALPH ITERATION {iteration} ===[/cyan bold]\n"
)
else:
print(f"\n=== RALPH ITERATION {iteration} ===\n")
def print_stats(
self,
iteration: int,
success_count: int,
error_count: int,
start_time: str,
prompt_file: str,
recent_lines: list[str],
) -> None:
"""
Print statistics table.
Args:
iteration: Current iteration number
success_count: Number of successful iterations
error_count: Number of failed iterations
start_time: Start time string
prompt_file: Prompt file name
recent_lines: Recent log entries
"""
if not self.console or not Table:
# Plain text fallback
print("\nRALPH STATISTICS")
print(f" Iteration: {iteration}")
print(f" Successful: {success_count}")
print(f" Failed: {error_count}")
print(f" Started: {start_time}")
print(f" Prompt: {prompt_file}")
return
# Create stats table with better formatting
table = Table(
title="🤖 RALPH STATISTICS",
show_header=True,
header_style="bold yellow",
border_style="cyan",
)
table.add_column("Metric", style="cyan bold", no_wrap=True, width=20)
table.add_column("Value", style="white", width=40)
# Calculate success rate
total = success_count + error_count
success_rate = (success_count / total * 100) if total > 0 else 0
table.add_row("🔄 Current Iteration", str(iteration))
table.add_row("✅ Successful", f"[green bold]{success_count}[/green bold]")
table.add_row("❌ Failed", f"[red bold]{error_count}[/red bold]")
# Determine success rate color based on percentage
if success_rate > 80:
rate_color = "green"
elif success_rate > 50:
rate_color = "yellow"
else:
rate_color = "red"
table.add_row("📊 Success Rate", f"[{rate_color}]{success_rate:.1f}%[/]")
table.add_row("🕐 Started", start_time or "Unknown")
table.add_row("📝 Prompt", prompt_file)
self.console.print(table)
# Show recent activity with better formatting
if recent_lines:
self.console.print("\n[yellow bold]📋 RECENT ACTIVITY[/yellow bold]")
for line in recent_lines:
# Clean up log lines for display and escape Rich markup
clean_line = escape(line.strip())
if "[SUCCESS]" in clean_line:
self.console.print(f" [green]▸[/green] {clean_line}")
elif "[ERROR]" in clean_line:
self.console.print(f" [red]▸[/red] {clean_line}")
elif "[WARNING]" in clean_line:
self.console.print(f" [yellow]▸[/yellow] {clean_line}")
else:
self.console.print(f" [blue]▸[/blue] {clean_line}")
self.console.print()
def print_countdown(self, remaining: int, total: int) -> None:
"""
Print countdown timer with progress bar.
Args:
remaining: Seconds remaining
total: Total delay seconds
"""
# Guard against division by zero
if total <= 0:
return
# Calculate progress
progress = (total - remaining) / total
filled = int(self.PROGRESS_BAR_WIDTH * progress)
bar = "" * filled + "" * (self.PROGRESS_BAR_WIDTH - filled)
# Color based on time remaining (using constants)
if remaining > self.COUNTDOWN_COLOR_CHANGE_THRESHOLD_HIGH:
color = "green"
elif remaining > self.COUNTDOWN_COLOR_CHANGE_THRESHOLD_LOW:
color = "yellow"
else:
color = "red"
if self.console:
self.console.print(
f"\r[{color}]⏳ [{bar}] {remaining}s / {total}s remaining[/{color}]",
end="",
)
else:
print(f"\r⏳ [{bar}] {remaining}s / {total}s remaining", end="")
def clear_line(self) -> None:
"""Clear current line."""
if self.console:
self.console.print("\r" + " " * self.CLEAR_LINE_WIDTH + "\r", end="")
else:
print("\r" + " " * self.CLEAR_LINE_WIDTH + "\r", end="")
def print_separator(self) -> None:
"""Print visual separator."""
if self.console:
self.console.print("\n[cyan]---[/cyan]\n")
else:
print("\n---\n")
def clear_screen(self) -> None:
"""Clear screen."""
if self.console:
self.console.clear()
else:
print("\033[2J\033[H", end="")
def print_message(self, text: str) -> None:
"""
Print message with intelligent formatting and improved visual hierarchy.
Detects and formats:
- Code blocks (```language) with syntax highlighting
- Diffs (lines starting with +, -, @@) with enhanced visualization
- Markdown tables with proper rendering
- Markdown headings, lists, emphasis with spacing
- Inline code (`code`) with highlighting
- Plain text with file path detection
Args:
text: Message text to print
"""
if not self.console:
print(text)
return
# Check if text contains code blocks
if "```" in text:
# Split text by code blocks and process each part
parts = re.split(self.CODE_BLOCK_PATTERN, text, flags=re.DOTALL)
for i, part in enumerate(parts):
if i % 3 == 0: # Regular text between code blocks
if part.strip():
self._print_formatted_text(part)
# Add subtle spacing after text before code block
if i + 1 < len(parts):
self.console.print()
elif i % 3 == 1: # Language identifier
language = part or "text"
code = parts[i + 1] if i + 1 < len(parts) else ""
if code.strip() and Syntax:
# Use syntax highlighting for code blocks with enhanced features
syntax = Syntax(
code,
language,
theme="monokai",
line_numbers=True,
word_wrap=True,
indent_guides=True,
padding=(1, 2),
)
self.console.print(syntax)
# Add spacing after code block if more content follows
if i + 2 < len(parts) and parts[i + 2].strip():
self.console.print()
elif self._is_diff_content(text):
# Format as diff with enhanced visualization
if self.diff_formatter:
self.diff_formatter.format_and_print(text)
else:
print(text)
elif self._is_markdown_table(text):
# Render markdown tables nicely
self._print_markdown_table(text)
# Add spacing after table
self.console.print()
elif self._is_markdown_content(text):
# Render rich markdown with headings, lists, emphasis
self._print_markdown(text)
else:
# Regular text - check for inline code and format accordingly
self._print_formatted_text(text)
def _is_diff_content(self, text: str) -> bool:
"""
Check if text appears to be diff content.
Args:
text: Text to check
Returns:
True if text looks like diff output
"""
diff_indicators = [
text.startswith("diff --git"),
text.startswith("--- "),
text.startswith("+++ "),
"@@" in text[: self.DIFF_HUNK_SCAN_CHARS], # Diff hunk markers
any(
line.startswith(("+", "-", "@@"))
for line in text.split("\n")[: self.DIFF_SCAN_LINE_LIMIT]
),
]
return any(diff_indicators)
def _is_markdown_table(self, text: str) -> bool:
"""
Check if text appears to be a markdown table.
Args:
text: Text to check
Returns:
True if text looks like a markdown table
"""
lines = text.strip().split("\n")
if len(lines) < 2:
return False
# Check for table separator line (e.g., |---|---|)
for line in lines[:3]:
if re.match(self.TABLE_SEPARATOR_PATTERN, line):
return True
return False
def _print_markdown_table(self, text: str) -> None:
"""
Print markdown table with Rich formatting.
Args:
text: Markdown table text
"""
if Markdown:
# Use Rich's Markdown renderer for tables
md = Markdown(text)
self.console.print(md)
else:
print(text)
def _is_markdown_content(self, text: str) -> bool:
"""
Check if text appears to contain rich markdown (headings, lists, etc.).
Args:
text: Text to check
Returns:
True if text looks like markdown with formatting
"""
markdown_indicators = [
re.search(self.MARKDOWN_HEADING_PATTERN, text, re.MULTILINE), # Headings
re.search(
self.MARKDOWN_UNORDERED_LIST_PATTERN, text, re.MULTILINE
), # Unordered lists
re.search(
self.MARKDOWN_ORDERED_LIST_PATTERN, text, re.MULTILINE
), # Ordered lists
re.search(self.MARKDOWN_BOLD_PATTERN, text), # Bold
re.search(self.MARKDOWN_ITALIC_PATTERN, text), # Italic
re.search(
self.MARKDOWN_BLOCKQUOTE_PATTERN, text, re.MULTILINE
), # Blockquotes
re.search(
self.MARKDOWN_TASK_LIST_PATTERN, text, re.MULTILINE
), # Task lists
re.search(
self.MARKDOWN_HORIZONTAL_RULE_PATTERN, text, re.MULTILINE
), # Horizontal rules
]
# Return true if at least MARKDOWN_INDICATOR_THRESHOLD markdown indicators present
threshold = self.MARKDOWN_INDICATOR_THRESHOLD
return sum(bool(indicator) for indicator in markdown_indicators) >= threshold
def _preprocess_markdown(self, text: str) -> str:
"""
Preprocess markdown text for better rendering.
Handles:
- Task lists with checkboxes (- [ ] and - [x])
- Horizontal rules with visual enhancement
- Code blocks with language hints
Args:
text: Raw markdown text
Returns:
Preprocessed markdown text
"""
lines = text.split("\n")
processed_lines = []
for line in lines:
# Enhanced task lists with visual indicators
if re.match(self.MARKDOWN_TASK_LIST_PATTERN, line):
# Replace [ ] with ☐ and [x]/[X] with ☑
line = re.sub(r"\[\s\]", "", line)
line = re.sub(r"\[[xX]\]", "", line)
# Enhanced horizontal rules - make them more visible
if re.match(self.MARKDOWN_HORIZONTAL_RULE_PATTERN, line):
line = f"\n{'' * 60}\n"
processed_lines.append(line)
return "\n".join(processed_lines)
def _print_markdown(self, text: str) -> None:
"""
Print markdown content with Rich formatting and improved spacing.
Args:
text: Markdown text to render
"""
if not Markdown:
print(text)
return
# Add subtle spacing before markdown for visual separation
has_heading = re.search(self.MARKDOWN_HEADING_PATTERN, text, re.MULTILINE)
if has_heading:
self.console.print()
# Preprocess markdown for enhanced features
processed_text = self._preprocess_markdown(text)
md = Markdown(processed_text)
self.console.print(md)
# Add spacing after markdown blocks for better separation from next content
self.console.print()
def _print_formatted_text(self, text: str) -> None:
"""
Print text with basic formatting, inline code, file path highlighting, and error detection.
Args:
text: Text to print
"""
if not self.console:
print(text)
return
# Check for error/exception patterns and apply special formatting
if self._is_error_traceback(text):
self._print_error_traceback(text)
return
# First, highlight file paths with line numbers (e.g., "file.py:123")
text = re.sub(
self.FILE_REF_PATTERN,
lambda m: (
f"[bold yellow]{m.group(1)}[/bold yellow]:"
f"[bold blue]{m.group(2)}[/bold blue]"
),
text,
)
# Check for inline code (single backticks)
if "`" in text and "```" not in text:
# Replace inline code with Rich markup - improved visibility
formatted_text = re.sub(
self.INLINE_CODE_PATTERN,
lambda m: f"[cyan on grey23]{m.group(1)}[/cyan on grey23]",
text,
)
self.console.print(formatted_text, highlight=True)
else:
# Enable markup for file paths and highlighting for URLs
self.console.print(text, markup=True, highlight=True)
def _is_error_traceback(self, text: str) -> bool:
"""
Check if text appears to be an error traceback.
Args:
text: Text to check
Returns:
True if text looks like an error traceback
"""
error_indicators = [
"Traceback (most recent call last):" in text,
re.search(r'^\s*File ".*", line \d+', text, re.MULTILINE),
re.search(
r"^(Error|Exception|ValueError|TypeError|RuntimeError):",
text,
re.MULTILINE,
),
]
return any(error_indicators)
def _print_error_traceback(self, text: str) -> None:
"""
Print error traceback with enhanced formatting.
Args:
text: Error traceback text
"""
if not Syntax:
print(text)
return
# Use Python syntax highlighting for tracebacks
try:
syntax = Syntax(
text,
"python",
theme="monokai",
line_numbers=False,
word_wrap=True,
background_color="grey11",
)
self.console.print("\n[red bold]⚠ Error Traceback:[/red bold]")
self.console.print(syntax)
self.console.print()
except Exception as e:
# Fallback to simple red text if syntax highlighting fails
_logger.warning("Syntax highlighting failed for traceback: %s: %s", type(e).__name__, e)
self.console.print(f"[red]{text}[/red]")

View File

@@ -0,0 +1,247 @@
# ABOUTME: Content type detection for smart output formatting
# ABOUTME: Detects diffs, code blocks, markdown, tables, and tracebacks
"""Content type detection for intelligent output formatting."""
import re
from enum import Enum
from typing import Optional
class ContentType(Enum):
"""Types of content that can be detected."""
PLAIN_TEXT = "plain_text"
DIFF = "diff"
CODE_BLOCK = "code_block"
MARKDOWN = "markdown"
MARKDOWN_TABLE = "markdown_table"
ERROR_TRACEBACK = "error_traceback"
class ContentDetector:
"""Detects content types for smart formatting.
Analyzes text content to determine the most appropriate rendering method.
Detection priority: code_block > diff > traceback > table > markdown > plain
"""
# Detection constants
DIFF_HUNK_SCAN_CHARS = 100
DIFF_SCAN_LINE_LIMIT = 5
MARKDOWN_INDICATOR_THRESHOLD = 2
# Regex patterns
CODE_BLOCK_PATTERN = re.compile(r"```(\w+)?\n.*?\n```", re.DOTALL)
MARKDOWN_HEADING_PATTERN = re.compile(r"^#{1,6}\s+.+", re.MULTILINE)
MARKDOWN_UNORDERED_LIST_PATTERN = re.compile(r"^[\*\-]\s+.+", re.MULTILINE)
MARKDOWN_ORDERED_LIST_PATTERN = re.compile(r"^\d+\.\s+.+", re.MULTILINE)
MARKDOWN_BOLD_PATTERN = re.compile(r"\*\*.+?\*\*")
MARKDOWN_ITALIC_PATTERN = re.compile(r"(?<!\*)\*(?!\*)[^*\n]+\*(?!\*)")
MARKDOWN_BLOCKQUOTE_PATTERN = re.compile(r"^>\s+.+", re.MULTILINE)
MARKDOWN_TASK_LIST_PATTERN = re.compile(r"^[\*\-]\s+\[([ xX])\]\s+.+", re.MULTILINE)
MARKDOWN_HORIZONTAL_RULE_PATTERN = re.compile(r"^(\-{3,}|\*{3,}|_{3,})\s*$", re.MULTILINE)
TABLE_SEPARATOR_PATTERN = re.compile(r"^\s*\|[\s\-:|]+\|\s*$", re.MULTILINE)
TRACEBACK_FILE_LINE_PATTERN = re.compile(r'^\s*File ".*", line \d+', re.MULTILINE)
TRACEBACK_ERROR_PATTERN = re.compile(
r"^(Error|Exception|ValueError|TypeError|RuntimeError|KeyError|AttributeError|"
r"IndexError|ImportError|FileNotFoundError|NameError|ZeroDivisionError):",
re.MULTILINE,
)
def detect(self, text: str) -> ContentType:
"""Detect the primary content type of text.
Detection priority ensures more specific types are matched first:
1. Code blocks (```...```) - highest priority
2. Diffs (git diff format)
3. Error tracebacks (Python exceptions)
4. Markdown tables (|...|)
5. Rich markdown (headings, lists, etc.)
6. Plain text (fallback)
Args:
text: Text content to analyze
Returns:
The detected ContentType
"""
if not text or not text.strip():
return ContentType.PLAIN_TEXT
# Check in priority order
if self.is_code_block(text):
return ContentType.CODE_BLOCK
if self.is_diff(text):
return ContentType.DIFF
if self.is_error_traceback(text):
return ContentType.ERROR_TRACEBACK
if self.is_markdown_table(text):
return ContentType.MARKDOWN_TABLE
if self.is_markdown(text):
return ContentType.MARKDOWN
return ContentType.PLAIN_TEXT
def is_diff(self, text: str) -> bool:
"""Check if text is diff content.
Detects git diff format including:
- diff --git headers
- --- and +++ file markers (as diff markers, not markdown hr)
- @@ hunk markers
Note: We avoid matching lines that merely start with + or - as those
could be markdown list items. Diff detection requires more specific
markers like @@ hunks or diff --git headers.
Args:
text: Text to check
Returns:
True if text appears to be diff content
"""
if not text:
return False
# Check for definitive diff markers
diff_indicators = [
text.startswith("diff --git"),
"@@" in text[: self.DIFF_HUNK_SCAN_CHARS],
]
if any(diff_indicators):
return True
# Check for --- a/ and +++ b/ patterns (file markers in unified diff)
# These are more specific than just --- or +++ which could be markdown hr
lines = text.split("\n")[: self.DIFF_SCAN_LINE_LIMIT]
has_file_markers = (
any(line.startswith("--- a/") or line.startswith("--- /") for line in lines)
and any(line.startswith("+++ b/") or line.startswith("+++ /") for line in lines)
)
if has_file_markers:
return True
# Check for @@ hunk pattern specifically
return any(line.startswith("@@") for line in lines)
def is_code_block(self, text: str) -> bool:
"""Check if text contains fenced code blocks.
Detects markdown-style code blocks with triple backticks.
Args:
text: Text to check
Returns:
True if text contains code blocks
"""
if not text:
return False
return "```" in text and self.CODE_BLOCK_PATTERN.search(text) is not None
def is_markdown(self, text: str) -> bool:
"""Check if text contains rich markdown formatting.
Requires at least MARKDOWN_INDICATOR_THRESHOLD indicators to avoid
false positives on text that happens to contain a single markdown element.
Detected elements:
- Headings (# Title)
- Lists (- item, 1. item)
- Emphasis (**bold**, *italic*)
- Blockquotes (> quote)
- Task lists (- [ ] task)
- Horizontal rules (---)
Args:
text: Text to check
Returns:
True if text appears to be markdown content
"""
if not text:
return False
markdown_indicators = [
self.MARKDOWN_HEADING_PATTERN.search(text),
self.MARKDOWN_UNORDERED_LIST_PATTERN.search(text),
self.MARKDOWN_ORDERED_LIST_PATTERN.search(text),
self.MARKDOWN_BOLD_PATTERN.search(text),
self.MARKDOWN_ITALIC_PATTERN.search(text),
self.MARKDOWN_BLOCKQUOTE_PATTERN.search(text),
self.MARKDOWN_TASK_LIST_PATTERN.search(text),
self.MARKDOWN_HORIZONTAL_RULE_PATTERN.search(text),
]
return sum(bool(indicator) for indicator in markdown_indicators) >= self.MARKDOWN_INDICATOR_THRESHOLD
def is_markdown_table(self, text: str) -> bool:
"""Check if text is a markdown table.
Detects tables with pipe separators and header dividers.
Args:
text: Text to check
Returns:
True if text appears to be a markdown table
"""
if not text:
return False
lines = text.strip().split("\n")
if len(lines) < 2:
return False
# Check for table separator line (|---|---|) in first few lines
for line in lines[:3]:
if self.TABLE_SEPARATOR_PATTERN.match(line):
return True
return False
def is_error_traceback(self, text: str) -> bool:
"""Check if text is an error traceback.
Detects Python exception tracebacks.
Args:
text: Text to check
Returns:
True if text appears to be an error traceback
"""
if not text:
return False
error_indicators = [
"Traceback (most recent call last):" in text,
self.TRACEBACK_FILE_LINE_PATTERN.search(text),
self.TRACEBACK_ERROR_PATTERN.search(text),
]
return any(error_indicators)
def extract_code_blocks(self, text: str) -> list[tuple[Optional[str], str]]:
"""Extract code blocks from text.
Args:
text: Text containing code blocks
Returns:
List of (language, code) tuples. Language may be None.
"""
if not text:
return []
blocks = []
pattern = re.compile(r"```(\w+)?\n(.*?)\n```", re.DOTALL)
for match in pattern.finditer(text):
language = match.group(1)
code = match.group(2)
blocks.append((language, code))
return blocks

View File

@@ -0,0 +1,416 @@
# ABOUTME: JSON output formatter for programmatic consumption
# ABOUTME: Produces structured JSON output for parsing by other tools
"""JSON output formatter for Claude adapter."""
import json
from datetime import datetime
from typing import Any, Dict, List, Optional
from .base import (
MessageType,
OutputFormatter,
ToolCallInfo,
VerbosityLevel,
)
class JsonFormatter(OutputFormatter):
"""JSON formatter for programmatic output consumption.
Produces structured JSON output suitable for parsing by other tools,
logging systems, or downstream processing pipelines.
"""
def __init__(
self,
verbosity: VerbosityLevel = VerbosityLevel.NORMAL,
pretty: bool = True,
include_timestamps: bool = True,
) -> None:
"""Initialize JSON formatter.
Args:
verbosity: Output verbosity level
pretty: Pretty-print JSON with indentation
include_timestamps: Include timestamps in output
"""
super().__init__(verbosity)
self._pretty = pretty
self._include_timestamps = include_timestamps
self._events: List[Dict[str, Any]] = []
def _to_json(self, obj: Dict[str, Any]) -> str:
"""Convert object to JSON string.
Args:
obj: Dictionary to serialize
Returns:
JSON string
"""
if self._pretty:
return json.dumps(obj, indent=2, default=str, ensure_ascii=False)
return json.dumps(obj, default=str, ensure_ascii=False)
def _create_event(
self,
event_type: str,
data: Dict[str, Any],
iteration: int = 0,
) -> Dict[str, Any]:
"""Create a structured event object.
Args:
event_type: Type of event
data: Event data
iteration: Current iteration number
Returns:
Event dictionary
"""
event = {
"type": event_type,
"iteration": iteration,
"data": data,
}
if self._include_timestamps:
event["timestamp"] = datetime.now().isoformat()
return event
def _record_event(self, event: Dict[str, Any]) -> None:
"""Record event for later retrieval.
Args:
event: Event dictionary to record
"""
self._events.append(event)
def get_events(self) -> List[Dict[str, Any]]:
"""Get all recorded events.
Returns:
List of event dictionaries
"""
return self._events.copy()
def clear_events(self) -> None:
"""Clear recorded events."""
self._events.clear()
def format_tool_call(
self,
tool_info: ToolCallInfo,
iteration: int = 0,
) -> str:
"""Format a tool call as JSON.
Args:
tool_info: Tool call information
iteration: Current iteration number
Returns:
JSON string representation
"""
if not self.should_display(MessageType.TOOL_CALL):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.TOOL_CALL, tool_info, context)
data: Dict[str, Any] = {
"tool_name": tool_info.tool_name,
"tool_id": tool_info.tool_id,
}
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
data["input_params"] = tool_info.input_params
if tool_info.start_time:
data["start_time"] = tool_info.start_time.isoformat()
event = self._create_event("tool_call", data, iteration)
self._record_event(event)
return self._to_json(event)
def format_tool_result(
self,
tool_info: ToolCallInfo,
iteration: int = 0,
) -> str:
"""Format a tool result as JSON.
Args:
tool_info: Tool call info with result
iteration: Current iteration number
Returns:
JSON string representation
"""
if not self.should_display(MessageType.TOOL_RESULT):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.TOOL_RESULT, tool_info, context)
data: Dict[str, Any] = {
"tool_name": tool_info.tool_name,
"tool_id": tool_info.tool_id,
"is_error": tool_info.is_error,
}
if tool_info.duration_ms is not None:
data["duration_ms"] = tool_info.duration_ms
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
result = tool_info.result
if isinstance(result, str) and len(result) > 1000:
data["result"] = self.summarize_content(result, 1000)
data["result_truncated"] = True
data["result_full_length"] = len(result)
else:
data["result"] = result
if tool_info.end_time:
data["end_time"] = tool_info.end_time.isoformat()
event = self._create_event("tool_result", data, iteration)
self._record_event(event)
return self._to_json(event)
def format_assistant_message(
self,
message: str,
iteration: int = 0,
) -> str:
"""Format an assistant message as JSON.
Args:
message: Assistant message text
iteration: Current iteration number
Returns:
JSON string representation
"""
if not self.should_display(MessageType.ASSISTANT):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.ASSISTANT, message, context)
data: Dict[str, Any] = {}
if self._verbosity == VerbosityLevel.NORMAL and len(message) > 1000:
data["message"] = self.summarize_content(message, 1000)
data["message_truncated"] = True
data["message_full_length"] = len(message)
else:
data["message"] = message
data["message_truncated"] = False
data["message_length"] = len(message)
event = self._create_event("assistant_message", data, iteration)
self._record_event(event)
return self._to_json(event)
def format_system_message(
self,
message: str,
iteration: int = 0,
) -> str:
"""Format a system message as JSON.
Args:
message: System message text
iteration: Current iteration number
Returns:
JSON string representation
"""
if not self.should_display(MessageType.SYSTEM):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.SYSTEM, message, context)
data = {
"message": message,
}
event = self._create_event("system_message", data, iteration)
self._record_event(event)
return self._to_json(event)
def format_error(
self,
error: str,
exception: Optional[Exception] = None,
iteration: int = 0,
) -> str:
"""Format an error as JSON.
Args:
error: Error message
exception: Optional exception object
iteration: Current iteration number
Returns:
JSON string representation
"""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.ERROR, error, context)
data: Dict[str, Any] = {
"error": error,
}
if exception:
data["exception_type"] = type(exception).__name__
data["exception_str"] = str(exception)
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
import traceback
data["traceback"] = traceback.format_exception(
type(exception), exception, exception.__traceback__
)
event = self._create_event("error", data, iteration)
self._record_event(event)
return self._to_json(event)
def format_progress(
self,
message: str,
current: int = 0,
total: int = 0,
iteration: int = 0,
) -> str:
"""Format progress information as JSON.
Args:
message: Progress message
current: Current progress value
total: Total progress value
iteration: Current iteration number
Returns:
JSON string representation
"""
if not self.should_display(MessageType.PROGRESS):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.PROGRESS, message, context)
data: Dict[str, Any] = {
"message": message,
"current": current,
"total": total,
}
if total > 0:
data["percentage"] = round((current / total) * 100, 1)
event = self._create_event("progress", data, iteration)
self._record_event(event)
return self._to_json(event)
def format_token_usage(self, show_session: bool = True) -> str:
"""Format token usage summary as JSON.
Args:
show_session: Include session totals
Returns:
JSON string representation
"""
usage = self._token_usage
data: Dict[str, Any] = {
"current": {
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"total_tokens": usage.total_tokens,
"cost": usage.cost,
},
}
if show_session:
data["session"] = {
"input_tokens": usage.session_input_tokens,
"output_tokens": usage.session_output_tokens,
"total_tokens": usage.session_total_tokens,
"cost": usage.session_cost,
}
if usage.model:
data["model"] = usage.model
event = self._create_event("token_usage", data, 0)
return self._to_json(event)
def format_section_header(self, title: str, iteration: int = 0) -> str:
"""Format a section header as JSON.
Args:
title: Section title
iteration: Current iteration number
Returns:
JSON string representation
"""
data = {
"title": title,
"elapsed_seconds": self.get_elapsed_time(),
}
event = self._create_event("section_start", data, iteration)
self._record_event(event)
return self._to_json(event)
def format_section_footer(self) -> str:
"""Format a section footer as JSON.
Returns:
JSON string representation
"""
data = {
"elapsed_seconds": self.get_elapsed_time(),
}
event = self._create_event("section_end", data, 0)
self._record_event(event)
return self._to_json(event)
def get_summary(self) -> Dict[str, Any]:
"""Get a summary of all recorded events.
Returns:
Summary dictionary with counts and totals
"""
event_counts: Dict[str, int] = {}
for event in self._events:
event_type = event.get("type", "unknown")
event_counts[event_type] = event_counts.get(event_type, 0) + 1
return {
"total_events": len(self._events),
"event_counts": event_counts,
"token_usage": {
"total_tokens": self._token_usage.session_total_tokens,
"total_cost": self._token_usage.session_cost,
},
"elapsed_seconds": self.get_elapsed_time(),
}
def export_events(self) -> str:
"""Export all recorded events as a JSON array.
Returns:
JSON string with all events
"""
return self._to_json({"events": self._events, "summary": self.get_summary()})

View File

@@ -0,0 +1,298 @@
# ABOUTME: Plain text output formatter for non-terminal environments
# ABOUTME: Provides basic text formatting without colors or special characters
"""Plain text output formatter for Claude adapter."""
from datetime import datetime
from typing import Optional
from .base import (
MessageType,
OutputFormatter,
ToolCallInfo,
VerbosityLevel,
)
class PlainTextFormatter(OutputFormatter):
"""Plain text formatter for environments without rich terminal support.
Produces readable output without ANSI codes, colors, or special characters.
Suitable for logging to files or basic terminal output.
"""
# Formatting constants
SEPARATOR_WIDTH = 60
HEADER_CHAR = "="
SUBHEADER_CHAR = "-"
SECTION_CHAR = "#"
def __init__(self, verbosity: VerbosityLevel = VerbosityLevel.NORMAL) -> None:
"""Initialize plain text formatter.
Args:
verbosity: Output verbosity level
"""
super().__init__(verbosity)
def _timestamp(self) -> str:
"""Get formatted timestamp string."""
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def _separator(self, char: str = "-", width: int = None) -> str:
"""Create a separator line."""
return char * (width or self.SEPARATOR_WIDTH)
def format_tool_call(
self,
tool_info: ToolCallInfo,
iteration: int = 0,
) -> str:
"""Format a tool call for plain text display.
Args:
tool_info: Tool call information
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.TOOL_CALL):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.TOOL_CALL, tool_info, context)
lines = [
self._separator(),
f"[{self._timestamp()}] TOOL CALL: {tool_info.tool_name}",
f" ID: {tool_info.tool_id[:12]}...",
]
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
if tool_info.input_params:
lines.append(" Input Parameters:")
for key, value in tool_info.input_params.items():
value_str = str(value)
if len(value_str) > 100:
value_str = value_str[:97] + "..."
lines.append(f" {key}: {value_str}")
lines.append(self._separator())
return "\n".join(lines)
def format_tool_result(
self,
tool_info: ToolCallInfo,
iteration: int = 0,
) -> str:
"""Format a tool result for plain text display.
Args:
tool_info: Tool call info with result
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.TOOL_RESULT):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.TOOL_RESULT, tool_info, context)
status = "ERROR" if tool_info.is_error else "Success"
duration = f" ({tool_info.duration_ms}ms)" if tool_info.duration_ms else ""
lines = [
f"[TOOL RESULT] {tool_info.tool_name}{duration}",
f" ID: {tool_info.tool_id[:12]}...",
f" Status: {status}",
]
if self._verbosity.value >= VerbosityLevel.VERBOSE.value and tool_info.result:
result_str = str(tool_info.result)
if len(result_str) > 500:
result_str = self.summarize_content(result_str, 500)
lines.append(" Output:")
for line in result_str.split("\n"):
lines.append(f" {line}")
lines.append(self._separator())
return "\n".join(lines)
def format_assistant_message(
self,
message: str,
iteration: int = 0,
) -> str:
"""Format an assistant message for plain text display.
Args:
message: Assistant message text
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.ASSISTANT):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.ASSISTANT, message, context)
if self._verbosity == VerbosityLevel.QUIET:
return ""
# Summarize if too long and not verbose
if self._verbosity == VerbosityLevel.NORMAL and len(message) > 1000:
message = self.summarize_content(message, 1000)
return f"[{self._timestamp()}] ASSISTANT:\n{message}\n"
def format_system_message(
self,
message: str,
iteration: int = 0,
) -> str:
"""Format a system message for plain text display.
Args:
message: System message text
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.SYSTEM):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.SYSTEM, message, context)
return f"[{self._timestamp()}] SYSTEM: {message}\n"
def format_error(
self,
error: str,
exception: Optional[Exception] = None,
iteration: int = 0,
) -> str:
"""Format an error for plain text display.
Args:
error: Error message
exception: Optional exception object
iteration: Current iteration number
Returns:
Formatted string representation
"""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.ERROR, error, context)
lines = [
self._separator(self.HEADER_CHAR),
f"[{self._timestamp()}] ERROR (Iteration {iteration})",
f" Message: {error}",
]
if exception and self._verbosity.value >= VerbosityLevel.VERBOSE.value:
lines.append(f" Type: {type(exception).__name__}")
import traceback
tb = "".join(traceback.format_exception(type(exception), exception, exception.__traceback__))
lines.append(" Traceback:")
for line in tb.split("\n"):
lines.append(f" {line}")
lines.append(self._separator(self.HEADER_CHAR))
return "\n".join(lines)
def format_progress(
self,
message: str,
current: int = 0,
total: int = 0,
iteration: int = 0,
) -> str:
"""Format progress information for plain text display.
Args:
message: Progress message
current: Current progress value
total: Total progress value
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.PROGRESS):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.PROGRESS, message, context)
if total > 0:
pct = (current / total) * 100
bar_width = 30
filled = int(bar_width * current / total)
bar = "#" * filled + "-" * (bar_width - filled)
return f"[{bar}] {pct:.1f}% - {message}"
return f"[...] {message}"
def format_token_usage(self, show_session: bool = True) -> str:
"""Format token usage summary for plain text display.
Args:
show_session: Include session totals
Returns:
Formatted string representation
"""
usage = self._token_usage
lines = [
self._separator(self.SUBHEADER_CHAR),
"TOKEN USAGE:",
f" Current: {usage.total_tokens:,} tokens (${usage.cost:.4f})",
f" Input: {usage.input_tokens:,} | Output: {usage.output_tokens:,}",
]
if show_session:
lines.extend([
f" Session: {usage.session_total_tokens:,} tokens (${usage.session_cost:.4f})",
f" Input: {usage.session_input_tokens:,} | Output: {usage.session_output_tokens:,}",
])
if usage.model:
lines.append(f" Model: {usage.model}")
lines.append(self._separator(self.SUBHEADER_CHAR))
return "\n".join(lines)
def format_section_header(self, title: str, iteration: int = 0) -> str:
"""Format a section header for plain text display.
Args:
title: Section title
iteration: Current iteration number
Returns:
Formatted string representation
"""
lines = [
"",
self._separator(self.HEADER_CHAR),
f"{title} (Iteration {iteration})" if iteration else title,
self._separator(self.HEADER_CHAR),
]
return "\n".join(lines)
def format_section_footer(self) -> str:
"""Format a section footer for plain text display.
Returns:
Formatted string representation
"""
elapsed = self.get_elapsed_time()
return f"\n{self._separator(self.SUBHEADER_CHAR)}\nElapsed: {elapsed:.1f}s\n"

View File

@@ -0,0 +1,737 @@
# ABOUTME: Rich terminal formatter with colors, panels, and progress indicators
# ABOUTME: Provides visually enhanced output using the Rich library with smart content detection
"""Rich terminal output formatter for Claude adapter.
This formatter provides intelligent content detection and rendering:
- Diffs are rendered with color-coded additions/deletions
- Code blocks get syntax highlighting
- Markdown is rendered with proper formatting
- Error tracebacks are highlighted for readability
"""
import logging
import re
from datetime import datetime
from io import StringIO
from typing import Optional
from .base import (
MessageType,
OutputFormatter,
ToolCallInfo,
VerbosityLevel,
)
from .content_detector import ContentDetector, ContentType
_logger = logging.getLogger(__name__)
# Try to import Rich components with fallback
try:
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from rich.syntax import Syntax
from rich.markup import escape
RICH_AVAILABLE = True
except ImportError:
RICH_AVAILABLE = False
Console = None # type: ignore
Panel = None # type: ignore
class RichTerminalFormatter(OutputFormatter):
"""Rich terminal formatter with colors, panels, and progress indicators.
Provides visually enhanced output using the Rich library for terminal
display. Falls back to plain text if Rich is not available.
"""
# Color scheme
COLORS = {
"tool_name": "bold cyan",
"tool_id": "dim",
"success": "bold green",
"error": "bold red",
"warning": "yellow",
"info": "blue",
"timestamp": "dim white",
"header": "bold magenta",
"assistant": "white",
"system": "dim cyan",
"token_input": "green",
"token_output": "yellow",
"cost": "bold yellow",
}
# Icons
ICONS = {
"tool": "",
"success": "",
"error": "",
"warning": "",
"info": "",
"assistant": "",
"system": "",
"token": "",
"clock": "",
"progress": "",
}
def __init__(
self,
verbosity: VerbosityLevel = VerbosityLevel.NORMAL,
console: Optional["Console"] = None,
smart_detection: bool = True,
) -> None:
"""Initialize rich terminal formatter.
Args:
verbosity: Output verbosity level
console: Optional Rich console instance (creates new if None)
smart_detection: Enable smart content detection (diff, code, markdown)
"""
super().__init__(verbosity)
self._rich_available = RICH_AVAILABLE
self._smart_detection = smart_detection
self._content_detector = ContentDetector() if smart_detection else None
if RICH_AVAILABLE:
self._console = console or Console()
# Import DiffFormatter for diff rendering
from .console import DiffFormatter
self._diff_formatter = DiffFormatter(self._console)
else:
self._console = None
self._diff_formatter = None
@property
def console(self) -> Optional["Console"]:
"""Get the Rich console instance."""
return self._console
def _timestamp(self) -> str:
"""Get formatted timestamp string with Rich markup."""
ts = datetime.now().strftime("%H:%M:%S")
if self._rich_available:
return f"[{self.COLORS['timestamp']}]{ts}[/]"
return ts
def _full_timestamp(self) -> str:
"""Get full timestamp with date."""
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self._rich_available:
return f"[{self.COLORS['timestamp']}]{ts}[/]"
return ts
def format_tool_call(
self,
tool_info: ToolCallInfo,
iteration: int = 0,
) -> str:
"""Format a tool call for rich terminal display.
Args:
tool_info: Tool call information
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.TOOL_CALL):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.TOOL_CALL, tool_info, context)
if not self._rich_available:
return self._format_tool_call_plain(tool_info)
# Build rich formatted output
icon = self.ICONS["tool"]
name_color = self.COLORS["tool_name"]
id_color = self.COLORS["tool_id"]
lines = [
f"{icon} [{name_color}]TOOL CALL: {tool_info.tool_name}[/]",
f" [{id_color}]ID: {tool_info.tool_id[:12]}...[/]",
]
if self._verbosity.value >= VerbosityLevel.VERBOSE.value:
if tool_info.input_params:
lines.append(f" [{self.COLORS['info']}]Input Parameters:[/]")
for key, value in tool_info.input_params.items():
value_str = str(value)
if len(value_str) > 100:
value_str = value_str[:97] + "..."
# Escape Rich markup in values
if self._rich_available:
value_str = escape(value_str)
lines.append(f" - {key}: {value_str}")
return "\n".join(lines)
def _format_tool_call_plain(self, tool_info: ToolCallInfo) -> str:
"""Plain fallback for tool call formatting."""
lines = [
f"TOOL CALL: {tool_info.tool_name}",
f" ID: {tool_info.tool_id[:12]}...",
]
if tool_info.input_params:
for key, value in tool_info.input_params.items():
value_str = str(value)[:100]
lines.append(f" {key}: {value_str}")
return "\n".join(lines)
def format_tool_result(
self,
tool_info: ToolCallInfo,
iteration: int = 0,
) -> str:
"""Format a tool result for rich terminal display.
Args:
tool_info: Tool call info with result
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.TOOL_RESULT):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.TOOL_RESULT, tool_info, context)
if not self._rich_available:
return self._format_tool_result_plain(tool_info)
# Determine status styling
if tool_info.is_error:
status_icon = self.ICONS["error"]
status_color = self.COLORS["error"]
status_text = "ERROR"
else:
status_icon = self.ICONS["success"]
status_color = self.COLORS["success"]
status_text = "Success"
duration = f" ({tool_info.duration_ms}ms)" if tool_info.duration_ms else ""
lines = [
f"{status_icon} [{status_color}]TOOL RESULT: {tool_info.tool_name}{duration}[/]",
f" [{self.COLORS['tool_id']}]ID: {tool_info.tool_id[:12]}...[/]",
f" Status: [{status_color}]{status_text}[/]",
]
if self._verbosity.value >= VerbosityLevel.VERBOSE.value and tool_info.result:
result_str = str(tool_info.result)
if len(result_str) > 500:
result_str = self.summarize_content(result_str, 500)
# Escape Rich markup in result
if self._rich_available:
result_str = escape(result_str)
lines.append(f" [{self.COLORS['info']}]Output:[/]")
for line in result_str.split("\n")[:20]: # Limit lines
lines.append(f" {line}")
if result_str.count("\n") > 20:
lines.append(f" [{self.COLORS['timestamp']}]... ({result_str.count(chr(10)) - 20} more lines)[/]")
return "\n".join(lines)
def _format_tool_result_plain(self, tool_info: ToolCallInfo) -> str:
"""Plain fallback for tool result formatting."""
status = "ERROR" if tool_info.is_error else "Success"
lines = [
f"TOOL RESULT: {tool_info.tool_name}",
f" Status: {status}",
]
if tool_info.result:
lines.append(f" Output: {str(tool_info.result)[:200]}")
return "\n".join(lines)
def format_assistant_message(
self,
message: str,
iteration: int = 0,
) -> str:
"""Format an assistant message for rich terminal display.
With smart_detection enabled, detects and renders:
- Diffs with color-coded additions/deletions
- Code blocks with syntax highlighting
- Markdown with proper formatting
- Error tracebacks with special highlighting
Args:
message: Assistant message text
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.ASSISTANT):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.ASSISTANT, message, context)
if self._verbosity == VerbosityLevel.QUIET:
return ""
# Summarize if needed (only for normal verbosity)
display_message = message
if self._verbosity == VerbosityLevel.NORMAL and len(message) > 1000:
display_message = self.summarize_content(message, 1000)
if not self._rich_available:
return f"ASSISTANT: {display_message}"
# Use smart detection if enabled
if self._smart_detection and self._content_detector:
content_type = self._content_detector.detect(display_message)
return self._render_smart_content(display_message, content_type)
# Fallback to simple formatting
icon = self.ICONS["assistant"]
return f"{icon} [{self.COLORS['assistant']}]{escape(display_message)}[/]"
def _render_smart_content(self, text: str, content_type: ContentType) -> str:
"""Render content based on detected type.
Args:
text: Text content to render
content_type: Detected content type
Returns:
Formatted string (may include Rich markup)
"""
if content_type == ContentType.DIFF:
return self._render_diff(text)
elif content_type == ContentType.CODE_BLOCK:
return self._render_code_blocks(text)
elif content_type == ContentType.MARKDOWN:
return self._render_markdown(text)
elif content_type == ContentType.MARKDOWN_TABLE:
return self._render_markdown(text) # Tables use markdown renderer
elif content_type == ContentType.ERROR_TRACEBACK:
return self._render_traceback(text)
else:
# Plain text - escape and format
icon = self.ICONS["assistant"]
return f"{icon} [{self.COLORS['assistant']}]{escape(text)}[/]"
def _render_diff(self, text: str) -> str:
"""Render diff content with colors.
Uses the DiffFormatter for enhanced diff visualization with
color-coded additions/deletions and file statistics.
Args:
text: Diff text to render
Returns:
Empty string (diff is printed directly to console)
"""
if self._diff_formatter and self._console:
# DiffFormatter prints directly, so capture would require buffer
# For now, print directly and return marker
self._diff_formatter.format_and_print(text)
return "" # Already printed
return f"[dim]{text}[/dim]"
def _render_code_blocks(self, text: str) -> str:
"""Render text with code blocks using syntax highlighting.
Extracts code blocks, renders them with Rich Syntax, and
formats the surrounding text.
Args:
text: Text containing code blocks
Returns:
Formatted string with code block markers for print_smart()
"""
if not self._console or not self._content_detector:
return f"[dim]{text}[/dim]"
# Split by code blocks and render each part
parts = []
pattern = r"```(\w+)?\n(.*?)\n```"
last_end = 0
for match in re.finditer(pattern, text, re.DOTALL):
# Text before code block
before = text[last_end:match.start()].strip()
if before:
parts.append(("text", before))
# Code block
language = match.group(1) or "text"
code = match.group(2)
parts.append(("code", language, code))
last_end = match.end()
# Text after last code block
after = text[last_end:].strip()
if after:
parts.append(("text", after))
# Render to string buffer using console
buffer = StringIO()
temp_console = Console(file=buffer, force_terminal=True, width=100)
for part in parts:
if part[0] == "text":
temp_console.print(part[1], markup=True, highlight=True)
else: # code
_, language, code = part
syntax = Syntax(
code,
language,
theme="monokai",
line_numbers=True,
word_wrap=True,
)
temp_console.print(syntax)
return buffer.getvalue()
def _render_markdown(self, text: str) -> str:
"""Render markdown content with Rich formatting.
Uses Rich's Markdown renderer for headings, lists, emphasis, etc.
Args:
text: Markdown text to render
Returns:
Formatted markdown string
"""
if not self._console:
return text
try:
from rich.markdown import Markdown
# Preprocess for task lists
processed = self._preprocess_markdown(text)
buffer = StringIO()
temp_console = Console(file=buffer, force_terminal=True, width=100)
temp_console.print(Markdown(processed))
return buffer.getvalue()
except ImportError:
return text
def _preprocess_markdown(self, text: str) -> str:
"""Preprocess markdown for enhanced rendering.
Converts task list checkboxes to visual indicators.
Args:
text: Raw markdown
Returns:
Preprocessed markdown
"""
# Convert task lists: [ ] -> ☐, [x] -> ☑
text = re.sub(r"\[\s\]", "", text)
text = re.sub(r"\[[xX]\]", "", text)
return text
def _render_traceback(self, text: str) -> str:
"""Render error traceback with syntax highlighting.
Uses Python syntax highlighting for better readability.
Args:
text: Traceback text to render
Returns:
Formatted traceback string
"""
if not self._console:
return f"[red]{text}[/red]"
try:
buffer = StringIO()
temp_console = Console(file=buffer, force_terminal=True, width=100)
temp_console.print("[red bold]⚠ Error Traceback:[/red bold]")
syntax = Syntax(
text,
"python",
theme="monokai",
line_numbers=False,
word_wrap=True,
background_color="grey11",
)
temp_console.print(syntax)
return buffer.getvalue()
except Exception as e:
_logger.warning("Rich traceback rendering failed: %s: %s", type(e).__name__, e)
return f"[red]{escape(text)}[/red]"
def print_smart(self, message: str, iteration: int = 0) -> None:
"""Print message with smart content detection directly to console.
This is the preferred method for displaying assistant messages
as it handles all content types appropriately and prints directly.
Args:
message: Message text to print
iteration: Current iteration number
"""
if not self.should_display(MessageType.ASSISTANT):
return
if self._verbosity == VerbosityLevel.QUIET:
return
if not self._console:
print(f"ASSISTANT: {message}")
return
# Use smart detection
if self._smart_detection and self._content_detector:
content_type = self._content_detector.detect(message)
if content_type == ContentType.DIFF:
# DiffFormatter prints directly
if self._diff_formatter:
self._diff_formatter.format_and_print(message)
return
# For other content types, use format and print
formatted = self.format_assistant_message(message, iteration)
if formatted:
self._console.print(formatted, markup=True)
def format_system_message(
self,
message: str,
iteration: int = 0,
) -> str:
"""Format a system message for rich terminal display.
Args:
message: System message text
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.SYSTEM):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.SYSTEM, message, context)
if not self._rich_available:
return f"SYSTEM: {message}"
icon = self.ICONS["system"]
return f"{icon} [{self.COLORS['system']}]SYSTEM: {message}[/]"
def format_error(
self,
error: str,
exception: Optional[Exception] = None,
iteration: int = 0,
) -> str:
"""Format an error for rich terminal display.
Args:
error: Error message
exception: Optional exception object
iteration: Current iteration number
Returns:
Formatted string representation
"""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.ERROR, error, context)
if not self._rich_available:
return f"ERROR: {error}"
icon = self.ICONS["error"]
color = self.COLORS["error"]
lines = [
f"\n{icon} [{color}]ERROR (Iteration {iteration})[/]",
f" [{color}]{error}[/]",
]
if exception and self._verbosity.value >= VerbosityLevel.VERBOSE.value:
lines.append(f" [{self.COLORS['warning']}]Type: {type(exception).__name__}[/]")
import traceback
tb = "".join(traceback.format_exception(type(exception), exception, exception.__traceback__))
lines.append(f" [{self.COLORS['timestamp']}]Traceback:[/]")
for line in tb.split("\n")[:15]: # Limit traceback lines
if line.strip():
lines.append(f" {escape(line)}" if self._rich_available else f" {line}")
return "\n".join(lines)
def format_progress(
self,
message: str,
current: int = 0,
total: int = 0,
iteration: int = 0,
) -> str:
"""Format progress information for rich terminal display.
Args:
message: Progress message
current: Current progress value
total: Total progress value
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self.should_display(MessageType.PROGRESS):
return ""
context = self._create_context(iteration)
self._notify_callbacks(MessageType.PROGRESS, message, context)
if not self._rich_available:
if total > 0:
pct = (current / total) * 100
return f"[{pct:.0f}%] {message}"
return f"[...] {message}"
icon = self.ICONS["progress"]
if total > 0:
pct = (current / total) * 100
bar_width = 20
filled = int(bar_width * current / total)
bar = "" * filled + "" * (bar_width - filled)
return f"{icon} [{self.COLORS['info']}][{bar}] {pct:.0f}%[/] {message}"
return f"{icon} [{self.COLORS['info']}][...][/] {message}"
def format_token_usage(self, show_session: bool = True) -> str:
"""Format token usage summary for rich terminal display.
Args:
show_session: Include session totals
Returns:
Formatted string representation
"""
usage = self._token_usage
if not self._rich_available:
lines = [
f"TOKEN USAGE: {usage.total_tokens:,} (${usage.cost:.4f})",
]
if show_session:
lines.append(f" Session: {usage.session_total_tokens:,} (${usage.session_cost:.4f})")
return "\n".join(lines)
icon = self.ICONS["token"]
input_color = self.COLORS["token_input"]
output_color = self.COLORS["token_output"]
cost_color = self.COLORS["cost"]
lines = [
f"\n{icon} [{self.COLORS['header']}]TOKEN USAGE[/]",
f" Current: [{input_color}]{usage.input_tokens:,} in[/] | [{output_color}]{usage.output_tokens:,} out[/] | [{cost_color}]${usage.cost:.4f}[/]",
]
if show_session:
lines.append(
f" Session: [{input_color}]{usage.session_input_tokens:,} in[/] | [{output_color}]{usage.session_output_tokens:,} out[/] | [{cost_color}]${usage.session_cost:.4f}[/]"
)
if usage.model:
lines.append(f" [{self.COLORS['timestamp']}]Model: {usage.model}[/]")
return "\n".join(lines)
def format_section_header(self, title: str, iteration: int = 0) -> str:
"""Format a section header for rich terminal display.
Args:
title: Section title
iteration: Current iteration number
Returns:
Formatted string representation
"""
if not self._rich_available:
sep = "=" * 60
header_title = f"{title} (Iteration {iteration})" if iteration else title
return f"\n{sep}\n{header_title}\n{sep}"
header_title = f"{title} (Iteration {iteration})" if iteration else title
sep = "" * 50
return f"\n[{self.COLORS['header']}]{sep}\n{header_title}\n{sep}[/]"
def format_section_footer(self) -> str:
"""Format a section footer for rich terminal display.
Returns:
Formatted string representation
"""
elapsed = self.get_elapsed_time()
if not self._rich_available:
return f"\n{'=' * 50}\nElapsed: {elapsed:.1f}s\n"
icon = self.ICONS["clock"]
return f"\n[{self.COLORS['timestamp']}]{icon} Elapsed: {elapsed:.1f}s[/]\n"
def print(self, text: str) -> None:
"""Print formatted text to console.
Args:
text: Rich-formatted text to print
"""
if self._console:
self._console.print(text, markup=True)
else:
# Strip markup for plain output
import re
plain = re.sub(r"\[/?[^\]]+\]", "", text)
print(plain)
def print_panel(self, content: str, title: str = "", border_style: str = "blue") -> None:
"""Print content in a Rich panel.
Args:
content: Content to display
title: Panel title
border_style: Panel border color
"""
if self._console and self._rich_available and Panel:
panel = Panel(content, title=title, border_style=border_style)
self._console.print(panel)
else:
if title:
print(f"\n=== {title} ===")
print(content)
print()
def create_progress_bar(self) -> Optional["Progress"]:
"""Create a Rich progress bar instance.
Returns:
Progress instance or None if Rich not available
"""
if not self._rich_available or not Progress:
return None
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=self._console,
)

View File

@@ -0,0 +1,156 @@
# ABOUTME: Safety guardrails and circuit breakers for Ralph Orchestrator
# ABOUTME: Prevents runaway loops and excessive costs
"""Safety mechanisms for Ralph Orchestrator."""
from collections import deque
from dataclasses import dataclass
from typing import Optional
import logging
logger = logging.getLogger('ralph-orchestrator.safety')
@dataclass
class SafetyCheckResult:
"""Result of a safety check."""
passed: bool
reason: Optional[str] = None
class SafetyGuard:
"""Safety guardrails for orchestration."""
def __init__(
self,
max_iterations: int = 100,
max_runtime: int = 14400, # 4 hours
max_cost: float = 10.0,
consecutive_failure_limit: int = 5
):
"""Initialize safety guard.
Args:
max_iterations: Maximum allowed iterations
max_runtime: Maximum runtime in seconds
max_cost: Maximum allowed cost in dollars
consecutive_failure_limit: Max consecutive failures before stopping
"""
self.max_iterations = max_iterations
self.max_runtime = max_runtime
self.max_cost = max_cost
self.consecutive_failure_limit = consecutive_failure_limit
self.consecutive_failures = 0
# Loop detection state
self.recent_outputs: deque = deque(maxlen=5)
self.loop_threshold: float = 0.9
def check(
self,
iterations: int,
elapsed_time: float,
total_cost: float
) -> SafetyCheckResult:
"""Check all safety conditions.
Args:
iterations: Current iteration count
elapsed_time: Elapsed time in seconds
total_cost: Total cost so far
Returns:
SafetyCheckResult indicating if it's safe to continue
"""
# Check iteration limit
if iterations >= self.max_iterations:
return SafetyCheckResult(
passed=False,
reason=f"Reached maximum iterations ({self.max_iterations})"
)
# Check runtime limit
if elapsed_time >= self.max_runtime:
hours = elapsed_time / 3600
return SafetyCheckResult(
passed=False,
reason=f"Reached maximum runtime ({hours:.1f} hours)"
)
# Check cost limit
if total_cost >= self.max_cost:
return SafetyCheckResult(
passed=False,
reason=f"Reached maximum cost (${total_cost:.2f})"
)
# Check consecutive failures
if self.consecutive_failures >= self.consecutive_failure_limit:
return SafetyCheckResult(
passed=False,
reason=f"Too many consecutive failures ({self.consecutive_failures})"
)
# Additional safety checks for high iteration counts
if iterations > 50:
# Warn but don't stop
logger.warning(f"High iteration count: {iterations}")
if iterations > 75:
# More aggressive checks
if elapsed_time / iterations > 300: # More than 5 min per iteration avg
return SafetyCheckResult(
passed=False,
reason="Iterations taking too long on average"
)
return SafetyCheckResult(passed=True)
def record_success(self):
"""Record a successful iteration."""
self.consecutive_failures = 0
def record_failure(self):
"""Record a failed iteration."""
self.consecutive_failures += 1
logger.warning(f"Consecutive failures: {self.consecutive_failures}")
def reset(self):
"""Reset safety counters."""
self.consecutive_failures = 0
self.recent_outputs.clear()
def detect_loop(self, current_output: str) -> bool:
"""Detect if agent is looping based on output similarity.
Uses rapidfuzz for fast fuzzy string matching. If the current output
is more than 90% similar to any recent output, a loop is detected.
Args:
current_output: The current agent output to check.
Returns:
True if loop detected (similar output found), False otherwise.
"""
if not current_output:
return False
try:
from rapidfuzz import fuzz
for prev_output in self.recent_outputs:
ratio = fuzz.ratio(current_output, prev_output) / 100.0
if ratio >= self.loop_threshold:
logger.warning(
f"Loop detected: {ratio:.1%} similarity to previous output"
)
return True
self.recent_outputs.append(current_output)
return False
except ImportError:
# rapidfuzz not installed, skip loop detection
logger.debug("rapidfuzz not installed, skipping loop detection")
return False
except Exception as e:
logger.warning(f"Error in loop detection: {e}")
return False

View File

@@ -0,0 +1,488 @@
# ABOUTME: Security utilities for Ralph Orchestrator
# ABOUTME: Provides input validation, path sanitization, and sensitive data protection
"""
Security utilities for Ralph Orchestrator.
This module provides security hardening functions including input validation,
path sanitization, and sensitive data protection.
"""
import re
import logging
from pathlib import Path
from typing import Any, Optional
logger = logging.getLogger("ralph-orchestrator.security")
class SecurityValidator:
"""Security validation utilities for Ralph Orchestrator."""
# Patterns for dangerous path components
DANGEROUS_PATH_PATTERNS = [
r"\.\.\/.*", # Directory traversal (Unix)
r"\.\.\\.*", # Windows directory traversal
r"^\.\.[\/\\]", # Starts with parent directory
r"[\/\\]\.\.[\/\\]", # Contains parent directory
r"[<>:\"|?*]", # Invalid filename characters (Windows)
r"[\x00-\x1f]", # Control characters
r"[\/\\]\.\.[\/\\]\.\.[\/\\]", # Double traversal
]
# Sensitive data patterns that should be masked (16+ patterns)
SENSITIVE_PATTERNS = [
# API Keys
(r"(sk-[a-zA-Z0-9]{10,})", r"sk-***********"), # OpenAI API keys
(r"(xai-[a-zA-Z0-9]{10,})", r"xai-***********"), # xAI API keys
(r"(AIza[a-zA-Z0-9_-]{35})", r"AIza***********"), # Google API keys
# Bearer tokens
(r"(Bearer [a-zA-Z0-9\-_\.]{20,})", r"Bearer ***********"),
# Passwords in various formats
(
r'(["\']?password["\']?\s*[:=]\s*["\']?)([^"\'\s]{3,})(["\']?)',
r"\1*********\3",
),
(r"(password\s*=\s*)([^\"'\s]{3,})", r"\1*********"),
# Tokens in various formats
(
r'(token["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9\-_\.]{10,})(["\']?)',
r"\1*********\3",
),
(r"(token\s*=\s*)([a-zA-Z0-9\-_\.]{10,})", r"\1*********"),
# Secrets
(
r'(secret["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9\-_\.]{10,})(["\']?)',
r"\1*********\3",
),
(r"(secret\s*=\s*)([a-zA-Z0-9\-_\.]{10,})", r"\1*********"),
# Generic keys
(
r'(key["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9\-_\.]{10,})(["\']?)',
r"\1*********\3",
),
# API keys in various formats
(
r'(api[_-]?key["\']?\s*[:=]\s*["\']?)([a-zA-Z0-9\-_\.]{10,})(["\']?)',
r"\1*********\3",
),
(r"(api[_-]?key\s*=\s*)([a-zA-Z0-9\-_\.]{10,})", r"\1*********"),
# Sensitive file paths
(
r"(/[a-zA-Z0-9_\-\./]*\.ssh/[a-zA-Z0-9_\-\./]*)",
r"[REDACTED_SSH_PATH]",
), # SSH paths
(
r"(/[a-zA-Z0-9_\-\./]*\.ssh/id_[a-zA-Z0-9]*)",
r"[REDACTED_SSH_KEY]",
), # SSH private keys
(
r"(/[a-zA-Z0-9_\-\./]*\.config/[a-zA-Z0-9_\-\./]*)",
r"[REDACTED_CONFIG_PATH]",
), # Config files
(
r"(/[a-zA-Z0-9_\-\./]*\.aws/[a-zA-Z0-9_\-\./]*)",
r"[REDACTED_AWS_PATH]",
), # AWS credentials
(
r"(/[a-zA-Z0-9_\-\./]*(passwd|shadow|group|hosts))",
r"[REDACTED_SYSTEM_FILE]",
), # System files
(
r"(C:\\\\[a-zA-Z0-9_\-\./]*\\\\System32\\\\[a-zA-Z0-9_\-\./]*)",
r"[REDACTED_SYSTEM_PATH]",
), # Windows system files
(
r"(/[a-zA-Z0-9_\-\./]*(id_rsa|id_dsa|id_ecdsa|id_ed25519))",
r"[REDACTED_PRIVATE_KEY]",
), # Private key files
]
# Dangerous absolute path prefixes
DANGEROUS_ABS_PATHS = [
"/etc",
"/usr/bin",
"/bin",
"/sbin",
"/root",
"/var",
"/opt",
"/sys",
"/proc",
"/dev",
]
@classmethod
def sanitize_path(cls, path: str, base_dir: Optional[Path] = None) -> Path:
"""
Sanitize a file path to prevent directory traversal attacks.
Args:
path: Input path to sanitize
base_dir: Base directory to resolve relative paths against
Returns:
Sanitized absolute Path
Raises:
ValueError: If path contains dangerous patterns
"""
if base_dir is None:
base_dir = Path.cwd()
# Convert to Path object
try:
input_path = Path(path)
except (ValueError, OSError) as e:
raise ValueError(f"Invalid path: {path}") from e
# Check for dangerous patterns
path_str = str(input_path)
for pattern in cls.DANGEROUS_PATH_PATTERNS:
if re.search(pattern, path_str, re.IGNORECASE):
raise ValueError(f"Path contains dangerous pattern: {path}")
# Check for dangerous absolute paths
if input_path.is_absolute():
for dangerous in cls.DANGEROUS_ABS_PATHS:
if path_str.startswith(dangerous):
raise ValueError(
f"Path resolves to dangerous system location: {path_str}"
)
# Resolve the path
if input_path.is_absolute():
resolved_path = input_path.resolve()
else:
resolved_path = (base_dir / input_path).resolve()
# Ensure resolved path is within base directory or a safe location
try:
resolved_path.relative_to(base_dir.resolve())
except ValueError:
# Check if this is an absolute path that might be dangerous
if input_path.is_absolute():
# Check dangerous absolute paths
dangerous_paths = cls.DANGEROUS_ABS_PATHS + ["/home"]
for dangerous in dangerous_paths:
try:
resolved_path.relative_to(dangerous)
raise ValueError(
f"Path resolves to dangerous system location: {resolved_path}"
)
except ValueError:
continue
else:
# Relative path that goes outside base directory
raise ValueError(
f"Path traversal detected: {path} -> {resolved_path}"
) from None
return resolved_path
@classmethod
def validate_config_value(cls, key: str, value: Any) -> Any:
"""
Validate and sanitize configuration values.
Args:
key: Configuration key
value: Configuration value
Returns:
Sanitized value
Raises:
ValueError: If value is invalid or dangerous
"""
if value is None:
return value
# Type-specific validation
if key in ["delay", "stats_interval", "max_iterations", "iteration_timeout"]:
if isinstance(value, str):
try:
value = int(value)
except ValueError as e:
raise ValueError(f"Invalid integer value for {key}: {value}") from e
# Validate ranges
if value < 0:
raise ValueError(f"{key} must be non-negative, got: {value}")
if key == "delay" and value > 86400: # 24 hours
raise ValueError(f"{key} too large (>24 hours): {value}")
if key == "max_iterations" and value > 10000:
raise ValueError(f"{key} too large (>10000): {value}")
if key == "stats_interval" and value > 3600: # 1 hour
raise ValueError(f"{key} too large (>1 hour): {value}")
if key == "iteration_timeout" and value > 7200: # 2 hours
raise ValueError(f"{key} too large (>2 hours): {value}")
elif key in ["log_file", "pid_file", "prompt_file", "system_prompt_file"]:
if isinstance(value, str):
# Sanitize file paths for non-prompt files
if key not in ["prompt_file", "system_prompt_file"]:
cls.sanitize_path(value)
elif key in [
"verbose",
"dry_run",
"clear_screen",
"show_countdown",
"inject_best_practices",
]:
# Boolean validation
if isinstance(value, str):
value = cls._parse_bool_safe(value)
elif not isinstance(value, bool):
raise ValueError(f"Invalid boolean value for {key}: {value}")
elif key == "focus":
if isinstance(value, str):
# Sanitize focus text - remove potential command injection
value = re.sub(r"[;&|`$()]", "", value)
if len(value) > 200:
value = value[:200]
return value
@classmethod
def _parse_bool_safe(cls, value: str) -> bool:
"""
Safely parse boolean values from strings.
Args:
value: String value to parse
Returns:
Boolean value
"""
if not value or not value.strip():
return False
value_lower = value.lower().strip()
# Remove any dangerous characters
value_clean = re.sub(r"[;&|`$()]", "", value_lower)
true_values = ("true", "1", "yes", "on")
false_values = ("false", "0", "no", "off")
if value_clean in true_values:
return True
elif value_clean in false_values:
return False
else:
# Default to False for ambiguous values
return False
@classmethod
def mask_sensitive_data(cls, text: str) -> str:
"""
Mask sensitive data in text for logging.
Args:
text: Text to mask sensitive data in
Returns:
Text with sensitive data masked
"""
masked_text = text
for pattern, replacement in cls.SENSITIVE_PATTERNS:
masked_text = re.sub(pattern, replacement, masked_text, flags=re.IGNORECASE)
return masked_text
@classmethod
def validate_filename(cls, filename: str) -> str:
"""
Validate a filename for security.
Args:
filename: Filename to validate
Returns:
Sanitized filename
Raises:
ValueError: If filename is invalid or dangerous
"""
if not filename or not filename.strip():
raise ValueError("Filename cannot be empty")
# Check for path traversal attempts in filename
if ".." in filename or "/" in filename or "\\" in filename:
raise ValueError(f"Filename contains path traversal: {filename}")
# Remove dangerous characters
sanitized = re.sub(r'[<>:"|?*\x00-\x1f]', "", filename.strip())
if not sanitized:
raise ValueError("Filename contains only invalid characters")
# Prevent reserved names (Windows)
reserved_names = {
"CON",
"PRN",
"AUX",
"NUL",
"COM1",
"COM2",
"COM3",
"COM4",
"COM5",
"COM6",
"COM7",
"COM8",
"COM9",
"LPT1",
"LPT2",
"LPT3",
"LPT4",
"LPT5",
"LPT6",
"LPT7",
"LPT8",
"LPT9",
}
name_without_ext = sanitized.split(".")[0].upper()
if name_without_ext in reserved_names:
raise ValueError(f"Filename uses reserved name: {filename}")
# Check for control characters
if any(ord(char) < 32 for char in filename):
raise ValueError(f"Filename contains control characters: {filename}")
# Limit length
if len(sanitized) > 255:
sanitized = sanitized[:255]
return sanitized
@classmethod
def create_secure_logger(
cls, name: str, log_file: Optional[str] = None
) -> logging.Logger:
"""
Create a logger with security features enabled.
Args:
name: Logger name
log_file: Optional log file path
Returns:
Secure logger instance
"""
secure_logger = logging.getLogger(name)
# Create custom formatter that masks sensitive data
class SecureFormatter(logging.Formatter):
def format(self, record):
formatted = super().format(record)
return cls.mask_sensitive_data(formatted)
# Set up secure formatter
if log_file:
handler = logging.FileHandler(log_file)
else:
handler = logging.StreamHandler()
handler.setFormatter(
SecureFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
secure_logger.addHandler(handler)
secure_logger.setLevel(logging.INFO)
return secure_logger
class PathTraversalProtection:
"""Protection against path traversal attacks."""
@staticmethod
def safe_file_read(file_path: str, base_dir: Optional[Path] = None) -> str:
"""
Safely read a file with path traversal protection.
Args:
file_path: Path to file to read
base_dir: Base directory for relative paths
Returns:
File content
Raises:
ValueError: If path is dangerous
FileNotFoundError: If file doesn't exist
PermissionError: If file cannot be read
"""
safe_path = SecurityValidator.sanitize_path(file_path, base_dir)
if not safe_path.exists():
raise FileNotFoundError(f"File not found: {safe_path}")
if not safe_path.is_file():
raise ValueError(f"Path is not a file: {safe_path}")
try:
return safe_path.read_text(encoding="utf-8")
except PermissionError as e:
raise PermissionError(f"Cannot read file: {safe_path}") from e
@staticmethod
def safe_file_write(
file_path: str, content: str, base_dir: Optional[Path] = None
) -> None:
"""
Safely write to a file with path traversal protection.
Args:
file_path: Path to file to write
content: Content to write
base_dir: Base directory for relative paths
Raises:
ValueError: If path is dangerous
PermissionError: If file cannot be written
"""
safe_path = SecurityValidator.sanitize_path(file_path, base_dir)
# Create parent directories if needed
safe_path.parent.mkdir(parents=True, exist_ok=True)
try:
safe_path.write_text(content, encoding="utf-8")
except PermissionError as e:
raise PermissionError(f"Cannot write file: {safe_path}") from e
# Security decorator for functions that handle file paths
def secure_file_operation(base_dir: Optional[Path] = None):
"""
Decorator to secure file operations against path traversal.
Args:
base_dir: Base directory for relative paths
"""
def decorator(func):
def wrapper(*args, **kwargs):
# Find path arguments and sanitize them
new_args = []
for arg in args:
if isinstance(arg, str) and ("/" in arg or "\\" in arg):
arg = str(SecurityValidator.sanitize_path(arg, base_dir))
new_args.append(arg)
new_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, str) and ("/" in value or "\\" in value):
value = str(SecurityValidator.sanitize_path(value, base_dir))
new_kwargs[key] = value
return func(*new_args, **new_kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,951 @@
# ABOUTME: Enhanced verbose logging utilities for Ralph Orchestrator
# ABOUTME: Provides session metrics, emergency shutdown, re-entrancy protection, Rich output
"""Enhanced verbose logging utilities for Ralph."""
import asyncio
import json
import sys
import threading
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, TextIO, cast
try:
from rich.console import Console
from rich.markdown import Markdown
from rich.syntax import Syntax
from rich.table import Table
from rich.panel import Panel
# Unused imports commented out but kept for future reference:
# from rich.progress import Progress, SpinnerColumn, TextColumn
RICH_AVAILABLE = True
except ImportError:
RICH_AVAILABLE = False
Console = None # type: ignore
Markdown = None # type: ignore
Syntax = None # type: ignore
# Import DiffFormatter for enhanced diff output
try:
from ralph_orchestrator.output import DiffFormatter
except ImportError:
DiffFormatter = None # type: ignore
class TextIOProxy:
"""TextIO proxy that captures Rich console output to a file."""
def __init__(self, file_path: Path) -> None:
"""
Initialize TextIO proxy.
Args:
file_path: Path to output file
"""
self.file_path = file_path
self._file: Optional[TextIO] = None
self._closed = False
self._lock = threading.Lock()
def _ensure_open(self) -> Optional[TextIO]:
"""Ensure file is open, opening lazily if needed."""
if self._closed:
return None
if self._file is None:
try:
self._file = open(self.file_path, "a", encoding="utf-8")
except (OSError, IOError):
self._closed = True
return None
return self._file
def write(self, text: str) -> int:
"""
Write text to file.
Args:
text: Text to write
Returns:
Number of characters written
"""
with self._lock:
if self._closed:
return 0
try:
f = self._ensure_open()
if f is None:
return 0
return f.write(text)
except (ValueError, OSError, AttributeError):
return 0
def flush(self) -> None:
"""Flush file buffer."""
with self._lock:
if not self._closed and self._file:
try:
self._file.flush()
except (ValueError, OSError):
pass
def close(self) -> None:
"""Close file."""
with self._lock:
if not self._closed and self._file:
try:
self._file.close()
except (ValueError, OSError):
pass
finally:
self._closed = True
self._file = None
def __del__(self) -> None:
"""Cleanup on deletion."""
self.close()
class VerboseLogger:
"""
Enhanced verbose logger that captures detailed output to log files.
Features:
- Session metrics tracking in JSON format
- Emergency shutdown capability
- Re-entrancy protection (prevent logging loops)
- Console output with Rich library integration
- Thread-safe operations
This logger captures all verbose output including:
- Claude SDK messages with full content
- Tool calls and results
- Console output with formatting preserved
- System events and status updates
- Error details and tracebacks
"""
_metrics: Dict[str, Any]
def __init__(self, log_dir: Optional[str] = None) -> None:
"""
Initialize verbose logger with thread safety.
Args:
log_dir: Directory to store verbose log files (defaults to .agent in cwd)
"""
if log_dir is None:
# Find repository root by looking for .git directory
current_dir = Path.cwd()
repo_root = current_dir
# Walk up to find .git directory or stop at filesystem root
while repo_root.parent != repo_root:
if (repo_root / ".git").exists():
break
repo_root = repo_root.parent
# Create .agent directory in repository root
log_dir = str(repo_root / ".agent")
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
# Create timestamped log file for current session
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.verbose_log_file = self.log_dir / f"ralph_verbose_{timestamp}.log"
self.raw_output_file = self.log_dir / f"ralph_raw_{timestamp}.log"
self.metrics_file = self.log_dir / f"ralph_metrics_{timestamp}.json"
# Thread safety: Use both asyncio and threading locks
self._lock = asyncio.Lock()
self._thread_lock = threading.RLock() # Re-entrant lock for thread safety
# Initialize Rich console or fallback
self._text_io_proxy = TextIOProxy(self.verbose_log_file)
if RICH_AVAILABLE:
self._console = Console(file=cast(TextIO, self._text_io_proxy), width=120)
self._live_console = Console() # For live terminal output
else:
self._console = None
self._live_console = None
# Initialize DiffFormatter for enhanced diff output
if RICH_AVAILABLE and DiffFormatter and self._console:
self._diff_formatter: Optional[DiffFormatter] = DiffFormatter(self._console)
else:
self._diff_formatter = None
self._raw_file_handle: Optional[TextIO] = None
# Emergency shutdown state
self._emergency_shutdown = False
self._emergency_event = threading.Event()
# Re-entrancy protection: Track if we're already logging
self._logging_depth = 0
self._logging_thread_ids: set = set()
self._max_logging_depth = 3 # Prevent deep nesting
# Session metrics tracking
self._metrics = {
"session_start": datetime.now().isoformat(),
"session_end": None,
"messages": [],
"tool_calls": [],
"errors": [],
"iterations": [],
"total_tokens": 0,
"total_cost": 0.0,
}
def _can_log_safely(self) -> bool:
"""
Check if logging is safe to perform (re-entrancy and thread safety check).
Returns:
True if logging is safe, False otherwise
"""
# Check emergency shutdown first
if self._emergency_event.is_set():
return False
# Get current thread ID
current_thread_id = threading.current_thread().ident
# Use thread lock for safe access to shared state
with self._thread_lock:
# Check if this thread is already in the middle of logging
if current_thread_id in self._logging_thread_ids:
# Check nesting depth
if self._logging_depth >= self._max_logging_depth:
return False # Too deeply nested
return True # Allow some nesting
# Check if any other thread is logging (to prevent excessive blocking)
if len(self._logging_thread_ids) > 0:
# Another thread is logging - we can still log but need to be careful
pass
# Check async lock state (non-blocking)
if self._lock.locked():
return False # Async lock is held, skip logging
return True
def _enter_logging_context(self) -> bool:
"""
Enter a logging context safely.
Returns:
True if we successfully entered the context, False otherwise
"""
current_thread_id = threading.current_thread().ident
with self._thread_lock:
if self._logging_depth >= self._max_logging_depth:
return False
if current_thread_id in self._logging_thread_ids:
self._logging_depth += 1
return True # Re-entrancy in same thread is okay with depth tracking
self._logging_thread_ids.add(current_thread_id)
self._logging_depth = 1
return True
def _exit_logging_context(self) -> None:
"""Exit a logging context safely."""
current_thread_id = threading.current_thread().ident
with self._thread_lock:
self._logging_depth = max(0, self._logging_depth - 1)
if self._logging_depth == 0:
self._logging_thread_ids.discard(current_thread_id)
def emergency_shutdown(self) -> None:
"""Signal emergency shutdown to make logging operations non-blocking."""
self._emergency_shutdown = True
self._emergency_event.set()
def is_shutdown(self) -> bool:
"""Check if emergency shutdown has been triggered."""
return self._emergency_event.is_set()
def _print_to_file(self, text: str) -> None:
"""Print text to the log file (Rich or plain)."""
if self._console and RICH_AVAILABLE:
self._console.print(text)
else:
self._text_io_proxy.write(text + "\n")
self._text_io_proxy.flush()
def _print_to_terminal(self, text: str) -> None:
"""Print text to the live terminal (Rich or plain)."""
if self._live_console and RICH_AVAILABLE:
self._live_console.print(text)
else:
print(text)
async def log_message(
self,
message_type: str,
content: Any,
iteration: int = 0,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""
Log a detailed message with rich formatting preserved (thread-safe).
Args:
message_type: Type of message (system, assistant, user, tool, etc.)
content: Message content (text, dict, object)
iteration: Current iteration number
metadata: Additional metadata about the message
"""
# Check if logging is safe (thread safety + re-entrancy)
if not self._can_log_safely():
return
# Enter logging context safely
if not self._enter_logging_context():
return
try:
# Use non-blocking lock acquisition
if self._lock.locked():
return
async with self._lock:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
# Create log entry
log_entry = {
"timestamp": timestamp,
"iteration": iteration,
"type": message_type,
"content": self._serialize_content(content),
"metadata": metadata or {},
}
# Write to verbose log with rich formatting
self._print_to_file(f"\n{'='*80}")
self._print_to_file(
f"[{timestamp}] Iteration {iteration} - {message_type}"
)
if metadata:
self._print_to_file(f"Metadata: {json.dumps(metadata, indent=2)}")
self._print_to_file(f"{'='*80}\n")
# Format content based on type
if isinstance(content, str):
if len(content) > 2000:
preview = content[:1000]
self._print_to_file(preview)
self._print_to_file(
f"\n[Content truncated ({len(content)} chars total)]"
)
else:
self._print_to_file(content)
elif isinstance(content, dict):
json_str = json.dumps(content, indent=2)
self._print_to_file(json_str)
else:
self._print_to_file(str(content))
# Write to raw log (complete content)
await self._write_raw_log(log_entry)
# Update metrics
await self._update_metrics("message", log_entry)
except Exception as e:
try:
print(f"Logging error in log_message: {e}", file=sys.stderr)
except Exception:
pass
finally:
self._exit_logging_context()
def log_message_sync(
self,
message_type: str,
content: Any,
iteration: int = 0,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Synchronous wrapper for log_message."""
if self._emergency_event.is_set():
return
try:
asyncio.get_running_loop() # Check if loop exists (raises RuntimeError if not)
asyncio.create_task(
self.log_message(message_type, content, iteration, metadata)
)
except RuntimeError:
# No running loop, run directly
asyncio.run(self.log_message(message_type, content, iteration, metadata))
async def log_tool_call(
self,
tool_name: str,
input_data: Any,
result: Any,
iteration: int,
duration_ms: Optional[int] = None,
) -> None:
"""
Log a detailed tool call with input and output (thread-safe).
Args:
tool_name: Name of the tool that was called
input_data: Tool input parameters
result: Tool execution result
iteration: Current iteration number
duration_ms: Tool execution duration in milliseconds
"""
if not self._can_log_safely():
return
if not self._enter_logging_context():
return
try:
if self._lock.locked():
return
async with self._lock:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
tool_entry = {
"timestamp": timestamp,
"iteration": iteration,
"tool_name": tool_name,
"input": self._serialize_content(input_data),
"result": self._serialize_content(result),
"duration_ms": duration_ms,
"success": result is not None,
}
# Write formatted tool call to verbose log
duration_text = f"{duration_ms}ms" if duration_ms else "unknown"
self._print_to_file(f"\n{'-'*60}")
self._print_to_file(f"TOOL CALL: {tool_name} ({duration_text})")
# Format input
if input_data:
self._print_to_file("\nInput:")
if isinstance(input_data, (dict, list)):
input_json = json.dumps(input_data, indent=2)
if len(input_json) > 1000:
input_json = (
input_json[:500]
+ "\n ... [truncated] ...\n"
+ input_json[-400:]
)
self._print_to_file(input_json)
else:
self._print_to_file(str(input_data)[:500])
# Format result
if result:
self._print_to_file("\nResult:")
result_str = self._serialize_content(result)
# Check if result is diff content and format with DiffFormatter
if (
isinstance(result_str, str)
and self._is_diff_content(result_str)
and self._diff_formatter
):
self._print_to_file(
"[Detected diff content - formatting with enhanced visualization]"
)
self._diff_formatter.format_and_print(result_str)
elif isinstance(result_str, str) and len(result_str) > 1500:
preview = (
result_str[:750]
+ "\n ... [truncated] ...\n"
+ result_str[-500:]
)
self._print_to_file(preview)
else:
self._print_to_file(str(result_str))
self._print_to_file(f"{'-'*60}\n")
await self._write_raw_log(tool_entry)
await self._update_metrics("tool_call", tool_entry)
except Exception as e:
try:
print(f"Logging error in log_tool_call: {e}", file=sys.stderr)
except Exception:
pass
finally:
self._exit_logging_context()
async def log_error(
self, error: Exception, iteration: int, context: Optional[str] = None
) -> None:
"""
Log detailed error information with traceback (thread-safe).
Args:
error: Exception that occurred
iteration: Current iteration number
context: Additional context about when the error occurred
"""
if not self._can_log_safely():
return
if not self._enter_logging_context():
return
try:
if self._lock.locked():
return
async with self._lock:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
error_entry = {
"timestamp": timestamp,
"iteration": iteration,
"error_type": type(error).__name__,
"error_message": str(error),
"context": context,
"traceback": self._get_traceback(error),
}
self._print_to_file(f"\n{'!'*20} ERROR DETAILS {'!'*20}")
self._print_to_file(f"[{timestamp}] Iteration {iteration}")
self._print_to_file(f"Error Type: {type(error).__name__}")
if context:
self._print_to_file(f"Context: {context}")
self._print_to_file(f"Message: {str(error)}")
traceback_str = self._get_traceback(error)
if traceback_str:
self._print_to_file("\nTraceback:")
self._print_to_file(traceback_str)
self._print_to_file(f"{'!'*20} END ERROR {'!'*20}\n")
await self._write_raw_log(error_entry)
await self._update_metrics("error", error_entry)
except Exception as e:
try:
print(f"Logging error in log_error: {e}", file=sys.stderr)
except Exception:
pass
finally:
self._exit_logging_context()
async def log_iteration_summary(
self,
iteration: int,
duration: int,
success: bool,
message_count: int,
stats: Dict[str, int],
tokens_used: int = 0,
cost: float = 0.0,
) -> None:
"""
Log a detailed iteration summary.
Args:
iteration: Iteration number
duration: Duration in seconds
success: Whether iteration was successful
message_count: Number of messages exchanged
stats: Message type statistics
tokens_used: Number of tokens used
cost: Cost of this iteration
"""
if self._emergency_event.is_set():
return
if self._lock.locked():
return
async with self._lock:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
summary_entry = {
"timestamp": timestamp,
"iteration": iteration,
"duration_seconds": duration,
"success": success,
"message_count": message_count,
"stats": stats,
"tokens_used": tokens_used,
"cost": cost,
}
status_icon = "SUCCESS" if success else "FAILED"
self._print_to_file(f"\n{'#'*15} ITERATION SUMMARY {'#'*15}")
self._print_to_file(f"{status_icon} - Iteration {iteration} - {duration}s")
self._print_to_file(f"Timestamp: {timestamp}")
self._print_to_file(f"Messages: {message_count}")
self._print_to_file(f"Tokens: {tokens_used}")
self._print_to_file(f"Cost: ${cost:.4f}")
if stats:
self._print_to_file("\nMessage Statistics:")
for msg_type, count in stats.items():
if count > 0:
self._print_to_file(f" {msg_type}: {count}")
self._print_to_file(f"{'#'*42}\n")
await self._write_raw_log(summary_entry)
await self._update_metrics("iteration", summary_entry)
# Update total metrics
self._metrics["total_tokens"] += tokens_used
self._metrics["total_cost"] += cost
def _serialize_content(
self, content: Any
) -> str | Dict[Any, Any] | List[Any] | int | float:
"""
Serialize content to JSON-serializable format.
Args:
content: Content to serialize
Returns:
Serialized content
"""
try:
if isinstance(content, str):
return content
elif hasattr(content, "__dict__"):
if hasattr(content, "text"):
return {"text": content.text, "type": type(content).__name__}
elif hasattr(content, "content"):
return {"content": content.content, "type": type(content).__name__}
else:
return {"repr": str(content), "type": type(content).__name__}
elif isinstance(content, (dict, list)):
return content
elif isinstance(content, (int, float, bool)):
return content
else:
return str(content)
except Exception:
return f"<unserializable: {type(content).__name__}>"
def _is_diff_content(self, text: str) -> bool:
"""
Check if text appears to be diff content.
Args:
text: Text to check
Returns:
True if text looks like diff output
"""
if not text or not isinstance(text, str):
return False
lines = text.split("\n")
# Check first few lines for diff indicators
diff_indicators = [
any(line.startswith("diff --git") for line in lines[:5]),
any(line.startswith("--- ") for line in lines[:5]),
any(line.startswith("+++ ") for line in lines[:5]),
any(line.startswith("@@") for line in lines[:10]),
any(
line.startswith(("+", "-"))
and not line.startswith(("+++", "---"))
for line in lines[:10]
),
]
return any(diff_indicators)
async def _write_raw_log(self, entry: Dict[str, Any]) -> None:
"""
Write entry to raw log file.
Args:
entry: Log entry to write
"""
if self._emergency_event.is_set():
return
try:
if self._raw_file_handle is None:
try:
# Use asyncio.to_thread to avoid blocking the event loop
self._raw_file_handle = await asyncio.to_thread(
open, self.raw_output_file, "a", encoding="utf-8"
)
except (OSError, IOError):
return
json_line = json.dumps(entry, default=str, ensure_ascii=False) + "\n"
try:
await asyncio.wait_for(
asyncio.to_thread(self._raw_file_handle.write, json_line),
timeout=0.1,
)
await asyncio.wait_for(
asyncio.to_thread(self._raw_file_handle.flush), timeout=0.1
)
except asyncio.TimeoutError:
return
except (OSError, IOError):
if self._raw_file_handle:
try:
self._raw_file_handle.close()
except Exception:
pass
self._raw_file_handle = None
return
except Exception:
pass
async def _update_metrics(self, entry_type: str, entry: Dict[str, Any]) -> None:
"""
Update metrics tracking.
Args:
entry_type: Type of entry (message, tool_call, error, iteration)
entry: The entry data
"""
try:
if entry_type == "message":
self._metrics["messages"].append(entry)
elif entry_type == "tool_call":
self._metrics["tool_calls"].append(entry)
elif entry_type == "error":
self._metrics["errors"].append(entry)
elif entry_type == "iteration":
self._metrics["iterations"].append(entry)
# Periodically save metrics (every 10 messages)
if len(self._metrics["messages"]) % 10 == 0:
await self._save_metrics()
except Exception:
pass
async def _save_metrics(self) -> None:
"""Save metrics to file."""
if self._emergency_event.is_set():
return
try:
metrics_data = {
**self._metrics,
"session_last_update": datetime.now().isoformat(),
"total_messages": len(self._metrics["messages"]),
"total_tool_calls": len(self._metrics["tool_calls"]),
"total_errors": len(self._metrics["errors"]),
"total_iterations": len(self._metrics["iterations"]),
}
if self._lock.locked():
return
async with self._lock:
try:
await asyncio.wait_for(
asyncio.to_thread(
lambda: self.metrics_file.write_text(
json.dumps(metrics_data, indent=2, default=str)
)
),
timeout=0.5,
)
except asyncio.TimeoutError:
pass
except Exception:
pass
def _get_traceback(self, error: Exception) -> str:
"""
Get formatted traceback from exception.
Args:
error: Exception to get traceback from
Returns:
Formatted traceback string
"""
import traceback
try:
return "".join(
traceback.format_exception(type(error), error, error.__traceback__)
)
except Exception:
return f"Could not extract traceback: {str(error)}"
def get_session_metrics(self) -> Dict[str, Any]:
"""
Get current session metrics.
Returns:
Dictionary containing session metrics
"""
return {
"session_start": self._metrics["session_start"],
"session_end": self._metrics.get("session_end"),
"total_messages": len(self._metrics["messages"]),
"total_tool_calls": len(self._metrics["tool_calls"]),
"total_errors": len(self._metrics["errors"]),
"total_iterations": len(self._metrics["iterations"]),
"total_tokens": self._metrics["total_tokens"],
"total_cost": self._metrics["total_cost"],
"log_files": {
"verbose": str(self.verbose_log_file),
"raw": str(self.raw_output_file),
"metrics": str(self.metrics_file),
},
}
def print_to_console(
self, message: str, style: Optional[str] = None, panel: bool = False
) -> None:
"""
Print a message to the live console with Rich formatting.
Args:
message: Message to print
style: Rich style string (e.g., "bold red", "green")
panel: Whether to wrap in a Rich panel
"""
if self._emergency_event.is_set():
return
if self._live_console and RICH_AVAILABLE:
if panel:
self._live_console.print(Panel(message))
elif style:
self._live_console.print(f"[{style}]{message}[/{style}]")
else:
self._live_console.print(message)
else:
print(message)
def print_table(
self, title: str, columns: List[str], rows: List[List[str]]
) -> None:
"""
Print a formatted table to the console.
Args:
title: Table title
columns: Column headers
rows: Table data rows
"""
if self._emergency_event.is_set():
return
if self._live_console and RICH_AVAILABLE:
table = Table(title=title)
for col in columns:
table.add_column(col)
for row in rows:
table.add_row(*row)
self._live_console.print(table)
else:
# Plain text fallback
print(f"\n{title}")
print("-" * 40)
print(" | ".join(columns))
print("-" * 40)
for row in rows:
print(" | ".join(row))
print("-" * 40)
async def close(self) -> None:
"""Close log files and save final metrics."""
try:
self._emergency_shutdown = True
self._emergency_event.set()
# Update session end time
self._metrics["session_end"] = datetime.now().isoformat()
# Save final metrics with timeout
try:
await asyncio.wait_for(self._save_metrics(), timeout=1.0)
except asyncio.TimeoutError:
pass
# Close raw file handle
if self._raw_file_handle:
try:
await asyncio.wait_for(
asyncio.to_thread(self._raw_file_handle.close), timeout=0.5
)
except asyncio.TimeoutError:
pass
finally:
self._raw_file_handle = None
# Write session summary
try:
if not self._lock.locked():
async with self._lock:
session_start = self._metrics["session_start"]
total_duration = (
datetime.now() - datetime.fromisoformat(session_start)
).total_seconds()
self._print_to_file(f"\n{'='*80}")
self._print_to_file("SESSION SUMMARY")
self._print_to_file(f"Duration: {total_duration:.1f} seconds")
self._print_to_file(
f"Messages: {len(self._metrics['messages'])}"
)
self._print_to_file(
f"Tool Calls: {len(self._metrics['tool_calls'])}"
)
self._print_to_file(f"Errors: {len(self._metrics['errors'])}")
self._print_to_file(
f"Iterations: {len(self._metrics['iterations'])}"
)
self._print_to_file(
f"Total Tokens: {self._metrics['total_tokens']}"
)
self._print_to_file(
f"Total Cost: ${self._metrics['total_cost']:.4f}"
)
self._print_to_file(f"Verbose log: {self.verbose_log_file}")
self._print_to_file(f"Raw log: {self.raw_output_file}")
self._print_to_file(f"Metrics: {self.metrics_file}")
self._print_to_file(f"{'='*80}\n")
except (RuntimeError, asyncio.TimeoutError):
pass
# Close text IO proxy
self._text_io_proxy.close()
except Exception as e:
print(f"Error closing verbose logger: {e}", file=sys.stderr)
def close_sync(self) -> None:
"""Synchronous close method."""
try:
asyncio.get_running_loop() # Check if loop exists (raises RuntimeError if not)
asyncio.create_task(self.close())
except RuntimeError:
asyncio.run(self.close())

View File

@@ -0,0 +1,8 @@
# ABOUTME: Web UI module for Ralph Orchestrator monitoring and control
# ABOUTME: Provides real-time dashboard for agent execution and system metrics
"""Web UI module for Ralph Orchestrator monitoring."""
from .server import WebMonitor
__all__ = ['WebMonitor']

View File

@@ -0,0 +1,74 @@
# ABOUTME: Entry point for running the Ralph Orchestrator web monitoring server
# ABOUTME: Enables execution with `python -m ralph_orchestrator.web`
"""Entry point for the Ralph Orchestrator web monitoring server."""
import argparse
import asyncio
import logging
from .server import WebMonitor
logger = logging.getLogger(__name__)
def main():
"""Main entry point for the web monitoring server."""
parser = argparse.ArgumentParser(
description="Ralph Orchestrator Web Monitoring Dashboard"
)
parser.add_argument(
"--port",
type=int,
default=8080,
help="Port to run the web server on (default: 8080)"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind the server to (default: 0.0.0.0)"
)
parser.add_argument(
"--no-auth",
action="store_true",
help="Disable authentication (not recommended for production)"
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set logging level (default: INFO)"
)
args = parser.parse_args()
# Configure logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Create and run the web monitor
monitor = WebMonitor(
port=args.port,
host=args.host,
enable_auth=not args.no_auth
)
logger.info(f"Starting Ralph Orchestrator Web Monitor on {args.host}:{args.port}")
if args.no_auth:
logger.warning("Authentication is disabled - not recommended for production")
else:
logger.info("Authentication enabled - default credentials: admin / ralph-admin-2024")
try:
asyncio.run(monitor.run())
except KeyboardInterrupt:
logger.info("Web monitor stopped by user")
except Exception as e:
logger.error(f"Web monitor error: {e}", exc_info=True)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,186 @@
# ABOUTME: Authentication module for Ralph Orchestrator web monitoring dashboard
# ABOUTME: Provides JWT-based authentication with username/password login
"""Authentication module for the web monitoring server."""
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
import jwt
from passlib.context import CryptContext
from fastapi import HTTPException, Security, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
# Configuration
SECRET_KEY = os.getenv("RALPH_WEB_SECRET_KEY", secrets.token_urlsafe(32))
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("RALPH_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours default
# Default admin credentials (should be changed in production)
DEFAULT_USERNAME = os.getenv("RALPH_WEB_USERNAME", "admin")
DEFAULT_PASSWORD_HASH = os.getenv("RALPH_WEB_PASSWORD_HASH", None)
# If no password hash is provided, generate one for the default password
if not DEFAULT_PASSWORD_HASH:
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
default_password = os.getenv("RALPH_WEB_PASSWORD", "admin123")
DEFAULT_PASSWORD_HASH = pwd_context.hash(default_password)
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# HTTP Bearer token authentication
security = HTTPBearer()
class LoginRequest(BaseModel):
"""Login request model."""
username: str
password: str
class TokenResponse(BaseModel):
"""Token response model."""
access_token: str
token_type: str = "bearer"
expires_in: int
class AuthManager:
"""Manages authentication for the web server."""
def __init__(self):
self.pwd_context = pwd_context
self.secret_key = SECRET_KEY
self.algorithm = ALGORITHM
self.access_token_expire_minutes = ACCESS_TOKEN_EXPIRE_MINUTES
# Simple in-memory user store (can be extended to use a database)
self.users = {
DEFAULT_USERNAME: {
"username": DEFAULT_USERNAME,
"hashed_password": DEFAULT_PASSWORD_HASH,
"is_active": True,
"is_admin": True
}
}
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return self.pwd_context.verify(plain_password, hashed_password)
def get_password_hash(self, password: str) -> str:
"""Hash a password."""
return self.pwd_context.hash(password)
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
"""Authenticate a user by username and password."""
user = self.users.get(username)
if not user:
return None
if not self.verify_password(password, user["hashed_password"]):
return None
if not user.get("is_active", True):
return None
return user
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=self.access_token_expire_minutes)
to_encode.update({"exp": expire, "iat": datetime.now(timezone.utc)})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
return encoded_jwt
def verify_token(self, token: str) -> Dict[str, Any]:
"""Verify and decode a JWT token."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
username: str = payload.get("sub")
if username is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Check if user still exists and is active
user = self.users.get(username)
if not user or not user.get("is_active", True):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
headers={"WWW-Authenticate": "Bearer"},
)
return {"username": username, "user": user}
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
) from None
except jwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
) from None
def add_user(self, username: str, password: str, is_admin: bool = False) -> bool:
"""Add a new user to the system."""
if username in self.users:
return False
self.users[username] = {
"username": username,
"hashed_password": self.get_password_hash(password),
"is_active": True,
"is_admin": is_admin
}
return True
def remove_user(self, username: str) -> bool:
"""Remove a user from the system."""
if username in self.users and username != DEFAULT_USERNAME:
del self.users[username]
return True
return False
def update_password(self, username: str, new_password: str) -> bool:
"""Update a user's password."""
if username not in self.users:
return False
self.users[username]["hashed_password"] = self.get_password_hash(new_password)
return True
# Global auth manager instance
auth_manager = AuthManager()
async def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)) -> Dict[str, Any]:
"""Get the current authenticated user from the token."""
token = credentials.credentials
user_data = auth_manager.verify_token(token)
return user_data
async def require_admin(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
"""Require the current user to be an admin."""
if not current_user["user"].get("is_admin", False):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required"
)
return current_user

View File

@@ -0,0 +1,467 @@
# ABOUTME: Database module for Ralph Orchestrator web monitoring
# ABOUTME: Provides SQLite storage for execution history and metrics
import sqlite3
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from contextlib import contextmanager
import threading
logger = logging.getLogger(__name__)
class DatabaseManager:
"""Manages SQLite database for Ralph Orchestrator execution history."""
def __init__(self, db_path: Optional[Path] = None):
"""Initialize database manager.
Args:
db_path: Path to SQLite database file (default: ~/.ralph/history.db)
"""
if db_path is None:
config_dir = Path.home() / ".ralph"
config_dir.mkdir(exist_ok=True)
db_path = config_dir / "history.db"
self.db_path = db_path
self._lock = threading.Lock()
self._init_database()
logger.info(f"Database initialized at {self.db_path}")
@contextmanager
def _get_connection(self):
"""Thread-safe context manager for database connections."""
conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
conn.row_factory = sqlite3.Row # Enable column access by name
try:
yield conn
finally:
conn.close()
def _init_database(self):
"""Initialize database schema."""
with self._lock:
with self._get_connection() as conn:
cursor = conn.cursor()
# Create orchestrator_runs table
cursor.execute("""
CREATE TABLE IF NOT EXISTS orchestrator_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
orchestrator_id TEXT NOT NULL,
prompt_path TEXT NOT NULL,
start_time TIMESTAMP NOT NULL,
end_time TIMESTAMP,
status TEXT NOT NULL,
total_iterations INTEGER DEFAULT 0,
max_iterations INTEGER,
error_message TEXT,
metadata TEXT
)
""")
# Create iteration_history table
cursor.execute("""
CREATE TABLE IF NOT EXISTS iteration_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id INTEGER NOT NULL,
iteration_number INTEGER NOT NULL,
start_time TIMESTAMP NOT NULL,
end_time TIMESTAMP,
status TEXT NOT NULL,
current_task TEXT,
agent_output TEXT,
error_message TEXT,
metrics TEXT,
FOREIGN KEY (run_id) REFERENCES orchestrator_runs(id)
)
""")
# Create task_history table
cursor.execute("""
CREATE TABLE IF NOT EXISTS task_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id INTEGER NOT NULL,
task_description TEXT NOT NULL,
status TEXT NOT NULL,
start_time TIMESTAMP,
end_time TIMESTAMP,
iteration_count INTEGER DEFAULT 0,
error_message TEXT,
FOREIGN KEY (run_id) REFERENCES orchestrator_runs(id)
)
""")
# Create indices for better query performance
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_runs_orchestrator_id
ON orchestrator_runs(orchestrator_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_runs_start_time
ON orchestrator_runs(start_time DESC)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_iterations_run_id
ON iteration_history(run_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_tasks_run_id
ON task_history(run_id)
""")
conn.commit()
def create_run(self, orchestrator_id: str, prompt_path: str,
max_iterations: Optional[int] = None,
metadata: Optional[Dict[str, Any]] = None) -> int:
"""Create a new orchestrator run entry.
Args:
orchestrator_id: Unique ID of the orchestrator
prompt_path: Path to the prompt file
max_iterations: Maximum iterations for this run
metadata: Additional metadata to store
Returns:
ID of the created run
"""
with self._lock:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO orchestrator_runs
(orchestrator_id, prompt_path, start_time, status, max_iterations, metadata)
VALUES (?, ?, ?, ?, ?, ?)
""", (
orchestrator_id,
prompt_path,
datetime.now().isoformat(),
"running",
max_iterations,
json.dumps(metadata) if metadata else None
))
conn.commit()
return cursor.lastrowid
def update_run_status(self, run_id: int, status: str,
error_message: Optional[str] = None,
total_iterations: Optional[int] = None):
"""Update the status of an orchestrator run.
Args:
run_id: ID of the run to update
status: New status (running, completed, failed, paused)
error_message: Error message if failed
total_iterations: Total iterations completed
"""
with self._lock:
with self._get_connection() as conn:
cursor = conn.cursor()
updates = ["status = ?"]
params = [status]
if status in ["completed", "failed"]:
updates.append("end_time = ?")
params.append(datetime.now().isoformat())
if error_message is not None:
updates.append("error_message = ?")
params.append(error_message)
if total_iterations is not None:
updates.append("total_iterations = ?")
params.append(total_iterations)
params.append(run_id)
cursor.execute(f"""
UPDATE orchestrator_runs
SET {', '.join(updates)}
WHERE id = ?
""", params)
conn.commit()
def add_iteration(self, run_id: int, iteration_number: int,
current_task: Optional[str] = None,
metrics: Optional[Dict[str, Any]] = None) -> int:
"""Add a new iteration entry.
Args:
run_id: ID of the parent run
iteration_number: Iteration number
current_task: Current task being executed
metrics: Performance metrics for this iteration
Returns:
ID of the created iteration
"""
with self._lock:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO iteration_history
(run_id, iteration_number, start_time, status, current_task, metrics)
VALUES (?, ?, ?, ?, ?, ?)
""", (
run_id,
iteration_number,
datetime.now().isoformat(),
"running",
current_task,
json.dumps(metrics) if metrics else None
))
conn.commit()
return cursor.lastrowid
def update_iteration(self, iteration_id: int, status: str,
agent_output: Optional[str] = None,
error_message: Optional[str] = None):
"""Update an iteration entry.
Args:
iteration_id: ID of the iteration to update
status: New status (running, completed, failed)
agent_output: Output from the agent
error_message: Error message if failed
"""
with self._lock:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE iteration_history
SET status = ?, end_time = ?, agent_output = ?, error_message = ?
WHERE id = ?
""", (
status,
datetime.now().isoformat() if status != "running" else None,
agent_output,
error_message,
iteration_id
))
conn.commit()
def add_task(self, run_id: int, task_description: str) -> int:
"""Add a task entry.
Args:
run_id: ID of the parent run
task_description: Description of the task
Returns:
ID of the created task
"""
with self._lock:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO task_history (run_id, task_description, status)
VALUES (?, ?, ?)
""", (run_id, task_description, "pending"))
conn.commit()
return cursor.lastrowid
def update_task_status(self, task_id: int, status: str,
error_message: Optional[str] = None):
"""Update task status.
Args:
task_id: ID of the task to update
status: New status (pending, in_progress, completed, failed)
error_message: Error message if failed
"""
with self._lock:
with self._get_connection() as conn:
cursor = conn.cursor()
now = datetime.now().isoformat()
if status == "in_progress":
cursor.execute("""
UPDATE task_history
SET status = ?, start_time = ?
WHERE id = ?
""", (status, now, task_id))
elif status in ["completed", "failed"]:
cursor.execute("""
UPDATE task_history
SET status = ?, end_time = ?, error_message = ?
WHERE id = ?
""", (status, now, error_message, task_id))
else:
cursor.execute("""
UPDATE task_history
SET status = ?
WHERE id = ?
""", (status, task_id))
conn.commit()
def get_recent_runs(self, limit: int = 50) -> List[Dict[str, Any]]:
"""Get recent orchestrator runs.
Args:
limit: Maximum number of runs to return
Returns:
List of run dictionaries
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM orchestrator_runs
ORDER BY start_time DESC
LIMIT ?
""", (limit,))
runs = []
for row in cursor.fetchall():
run = dict(row)
if run.get('metadata'):
run['metadata'] = json.loads(run['metadata'])
runs.append(run)
return runs
def get_run_details(self, run_id: int) -> Optional[Dict[str, Any]]:
"""Get detailed information about a specific run.
Args:
run_id: ID of the run
Returns:
Run details with iterations and tasks
"""
with self._get_connection() as conn:
cursor = conn.cursor()
# Get run info
cursor.execute("SELECT * FROM orchestrator_runs WHERE id = ?", (run_id,))
row = cursor.fetchone()
if not row:
return None
run = dict(row)
if run.get('metadata'):
run['metadata'] = json.loads(run['metadata'])
# Get iterations
cursor.execute("""
SELECT * FROM iteration_history
WHERE run_id = ?
ORDER BY iteration_number
""", (run_id,))
iterations = []
for row in cursor.fetchall():
iteration = dict(row)
if iteration.get('metrics'):
iteration['metrics'] = json.loads(iteration['metrics'])
iterations.append(iteration)
run['iterations'] = iterations
# Get tasks
cursor.execute("""
SELECT * FROM task_history
WHERE run_id = ?
ORDER BY id
""", (run_id,))
run['tasks'] = [dict(row) for row in cursor.fetchall()]
return run
def get_statistics(self) -> Dict[str, Any]:
"""Get database statistics.
Returns:
Dictionary with statistics
"""
with self._get_connection() as conn:
cursor = conn.cursor()
stats = {}
# Total runs
cursor.execute("SELECT COUNT(*) FROM orchestrator_runs")
stats['total_runs'] = cursor.fetchone()[0]
# Runs by status
cursor.execute("""
SELECT status, COUNT(*)
FROM orchestrator_runs
GROUP BY status
""")
stats['runs_by_status'] = dict(cursor.fetchall())
# Total iterations
cursor.execute("SELECT COUNT(*) FROM iteration_history")
stats['total_iterations'] = cursor.fetchone()[0]
# Total tasks
cursor.execute("SELECT COUNT(*) FROM task_history")
stats['total_tasks'] = cursor.fetchone()[0]
# Average iterations per run
cursor.execute("""
SELECT AVG(total_iterations)
FROM orchestrator_runs
WHERE total_iterations > 0
""")
result = cursor.fetchone()[0]
stats['avg_iterations_per_run'] = round(result, 2) if result else 0
# Success rate
cursor.execute("""
SELECT
COUNT(CASE WHEN status = 'completed' THEN 1 END) * 100.0 /
NULLIF(COUNT(*), 0)
FROM orchestrator_runs
WHERE status IN ('completed', 'failed')
""")
result = cursor.fetchone()[0]
stats['success_rate'] = round(result, 2) if result else 0
return stats
def cleanup_old_records(self, days: int = 30):
"""Remove records older than specified days.
Args:
days: Number of days to keep
"""
with self._lock:
with self._get_connection() as conn:
cursor = conn.cursor()
# Get run IDs to delete (using SQLite datetime functions directly)
cursor.execute("""
SELECT id FROM orchestrator_runs
WHERE datetime(start_time) < datetime('now', '-' || ? || ' days')
""", (days,))
run_ids = [row[0] for row in cursor.fetchall()]
if run_ids:
# Delete iterations
cursor.execute("""
DELETE FROM iteration_history
WHERE run_id IN ({})
""".format(','.join('?' * len(run_ids))), run_ids)
# Delete tasks
cursor.execute("""
DELETE FROM task_history
WHERE run_id IN ({})
""".format(','.join('?' * len(run_ids))), run_ids)
# Delete runs
cursor.execute("""
DELETE FROM orchestrator_runs
WHERE id IN ({})
""".format(','.join('?' * len(run_ids))), run_ids)
conn.commit()
logger.info(f"Cleaned up {len(run_ids)} old runs")

View File

@@ -0,0 +1,280 @@
# ABOUTME: Implements rate limiting for API endpoints to prevent abuse
# ABOUTME: Uses token bucket algorithm with configurable limits per endpoint
"""Rate limiting implementation for the Ralph web server."""
import asyncio
import time
from collections import defaultdict
from typing import Dict, Optional, Tuple
from functools import wraps
import logging
from fastapi import Request, status
from fastapi.responses import JSONResponse
logger = logging.getLogger(__name__)
class RateLimiter:
"""Token bucket rate limiter implementation.
Uses a token bucket algorithm to limit requests per IP address.
Tokens are replenished at a fixed rate up to a maximum capacity.
"""
def __init__(
self,
capacity: int = 100,
refill_rate: float = 10.0,
refill_period: float = 1.0,
block_duration: float = 60.0
):
"""Initialize the rate limiter.
Args:
capacity: Maximum number of tokens in the bucket
refill_rate: Number of tokens to add per refill period
refill_period: Time in seconds between refills
block_duration: Time in seconds to block an IP after exhausting tokens
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.refill_period = refill_period
self.block_duration = block_duration
# Store buckets per IP address
self.buckets: Dict[str, Tuple[float, float, float]] = defaultdict(
lambda: (float(capacity), time.time(), 0.0)
)
self.blocked_ips: Dict[str, float] = {}
# Lock for thread-safe access
self._lock = asyncio.Lock()
async def check_rate_limit(self, identifier: str) -> Tuple[bool, Optional[int]]:
"""Check if a request is allowed under the rate limit.
Args:
identifier: Unique identifier for the client (e.g., IP address)
Returns:
Tuple of (allowed, retry_after_seconds)
"""
async with self._lock:
current_time = time.time()
# Check if IP is blocked
if identifier in self.blocked_ips:
block_end = self.blocked_ips[identifier]
if current_time < block_end:
retry_after = int(block_end - current_time)
return False, retry_after
else:
# Unblock the IP
del self.blocked_ips[identifier]
# Get or create bucket for this identifier
tokens, last_refill, consecutive_violations = self.buckets[identifier]
# Calculate tokens to add based on time elapsed
time_elapsed = current_time - last_refill
tokens_to_add = (time_elapsed / self.refill_period) * self.refill_rate
tokens = min(self.capacity, tokens + tokens_to_add)
if tokens >= 1:
# Consume a token
tokens -= 1
consecutive_violations = 0
self.buckets[identifier] = (tokens, current_time, consecutive_violations)
return True, None
else:
# No tokens available
consecutive_violations += 1
# Block IP if too many consecutive violations
if consecutive_violations >= 5:
block_end = current_time + self.block_duration
self.blocked_ips[identifier] = block_end
del self.buckets[identifier]
return False, int(self.block_duration)
self.buckets[identifier] = (tokens, current_time, consecutive_violations)
retry_after = int(self.refill_period / self.refill_rate)
return False, retry_after
async def cleanup_old_buckets(self, max_age: float = 3600.0):
"""Remove old inactive buckets to prevent memory growth.
Args:
max_age: Maximum age in seconds for inactive buckets
"""
async with self._lock:
current_time = time.time()
to_remove = []
for identifier, (_tokens, last_refill, _) in self.buckets.items():
if current_time - last_refill > max_age:
to_remove.append(identifier)
for identifier in to_remove:
del self.buckets[identifier]
# Clean up expired blocks
to_remove = []
for identifier, block_end in self.blocked_ips.items():
if current_time >= block_end:
to_remove.append(identifier)
for identifier in to_remove:
del self.blocked_ips[identifier]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} expired rate limit entries")
class RateLimitConfig:
"""Configuration for different rate limit tiers."""
# Default limits for different endpoint categories
LIMITS = {
"auth": {"capacity": 10, "refill_rate": 1.0, "refill_period": 60.0}, # 10 requests/minute
"api": {"capacity": 100, "refill_rate": 10.0, "refill_period": 1.0}, # 100 requests/10 seconds
"websocket": {"capacity": 10, "refill_rate": 1.0, "refill_period": 10.0}, # 10 connections/10 seconds
"static": {"capacity": 200, "refill_rate": 20.0, "refill_period": 1.0}, # 200 requests/20 seconds
"admin": {"capacity": 50, "refill_rate": 5.0, "refill_period": 1.0}, # 50 requests/5 seconds
}
@classmethod
def get_limiter(cls, category: str) -> RateLimiter:
"""Get or create a rate limiter for a specific category.
Args:
category: The category of endpoints to limit
Returns:
A configured RateLimiter instance
"""
if not hasattr(cls, "_limiters"):
cls._limiters = {}
if category not in cls._limiters:
config = cls.LIMITS.get(category, cls.LIMITS["api"])
cls._limiters[category] = RateLimiter(**config)
return cls._limiters[category]
def rate_limit(category: str = "api"):
"""Decorator to apply rate limiting to FastAPI endpoints.
Args:
category: The rate limit category to apply
"""
def decorator(func):
@wraps(func)
async def wrapper(request: Request, *args, **kwargs):
# Get client IP address
client_ip = request.client.host if request.client else "unknown"
# Check for X-Forwarded-For header (for proxies)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Get the appropriate rate limiter
limiter = RateLimitConfig.get_limiter(category)
# Check rate limit
allowed, retry_after = await limiter.check_rate_limit(client_ip)
if not allowed:
logger.warning(f"Rate limit exceeded for {client_ip} on {category} endpoint")
response = JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded",
"retry_after": retry_after
}
)
if retry_after:
response.headers["Retry-After"] = str(retry_after)
return response
# Call the original function
return await func(request, *args, **kwargs)
return wrapper
return decorator
async def setup_rate_limit_cleanup():
"""Set up periodic cleanup of old rate limit buckets.
Returns:
The cleanup task
"""
async def cleanup_task():
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
for category in RateLimitConfig.LIMITS:
limiter = RateLimitConfig.get_limiter(category)
await limiter.cleanup_old_buckets()
except Exception as e:
logger.error(f"Error in rate limit cleanup: {e}")
return asyncio.create_task(cleanup_task())
# Middleware for global rate limiting
async def rate_limit_middleware(request: Request, call_next):
"""Global rate limiting middleware for all requests.
Args:
request: The incoming request
call_next: The next middleware or endpoint
Returns:
The response
"""
# Determine the category based on the path
path = request.url.path
if path.startswith("/api/auth"):
category = "auth"
elif path.startswith("/api/admin"):
category = "admin"
elif path.startswith("/ws"):
category = "websocket"
elif path.startswith("/static"):
category = "static"
else:
category = "api"
# Get client IP
client_ip = request.client.host if request.client else "unknown"
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Check rate limit
limiter = RateLimitConfig.get_limiter(category)
allowed, retry_after = await limiter.check_rate_limit(client_ip)
if not allowed:
logger.warning(f"Rate limit exceeded for {client_ip} on {path}")
response = JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded",
"retry_after": retry_after
}
)
if retry_after:
response.headers["Retry-After"] = str(retry_after)
return response
# Continue with the request
response = await call_next(request)
return response

View File

@@ -0,0 +1,712 @@
# ABOUTME: FastAPI web server for Ralph Orchestrator monitoring dashboard
# ABOUTME: Provides REST API endpoints and WebSocket connections for real-time updates
"""FastAPI web server for Ralph Orchestrator monitoring."""
import json
import time
import asyncio
import logging
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, Any, List
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends, status
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import psutil
from ..orchestrator import RalphOrchestrator
from .auth import (
auth_manager, LoginRequest, TokenResponse,
get_current_user, require_admin
)
from .database import DatabaseManager
from .rate_limit import rate_limit_middleware, setup_rate_limit_cleanup
logger = logging.getLogger(__name__)
class PromptUpdateRequest(BaseModel):
"""Request model for updating orchestrator prompt."""
content: str
class OrchestratorMonitor:
"""Monitors and manages orchestrator instances."""
def __init__(self):
self.active_orchestrators: Dict[str, RalphOrchestrator] = {}
self.execution_history: List[Dict[str, Any]] = []
self.websocket_clients: List[WebSocket] = []
self.metrics_cache: Dict[str, Any] = {}
self.system_metrics_task: Optional[asyncio.Task] = None
self.database = DatabaseManager()
self.active_runs: Dict[str, int] = {} # Maps orchestrator_id to run_id
self.active_iterations: Dict[str, int] = {} # Maps orchestrator_id to iteration_id
async def start_monitoring(self):
"""Start background monitoring tasks."""
if not self.system_metrics_task:
self.system_metrics_task = asyncio.create_task(self._monitor_system_metrics())
async def stop_monitoring(self):
"""Stop background monitoring tasks."""
if self.system_metrics_task:
self.system_metrics_task.cancel()
try:
await self.system_metrics_task
except asyncio.CancelledError:
pass
async def _monitor_system_metrics(self):
"""Monitor system metrics continuously."""
while True:
try:
# Collect system metrics
metrics = {
"timestamp": datetime.now().isoformat(),
"cpu_percent": psutil.cpu_percent(interval=1),
"memory": {
"total": psutil.virtual_memory().total,
"available": psutil.virtual_memory().available,
"percent": psutil.virtual_memory().percent
},
"active_processes": len(psutil.pids()),
"orchestrators": len(self.active_orchestrators)
}
self.metrics_cache["system"] = metrics
# Broadcast to WebSocket clients
await self._broadcast_to_clients({
"type": "system_metrics",
"data": metrics
})
await asyncio.sleep(5) # Update every 5 seconds
except Exception as e:
logger.error(f"Error monitoring system metrics: {e}")
await asyncio.sleep(5)
async def _broadcast_to_clients(self, message: Dict[str, Any]):
"""Broadcast message to all connected WebSocket clients."""
disconnected_clients = []
for client in self.websocket_clients:
try:
await client.send_json(message)
except Exception:
disconnected_clients.append(client)
# Remove disconnected clients
for client in disconnected_clients:
if client in self.websocket_clients:
self.websocket_clients.remove(client)
def _schedule_broadcast(self, message: Dict[str, Any]):
"""Schedule a broadcast to clients, handling both sync and async contexts."""
try:
# Check if there's a running event loop (raises RuntimeError if not)
asyncio.get_running_loop()
# If we're in an async context, schedule the broadcast
asyncio.create_task(self._broadcast_to_clients(message))
except RuntimeError:
# No event loop running - we're in a sync context (e.g., during testing)
# The broadcast will be skipped in this case
pass
async def broadcast_update(self, message: Dict[str, Any]):
"""Public method to broadcast updates to WebSocket clients."""
await self._broadcast_to_clients(message)
def register_orchestrator(self, orchestrator_id: str, orchestrator: RalphOrchestrator):
"""Register an orchestrator instance."""
self.active_orchestrators[orchestrator_id] = orchestrator
# Create a new run in the database
try:
run_id = self.database.create_run(
orchestrator_id=orchestrator_id,
prompt_path=str(orchestrator.prompt_file),
max_iterations=orchestrator.max_iterations,
metadata={
"primary_tool": orchestrator.primary_tool,
"max_runtime": orchestrator.max_runtime
}
)
self.active_runs[orchestrator_id] = run_id
# Extract and store tasks if available
if hasattr(orchestrator, 'task_queue'):
for task in orchestrator.task_queue:
self.database.add_task(run_id, task['description'])
except Exception as e:
logger.error(f"Error creating database run for orchestrator {orchestrator_id}: {e}")
self._schedule_broadcast({
"type": "orchestrator_registered",
"data": {"id": orchestrator_id, "timestamp": datetime.now().isoformat()}
})
def unregister_orchestrator(self, orchestrator_id: str):
"""Unregister an orchestrator instance."""
if orchestrator_id in self.active_orchestrators:
# Update database run status
if orchestrator_id in self.active_runs:
try:
orchestrator = self.active_orchestrators[orchestrator_id]
status = "completed" if not orchestrator.stop_requested else "stopped"
total_iterations = orchestrator.metrics.total_iterations if hasattr(orchestrator, 'metrics') else 0
self.database.update_run_status(
self.active_runs[orchestrator_id],
status=status,
total_iterations=total_iterations
)
del self.active_runs[orchestrator_id]
except Exception as e:
logger.error(f"Error updating database run for orchestrator {orchestrator_id}: {e}")
# Remove from active orchestrators
del self.active_orchestrators[orchestrator_id]
# Remove active iteration tracking if exists
if orchestrator_id in self.active_iterations:
del self.active_iterations[orchestrator_id]
self._schedule_broadcast({
"type": "orchestrator_unregistered",
"data": {"id": orchestrator_id, "timestamp": datetime.now().isoformat()}
})
def get_orchestrator_status(self, orchestrator_id: str) -> Dict[str, Any]:
"""Get status of a specific orchestrator."""
if orchestrator_id not in self.active_orchestrators:
return None
orchestrator = self.active_orchestrators[orchestrator_id]
# Try to use the new get_orchestrator_state method if it exists
if hasattr(orchestrator, 'get_orchestrator_state'):
state = orchestrator.get_orchestrator_state()
state['id'] = orchestrator_id # Override with our ID
return state
else:
# Fallback to old method for compatibility
return {
"id": orchestrator_id,
"status": "running" if not orchestrator.stop_requested else "stopping",
"metrics": orchestrator.metrics.to_dict(),
"cost": orchestrator.cost_tracker.get_summary() if orchestrator.cost_tracker else None,
"config": {
"primary_tool": orchestrator.primary_tool,
"max_iterations": orchestrator.max_iterations,
"max_runtime": orchestrator.max_runtime,
"prompt_file": str(orchestrator.prompt_file)
}
}
def get_all_orchestrators_status(self) -> List[Dict[str, Any]]:
"""Get status of all orchestrators."""
return [
self.get_orchestrator_status(orch_id)
for orch_id in self.active_orchestrators
]
def start_iteration(self, orchestrator_id: str, iteration_number: int,
current_task: Optional[str] = None) -> Optional[int]:
"""Start tracking a new iteration.
Args:
orchestrator_id: ID of the orchestrator
iteration_number: Iteration number
current_task: Current task being executed
Returns:
Iteration ID if successful, None otherwise
"""
if orchestrator_id not in self.active_runs:
return None
try:
iteration_id = self.database.add_iteration(
run_id=self.active_runs[orchestrator_id],
iteration_number=iteration_number,
current_task=current_task,
metrics=None # Can be enhanced to include metrics
)
self.active_iterations[orchestrator_id] = iteration_id
return iteration_id
except Exception as e:
logger.error(f"Error starting iteration for orchestrator {orchestrator_id}: {e}")
return None
def end_iteration(self, orchestrator_id: str, status: str = "completed",
agent_output: Optional[str] = None, error_message: Optional[str] = None):
"""End tracking for the current iteration.
Args:
orchestrator_id: ID of the orchestrator
status: Status of the iteration (completed, failed)
agent_output: Output from the agent
error_message: Error message if failed
"""
if orchestrator_id not in self.active_iterations:
return
try:
self.database.update_iteration(
iteration_id=self.active_iterations[orchestrator_id],
status=status,
agent_output=agent_output,
error_message=error_message
)
del self.active_iterations[orchestrator_id]
except Exception as e:
logger.error(f"Error ending iteration for orchestrator {orchestrator_id}: {e}")
class WebMonitor:
"""Web monitoring server for Ralph Orchestrator."""
def __init__(self, host: str = "0.0.0.0", port: int = 8080, enable_auth: bool = True):
self.host = host
self.port = port
self.enable_auth = enable_auth
self.monitor = OrchestratorMonitor()
self.app = None
self._setup_app()
def _setup_app(self):
"""Setup FastAPI application."""
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
await self.monitor.start_monitoring()
# Start rate limit cleanup task
cleanup_task = await setup_rate_limit_cleanup()
yield
# Shutdown
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
pass
await self.monitor.stop_monitoring()
self.app = FastAPI(
title="Ralph Orchestrator Monitor",
description="Real-time monitoring for Ralph AI Orchestrator",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add rate limiting middleware
self.app.middleware("http")(rate_limit_middleware)
# Mount static files directory if it exists
static_dir = Path(__file__).parent / "static"
if static_dir.exists():
self.app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
# Setup routes
self._setup_routes()
def _setup_routes(self):
"""Setup API routes."""
# Authentication endpoints (public)
@self.app.post("/api/auth/login", response_model=TokenResponse)
async def login(request: LoginRequest):
"""Login and receive an access token."""
user = auth_manager.authenticate_user(request.username, request.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = auth_manager.create_access_token(
data={"sub": user["username"]}
)
return TokenResponse(
access_token=access_token,
expires_in=auth_manager.access_token_expire_minutes * 60
)
@self.app.get("/api/auth/verify")
async def verify_token(current_user: Dict[str, Any] = Depends(get_current_user)):
"""Verify the current token is valid."""
return {
"valid": True,
"username": current_user["username"],
"is_admin": current_user["user"].get("is_admin", False)
}
# Public endpoints - HTML pages
@self.app.get("/login.html")
async def login_page():
"""Serve the login page."""
html_file = Path(__file__).parent / "static" / "login.html"
if html_file.exists():
return FileResponse(html_file, media_type="text/html")
else:
return HTMLResponse(content="<h1>Login page not found</h1>", status_code=404)
@self.app.get("/")
async def index():
"""Serve the main dashboard."""
html_file = Path(__file__).parent / "static" / "index.html"
if html_file.exists():
return FileResponse(html_file, media_type="text/html")
else:
# Return a basic HTML page if static file doesn't exist yet
return HTMLResponse(content="""
<!DOCTYPE html>
<html>
<head>
<title>Ralph Orchestrator Monitor</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; }
h1 { color: #333; }
.status { padding: 10px; margin: 10px 0; background: #f0f0f0; border-radius: 5px; }
</style>
</head>
<body>
<h1>Ralph Orchestrator Monitor</h1>
<div id="status" class="status">
<p>Web monitor is running. Dashboard file not found.</p>
<p>API Endpoints:</p>
<ul>
<li><a href="/api/status">/api/status</a> - System status</li>
<li><a href="/api/orchestrators">/api/orchestrators</a> - Active orchestrators</li>
<li><a href="/api/metrics">/api/metrics</a> - System metrics</li>
<li><a href="/docs">/docs</a> - API documentation</li>
</ul>
</div>
</body>
</html>
""")
# Create dependency for auth if enabled
auth_dependency = Depends(get_current_user) if self.enable_auth else None
@self.app.get("/api/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat()
}
@self.app.get("/api/status", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_status():
"""Get overall system status."""
return {
"status": "online",
"timestamp": datetime.now().isoformat(),
"active_orchestrators": len(self.monitor.active_orchestrators),
"connected_clients": len(self.monitor.websocket_clients),
"system_metrics": self.monitor.metrics_cache.get("system", {})
}
@self.app.get("/api/orchestrators", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_orchestrators():
"""Get all active orchestrators."""
return {
"orchestrators": self.monitor.get_all_orchestrators_status(),
"count": len(self.monitor.active_orchestrators)
}
@self.app.get("/api/orchestrators/{orchestrator_id}", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_orchestrator(orchestrator_id: str):
"""Get specific orchestrator status."""
status = self.monitor.get_orchestrator_status(orchestrator_id)
if not status:
raise HTTPException(status_code=404, detail="Orchestrator not found")
return status
@self.app.get("/api/orchestrators/{orchestrator_id}/tasks", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_orchestrator_tasks(orchestrator_id: str):
"""Get task queue status for an orchestrator."""
if orchestrator_id not in self.monitor.active_orchestrators:
raise HTTPException(status_code=404, detail="Orchestrator not found")
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
task_status = orchestrator.get_task_status()
return {
"orchestrator_id": orchestrator_id,
"tasks": task_status
}
@self.app.post("/api/orchestrators/{orchestrator_id}/pause", dependencies=[auth_dependency] if self.enable_auth else [])
async def pause_orchestrator(orchestrator_id: str):
"""Pause an orchestrator."""
if orchestrator_id not in self.monitor.active_orchestrators:
raise HTTPException(status_code=404, detail="Orchestrator not found")
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
orchestrator.stop_requested = True
return {"status": "paused", "orchestrator_id": orchestrator_id}
@self.app.post("/api/orchestrators/{orchestrator_id}/resume", dependencies=[auth_dependency] if self.enable_auth else [])
async def resume_orchestrator(orchestrator_id: str):
"""Resume an orchestrator."""
if orchestrator_id not in self.monitor.active_orchestrators:
raise HTTPException(status_code=404, detail="Orchestrator not found")
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
orchestrator.stop_requested = False
return {"status": "resumed", "orchestrator_id": orchestrator_id}
@self.app.get("/api/orchestrators/{orchestrator_id}/prompt", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_orchestrator_prompt(orchestrator_id: str):
"""Get the current prompt for an orchestrator."""
if orchestrator_id not in self.monitor.active_orchestrators:
raise HTTPException(status_code=404, detail="Orchestrator not found")
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
prompt_file = orchestrator.prompt_file
if not prompt_file.exists():
raise HTTPException(status_code=404, detail="Prompt file not found")
content = prompt_file.read_text()
return {
"orchestrator_id": orchestrator_id,
"prompt_file": str(prompt_file),
"content": content,
"last_modified": prompt_file.stat().st_mtime
}
@self.app.post("/api/orchestrators/{orchestrator_id}/prompt", dependencies=[auth_dependency] if self.enable_auth else [])
async def update_orchestrator_prompt(orchestrator_id: str, request: PromptUpdateRequest):
"""Update the prompt for an orchestrator."""
if orchestrator_id not in self.monitor.active_orchestrators:
raise HTTPException(status_code=404, detail="Orchestrator not found")
orchestrator = self.monitor.active_orchestrators[orchestrator_id]
prompt_file = orchestrator.prompt_file
try:
# Create backup before updating
backup_file = prompt_file.with_suffix(f".{int(time.time())}.backup")
if prompt_file.exists():
backup_file.write_text(prompt_file.read_text())
# Write the new content
prompt_file.write_text(request.content)
# Notify the orchestrator of the update
if hasattr(orchestrator, '_reload_prompt'):
orchestrator._reload_prompt()
# Broadcast update to WebSocket clients
await self.monitor._broadcast_to_clients({
"type": "prompt_updated",
"data": {
"orchestrator_id": orchestrator_id,
"timestamp": datetime.now().isoformat()
}
})
return {
"status": "success",
"orchestrator_id": orchestrator_id,
"backup_file": str(backup_file),
"timestamp": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Error updating prompt: {e}")
raise HTTPException(status_code=500, detail=f"Failed to update prompt: {str(e)}") from e
@self.app.get("/api/metrics", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_metrics():
"""Get system metrics."""
return self.monitor.metrics_cache
@self.app.get("/api/history", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_history(limit: int = 50):
"""Get execution history from database.
Args:
limit: Maximum number of runs to return (default 50)
"""
try:
# Get recent runs from database
history = self.monitor.database.get_recent_runs(limit=limit)
return history
except Exception as e:
logger.error(f"Error fetching history from database: {e}")
# Fallback to file-based history if database fails
metrics_dir = Path(".agent") / "metrics"
history = []
if metrics_dir.exists():
for metrics_file in sorted(metrics_dir.glob("metrics_*.json")):
try:
data = json.loads(metrics_file.read_text())
data["filename"] = metrics_file.name
history.append(data)
except Exception as e:
logger.error(f"Error reading metrics file {metrics_file}: {e}")
return {"history": history[-50:]} # Return last 50 entries
@self.app.get("/api/history/{run_id}", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_run_details(run_id: int):
"""Get detailed information about a specific run.
Args:
run_id: ID of the run to retrieve
"""
run_details = self.monitor.database.get_run_details(run_id)
if not run_details:
raise HTTPException(status_code=404, detail="Run not found")
return run_details
@self.app.get("/api/statistics", dependencies=[auth_dependency] if self.enable_auth else [])
async def get_statistics():
"""Get database statistics."""
return self.monitor.database.get_statistics()
@self.app.post("/api/database/cleanup", dependencies=[auth_dependency] if self.enable_auth else [])
async def cleanup_database(days: int = 30):
"""Clean up old records from the database.
Args:
days: Number of days of history to keep (default 30)
"""
try:
self.monitor.database.cleanup_old_records(days=days)
return {"status": "success", "message": f"Cleaned up records older than {days} days"}
except Exception as e:
logger.error(f"Error cleaning up database: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
# Admin endpoints for user management
@self.app.post("/api/admin/users", dependencies=[Depends(require_admin)] if self.enable_auth else [])
async def add_user(username: str, password: str, is_admin: bool = False):
"""Add a new user (admin only)."""
if auth_manager.add_user(username, password, is_admin):
return {"status": "success", "message": f"User {username} created"}
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User already exists"
)
@self.app.delete("/api/admin/users/{username}", dependencies=[Depends(require_admin)] if self.enable_auth else [])
async def remove_user(username: str):
"""Remove a user (admin only)."""
if auth_manager.remove_user(username):
return {"status": "success", "message": f"User {username} removed"}
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot remove user"
)
@self.app.post("/api/auth/change-password", dependencies=[auth_dependency] if self.enable_auth else [])
async def change_password(
old_password: str,
new_password: str,
current_user: Dict[str, Any] = Depends(get_current_user) if self.enable_auth else None
):
"""Change the current user's password."""
if not self.enable_auth:
raise HTTPException(status_code=404, detail="Authentication not enabled")
# Verify old password
user = auth_manager.authenticate_user(current_user["username"], old_password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect old password"
)
# Update password
if auth_manager.update_password(current_user["username"], new_password):
return {"status": "success", "message": "Password updated"}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update password"
)
@self.app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = None):
"""WebSocket endpoint for real-time updates."""
# Verify token if auth is enabled
if self.enable_auth and token:
try:
auth_manager.verify_token(token)
except HTTPException:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
elif self.enable_auth:
# Auth is enabled but no token provided
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await websocket.accept()
self.monitor.websocket_clients.append(websocket)
# Send initial state
await websocket.send_json({
"type": "initial_state",
"data": {
"orchestrators": self.monitor.get_all_orchestrators_status(),
"system_metrics": self.monitor.metrics_cache.get("system", {})
}
})
try:
while True:
# Keep connection alive and handle incoming messages
data = await websocket.receive_text()
# Handle ping/pong or other commands if needed
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
self.monitor.websocket_clients.remove(websocket)
logger.info("WebSocket client disconnected")
def run(self):
"""Run the web server."""
logger.info(f"Starting web monitor on {self.host}:{self.port}")
uvicorn.run(self.app, host=self.host, port=self.port)
async def arun(self):
"""Run the web server asynchronously."""
logger.info(f"Starting web monitor on {self.host}:{self.port}")
config = uvicorn.Config(app=self.app, host=self.host, port=self.port)
server = uvicorn.Server(config)
await server.serve()
def register_orchestrator(self, orchestrator_id: str, orchestrator: RalphOrchestrator):
"""Register an orchestrator with the monitor."""
self.monitor.register_orchestrator(orchestrator_id, orchestrator)
def unregister_orchestrator(self, orchestrator_id: str):
"""Unregister an orchestrator from the monitor."""
self.monitor.unregister_orchestrator(orchestrator_id)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,320 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Ralph Orchestrator - Login</title>
<style>
:root {
--primary-color: #4a5568;
--secondary-color: #718096;
--success-color: #48bb78;
--warning-color: #ed8936;
--danger-color: #f56565;
--background: #f7fafc;
--surface: #ffffff;
--text-primary: #2d3748;
--text-secondary: #718096;
--border-color: #e2e8f0;
}
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, var(--primary-color) 0%, var(--secondary-color) 100%);
color: var(--text-primary);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
}
.login-container {
background: var(--surface);
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.15);
width: 100%;
max-width: 400px;
padding: 40px;
}
.login-header {
text-align: center;
margin-bottom: 30px;
}
.login-header h1 {
color: var(--primary-color);
font-size: 28px;
margin-bottom: 10px;
}
.login-header p {
color: var(--text-secondary);
font-size: 14px;
}
.form-group {
margin-bottom: 20px;
}
label {
display: block;
margin-bottom: 8px;
color: var(--text-primary);
font-weight: 500;
font-size: 14px;
}
input[type="text"],
input[type="password"] {
width: 100%;
padding: 12px;
border: 1px solid var(--border-color);
border-radius: 6px;
font-size: 14px;
transition: border-color 0.3s, box-shadow 0.3s;
background: var(--surface);
color: var(--text-primary);
}
input[type="text"]:focus,
input[type="password"]:focus {
outline: none;
border-color: var(--primary-color);
box-shadow: 0 0 0 3px rgba(74, 85, 104, 0.1);
}
.btn-login {
width: 100%;
padding: 12px;
background: var(--primary-color);
color: white;
border: none;
border-radius: 6px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: background-color 0.3s, transform 0.2s;
}
.btn-login:hover {
background: #2d3748;
transform: translateY(-1px);
}
.btn-login:active {
transform: translateY(0);
}
.btn-login:disabled {
background: var(--secondary-color);
cursor: not-allowed;
transform: none;
}
.error-message {
background: rgba(245, 101, 101, 0.1);
border: 1px solid var(--danger-color);
color: var(--danger-color);
padding: 10px;
border-radius: 6px;
margin-bottom: 20px;
font-size: 14px;
display: none;
}
.error-message.show {
display: block;
}
.success-message {
background: rgba(72, 187, 120, 0.1);
border: 1px solid var(--success-color);
color: var(--success-color);
padding: 10px;
border-radius: 6px;
margin-bottom: 20px;
font-size: 14px;
display: none;
}
.success-message.show {
display: block;
}
.form-footer {
margin-top: 20px;
text-align: center;
color: var(--text-secondary);
font-size: 12px;
}
.loading-spinner {
display: inline-block;
width: 16px;
height: 16px;
border: 2px solid rgba(255, 255, 255, 0.3);
border-radius: 50%;
border-top-color: white;
animation: spin 0.6s linear infinite;
margin-left: 8px;
vertical-align: middle;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.auth-info {
background: rgba(74, 85, 104, 0.05);
border-left: 3px solid var(--primary-color);
padding: 12px;
margin-top: 20px;
font-size: 13px;
color: var(--text-secondary);
}
.auth-info strong {
color: var(--text-primary);
}
@media (max-width: 480px) {
.login-container {
padding: 30px 20px;
}
.login-header h1 {
font-size: 24px;
}
}
</style>
</head>
<body>
<div class="login-container">
<div class="login-header">
<h1>🤖 Ralph Orchestrator</h1>
<p>Monitoring Dashboard Login</p>
</div>
<div id="errorMessage" class="error-message"></div>
<div id="successMessage" class="success-message"></div>
<form id="loginForm">
<div class="form-group">
<label for="username">Username</label>
<input type="text" id="username" name="username" required autofocus>
</div>
<div class="form-group">
<label for="password">Password</label>
<input type="password" id="password" name="password" required>
</div>
<button type="submit" class="btn-login" id="loginButton">
Login
</button>
</form>
<div class="auth-info">
<strong>Default Credentials:</strong><br>
Username: admin<br>
Password: Set via RALPH_WEB_PASSWORD env variable<br>
(Default: ralph-admin-2024)
</div>
<div class="form-footer">
Secure access to Ralph Orchestrator monitoring
</div>
</div>
<script>
// Check if already authenticated
const token = localStorage.getItem('ralph_auth_token');
if (token) {
// Verify token is still valid
fetch('/api/auth/verify', {
method: 'POST',
headers: {
'Authorization': `Bearer ${token}`
}
}).then(response => {
if (response.ok) {
// Token is valid, redirect to dashboard
window.location.href = '/';
} else {
// Token is invalid, remove it
localStorage.removeItem('ralph_auth_token');
}
});
}
// Handle login form submission
document.getElementById('loginForm').addEventListener('submit', async (e) => {
e.preventDefault();
const username = document.getElementById('username').value;
const password = document.getElementById('password').value;
const loginButton = document.getElementById('loginButton');
const errorMessage = document.getElementById('errorMessage');
const successMessage = document.getElementById('successMessage');
// Reset messages
errorMessage.classList.remove('show');
successMessage.classList.remove('show');
// Disable button and show loading
loginButton.disabled = true;
loginButton.innerHTML = 'Logging in<span class="loading-spinner"></span>';
try {
const response = await fetch('/api/auth/login', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ username, password })
});
const data = await response.json();
if (response.ok) {
// Store token
localStorage.setItem('ralph_auth_token', data.access_token);
// Show success message
successMessage.textContent = 'Login successful! Redirecting...';
successMessage.classList.add('show');
// Redirect to dashboard
setTimeout(() => {
window.location.href = '/';
}, 1000);
} else {
// Show error
errorMessage.textContent = data.detail || 'Login failed. Please check your credentials.';
errorMessage.classList.add('show');
// Re-enable button
loginButton.disabled = false;
loginButton.textContent = 'Login';
}
} catch (error) {
// Show error
errorMessage.textContent = 'Connection error. Please try again.';
errorMessage.classList.add('show');
// Re-enable button
loginButton.disabled = false;
loginButton.textContent = 'Login';
}
});
</script>
</body>
</html>