Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
# ABOUTME: Tool adapter interfaces and implementations
|
||||
# ABOUTME: Provides unified interface for Claude, Q Chat, Gemini, ACP, and other tools
|
||||
|
||||
"""Tool adapters for Ralph Orchestrator."""
|
||||
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from .claude import ClaudeAdapter
|
||||
from .qchat import QChatAdapter
|
||||
from .kiro import KiroAdapter
|
||||
from .gemini import GeminiAdapter
|
||||
from .acp import ACPAdapter
|
||||
from .acp_handlers import ACPHandlers, PermissionRequest, PermissionResult, Terminal
|
||||
|
||||
__all__ = [
|
||||
"ToolAdapter",
|
||||
"ToolResponse",
|
||||
"ClaudeAdapter",
|
||||
"QChatAdapter",
|
||||
"KiroAdapter",
|
||||
"GeminiAdapter",
|
||||
"ACPAdapter",
|
||||
"ACPHandlers",
|
||||
"PermissionRequest",
|
||||
"PermissionResult",
|
||||
"Terminal",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,984 @@
|
||||
# ABOUTME: ACP Adapter for Agent Client Protocol integration
|
||||
# ABOUTME: Provides subprocess-based communication with ACP-compliant agents like Gemini CLI
|
||||
|
||||
"""ACP (Agent Client Protocol) adapter for Ralph Orchestrator.
|
||||
|
||||
This adapter enables Ralph to use any ACP-compliant agent (like Gemini CLI)
|
||||
as a backend for task execution. It manages the subprocess lifecycle,
|
||||
handles the initialization handshake, and routes session messages.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from .acp_client import ACPClient, ACPClientError
|
||||
from .acp_models import ACPAdapterConfig, ACPSession, UpdatePayload
|
||||
from .acp_handlers import ACPHandlers
|
||||
from ..output.console import RalphConsole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ACP Protocol version this adapter supports (integer per spec)
|
||||
ACP_PROTOCOL_VERSION = 1
|
||||
|
||||
|
||||
class ACPAdapter(ToolAdapter):
|
||||
"""Adapter for ACP-compliant agents like Gemini CLI.
|
||||
|
||||
Manages subprocess lifecycle, initialization handshake, and session
|
||||
message routing for Agent Client Protocol communication.
|
||||
|
||||
Attributes:
|
||||
agent_command: Command to spawn the agent (default: gemini).
|
||||
agent_args: Additional arguments for agent command.
|
||||
timeout: Request timeout in seconds.
|
||||
permission_mode: How to handle permission requests.
|
||||
"""
|
||||
|
||||
_TOOL_FIELD_ALIASES = {
|
||||
"toolName": ("toolName", "tool_name", "name", "tool"),
|
||||
"toolCallId": ("toolCallId", "tool_call_id", "id"),
|
||||
"arguments": ("arguments", "args", "parameters", "params", "input"),
|
||||
"status": ("status",),
|
||||
"result": ("result",),
|
||||
"error": ("error",),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_command: str = "gemini",
|
||||
agent_args: Optional[list[str]] = None,
|
||||
timeout: int = 300,
|
||||
permission_mode: str = "auto_approve",
|
||||
permission_allowlist: Optional[list[str]] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Initialize ACPAdapter.
|
||||
|
||||
Args:
|
||||
agent_command: Command to spawn the agent (default: gemini).
|
||||
agent_args: Additional command-line arguments.
|
||||
timeout: Request timeout in seconds (default: 300).
|
||||
permission_mode: Permission handling mode (default: auto_approve).
|
||||
permission_allowlist: Patterns for allowlist mode.
|
||||
verbose: Enable verbose streaming output (default: False).
|
||||
"""
|
||||
self.agent_command = agent_command
|
||||
self.agent_args = agent_args or []
|
||||
self.timeout = timeout
|
||||
self.permission_mode = permission_mode
|
||||
self.permission_allowlist = permission_allowlist or []
|
||||
self.verbose = verbose
|
||||
self._current_verbose = verbose # Per-request verbose flag
|
||||
|
||||
# Console for verbose output
|
||||
self._console = RalphConsole()
|
||||
|
||||
# State
|
||||
self._client: Optional[ACPClient] = None
|
||||
self._session_id: Optional[str] = None
|
||||
self._initialized = False
|
||||
self._session: Optional[ACPSession] = None
|
||||
|
||||
# Create permission handlers
|
||||
self._handlers = ACPHandlers(
|
||||
permission_mode=permission_mode,
|
||||
permission_allowlist=self.permission_allowlist,
|
||||
on_permission_log=self._log_permission,
|
||||
)
|
||||
|
||||
# Thread synchronization
|
||||
self._lock = threading.Lock()
|
||||
self._shutdown_requested = False
|
||||
|
||||
# Signal handlers
|
||||
self._original_sigint = None
|
||||
self._original_sigterm = None
|
||||
|
||||
# Call parent init - this will call check_availability()
|
||||
super().__init__("acp")
|
||||
|
||||
# Register signal handlers
|
||||
self._register_signal_handlers()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ACPAdapterConfig) -> "ACPAdapter":
|
||||
"""Create ACPAdapter from configuration object.
|
||||
|
||||
Args:
|
||||
config: ACPAdapterConfig with adapter settings.
|
||||
|
||||
Returns:
|
||||
Configured ACPAdapter instance.
|
||||
"""
|
||||
return cls(
|
||||
agent_command=config.agent_command,
|
||||
agent_args=config.agent_args,
|
||||
timeout=config.timeout,
|
||||
permission_mode=config.permission_mode,
|
||||
permission_allowlist=config.permission_allowlist,
|
||||
)
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if the agent command is available.
|
||||
|
||||
Returns:
|
||||
True if agent command exists in PATH, False otherwise.
|
||||
"""
|
||||
return shutil.which(self.agent_command) is not None
|
||||
|
||||
def _register_signal_handlers(self) -> None:
|
||||
"""Register signal handlers for graceful shutdown."""
|
||||
try:
|
||||
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
|
||||
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
except ValueError as e:
|
||||
logger.warning("Cannot register signal handlers (not in main thread): %s. Graceful shutdown via Ctrl+C will not work.", e)
|
||||
|
||||
def _restore_signal_handlers(self) -> None:
|
||||
"""Restore original signal handlers."""
|
||||
try:
|
||||
if self._original_sigint is not None:
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if self._original_sigterm is not None:
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning("Failed to restore signal handlers: %s", e)
|
||||
|
||||
def _signal_handler(self, signum: int, frame) -> None:
|
||||
"""Handle shutdown signals.
|
||||
|
||||
Terminates running subprocess synchronously (signal-safe),
|
||||
then propagates to original handler (orchestrator).
|
||||
|
||||
Args:
|
||||
signum: Signal number.
|
||||
frame: Current stack frame.
|
||||
"""
|
||||
with self._lock:
|
||||
self._shutdown_requested = True
|
||||
|
||||
# Kill subprocess synchronously (signal-safe)
|
||||
self.kill_subprocess_sync()
|
||||
|
||||
# Propagate signal to original handler (orchestrator's handler)
|
||||
original = self._original_sigint if signum == signal.SIGINT else self._original_sigterm
|
||||
if original and callable(original):
|
||||
original(signum, frame)
|
||||
|
||||
def kill_subprocess_sync(self) -> None:
|
||||
"""Synchronously kill the agent subprocess (signal-safe).
|
||||
|
||||
This method is safe to call from signal handlers.
|
||||
Uses non-blocking approach with immediate force kill after 2 seconds.
|
||||
"""
|
||||
if self._client and self._client._process:
|
||||
try:
|
||||
process = self._client._process
|
||||
if process.returncode is None:
|
||||
# Try graceful termination first
|
||||
process.terminate()
|
||||
|
||||
# Non-blocking poll with timeout
|
||||
import time
|
||||
start = time.time()
|
||||
timeout = 2.0
|
||||
|
||||
while time.time() - start < timeout:
|
||||
if process.poll() is not None:
|
||||
# Process terminated successfully
|
||||
return
|
||||
time.sleep(0.01) # Brief sleep to avoid busy-wait
|
||||
|
||||
# Timeout reached, force kill
|
||||
try:
|
||||
process.kill()
|
||||
# Brief wait to ensure kill completes
|
||||
time.sleep(0.1)
|
||||
process.poll()
|
||||
except Exception as e:
|
||||
logger.debug("Exception during subprocess kill: %s", e)
|
||||
except Exception as e:
|
||||
logger.debug("Exception during subprocess kill: %s", e)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Initialize ACP connection with agent.
|
||||
|
||||
Performs the ACP initialization handshake:
|
||||
1. Start ACPClient subprocess
|
||||
2. Send initialize request with protocol version
|
||||
3. Receive and validate initialize response
|
||||
4. Send session/new request
|
||||
5. Store session_id
|
||||
|
||||
Raises:
|
||||
ACPClientError: If initialization fails.
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Build effective args, auto-adding ACP flags for known agents
|
||||
effective_args = list(self.agent_args)
|
||||
|
||||
# Gemini CLI requires --experimental-acp flag to enter ACP mode
|
||||
# Also add --yolo to auto-approve internal tool executions
|
||||
# And --allowed-tools to enable native Gemini tools
|
||||
agent_basename = os.path.basename(self.agent_command)
|
||||
if agent_basename == "gemini":
|
||||
if "--experimental-acp" not in effective_args:
|
||||
logger.info("Auto-adding --experimental-acp flag for Gemini CLI")
|
||||
effective_args.append("--experimental-acp")
|
||||
if "--yolo" not in effective_args:
|
||||
logger.info("Auto-adding --yolo flag for Gemini CLI tool execution")
|
||||
effective_args.append("--yolo")
|
||||
# Enable native Gemini tools for ACP mode
|
||||
# Note: Excluding write_file and run_shell_command - they have bugs in ACP mode
|
||||
# Gemini should fall back to ACP's fs/write_text_file and terminal/create
|
||||
if "--allowed-tools" not in effective_args:
|
||||
logger.info("Auto-adding --allowed-tools for Gemini CLI native tools")
|
||||
effective_args.extend([
|
||||
"--allowed-tools",
|
||||
"list_directory",
|
||||
"read_many_files",
|
||||
"read_file",
|
||||
"web_fetch",
|
||||
"google_web_search",
|
||||
])
|
||||
|
||||
# Create and start client
|
||||
self._client = ACPClient(
|
||||
command=self.agent_command,
|
||||
args=effective_args,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
await self._client.start()
|
||||
|
||||
# Register notification handler for session updates
|
||||
self._client.on_notification(self._handle_notification)
|
||||
|
||||
# Register request handler for permission requests
|
||||
self._client.on_request(self._handle_request)
|
||||
|
||||
try:
|
||||
# Send initialize request (per ACP spec)
|
||||
init_future = self._client.send_request(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": ACP_PROTOCOL_VERSION,
|
||||
"clientCapabilities": {
|
||||
"fs": {
|
||||
"readTextFile": True,
|
||||
"writeTextFile": True,
|
||||
},
|
||||
"terminal": True,
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "ralph-orchestrator",
|
||||
"title": "Ralph Orchestrator",
|
||||
"version": "1.2.3",
|
||||
},
|
||||
},
|
||||
)
|
||||
init_response = await asyncio.wait_for(init_future, timeout=self.timeout)
|
||||
|
||||
# Validate response
|
||||
if "protocolVersion" not in init_response:
|
||||
raise ACPClientError("Invalid initialize response: missing protocolVersion")
|
||||
|
||||
# Create new session (cwd and mcpServers are required per ACP spec)
|
||||
session_future = self._client.send_request(
|
||||
"session/new",
|
||||
{
|
||||
"cwd": os.getcwd(),
|
||||
"mcpServers": [], # No MCP servers by default
|
||||
},
|
||||
)
|
||||
session_response = await asyncio.wait_for(session_future, timeout=self.timeout)
|
||||
|
||||
# Store session ID
|
||||
self._session_id = session_response.get("sessionId")
|
||||
if not self._session_id:
|
||||
raise ACPClientError("Invalid session/new response: missing sessionId")
|
||||
|
||||
# Create session state tracker
|
||||
self._session = ACPSession(session_id=self._session_id)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await self._client.stop()
|
||||
raise ACPClientError("Initialization timed out")
|
||||
except Exception:
|
||||
await self._client.stop()
|
||||
raise
|
||||
|
||||
def _handle_notification(self, method: str, params: dict) -> None:
|
||||
"""Handle notifications from agent.
|
||||
|
||||
Args:
|
||||
method: Notification method name.
|
||||
params: Notification parameters.
|
||||
"""
|
||||
if method == "session/update" and self._session:
|
||||
# Handle both notification formats:
|
||||
# Format 1 (flat): {"kind": "agent_message_chunk", "content": "..."}
|
||||
# Format 2 (nested): {"update": {"sessionUpdate": "agent_message_chunk", "content": {...}}}
|
||||
if "update" in params:
|
||||
# Nested format (Gemini)
|
||||
update = params["update"]
|
||||
kind = update.get("sessionUpdate", "")
|
||||
content_obj = update.get("content")
|
||||
content = None
|
||||
flat_params = {"kind": kind, "content": content}
|
||||
allow_name_id = kind in ("tool_call", "tool_call_update")
|
||||
# Extract text content if it's an object
|
||||
if isinstance(content_obj, dict):
|
||||
if "text" in content_obj:
|
||||
content = content_obj.get("text", "")
|
||||
flat_params["content"] = content
|
||||
elif isinstance(content_obj, list):
|
||||
for entry in content_obj:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
self._merge_tool_fields(flat_params, entry, allow_name_id=allow_name_id)
|
||||
nested_tool_call = entry.get("toolCall") or entry.get("tool_call")
|
||||
if isinstance(nested_tool_call, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool_call,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
nested_tool = entry.get("tool")
|
||||
if isinstance(nested_tool, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
else:
|
||||
content = str(content_obj) if content_obj else ""
|
||||
flat_params["content"] = content
|
||||
self._merge_tool_fields(flat_params, update, allow_name_id=allow_name_id)
|
||||
if isinstance(content_obj, dict):
|
||||
self._merge_tool_fields(flat_params, content_obj, allow_name_id=allow_name_id)
|
||||
nested_tool_call = content_obj.get("toolCall") or content_obj.get("tool_call")
|
||||
if isinstance(nested_tool_call, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool_call,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
nested_tool = content_obj.get("tool")
|
||||
if isinstance(nested_tool, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
nested_tool_call = update.get("toolCall") or update.get("tool_call")
|
||||
if isinstance(nested_tool_call, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool_call,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
nested_tool = update.get("tool")
|
||||
if isinstance(nested_tool, dict):
|
||||
self._merge_tool_fields(
|
||||
flat_params,
|
||||
nested_tool,
|
||||
allow_name_id=allow_name_id,
|
||||
)
|
||||
payload = UpdatePayload.from_dict(flat_params)
|
||||
payload._raw = update
|
||||
payload._raw_flat = flat_params
|
||||
else:
|
||||
# Flat format
|
||||
payload = UpdatePayload.from_dict(params)
|
||||
payload._raw = params
|
||||
|
||||
# Stream to console if verbose; always show tool calls
|
||||
if self._current_verbose:
|
||||
self._stream_update(payload, show_details=True)
|
||||
elif payload.kind == "tool_call":
|
||||
self._stream_update(payload, show_details=False)
|
||||
|
||||
self._session.process_update(payload)
|
||||
|
||||
def _merge_tool_fields(
|
||||
self,
|
||||
target: dict,
|
||||
source: dict,
|
||||
*,
|
||||
allow_name_id: bool = False,
|
||||
) -> None:
|
||||
"""Merge tool call fields from source into target with alias support."""
|
||||
for canonical, aliases in self._TOOL_FIELD_ALIASES.items():
|
||||
if not allow_name_id and canonical in ("toolName", "toolCallId"):
|
||||
aliases = tuple(
|
||||
alias for alias in aliases if alias not in ("name", "id")
|
||||
)
|
||||
if canonical in target and target[canonical] not in (None, ""):
|
||||
continue
|
||||
for key in aliases:
|
||||
if key in source:
|
||||
value = source[key]
|
||||
if key == "tool" and canonical == "toolName":
|
||||
if isinstance(value, dict):
|
||||
value = value.get("name") or value.get("toolName") or value.get("tool_name")
|
||||
elif not isinstance(value, str):
|
||||
value = None
|
||||
if value is None or value == "":
|
||||
continue
|
||||
target[canonical] = value
|
||||
break
|
||||
|
||||
def _format_agent_label(self) -> str:
|
||||
"""Return the ACP agent command with arguments for display."""
|
||||
if not self.agent_args:
|
||||
return self.agent_command
|
||||
return " ".join([self.agent_command, *self.agent_args])
|
||||
|
||||
def _format_payload_value(self, value: object, limit: int = 200) -> str:
|
||||
"""Format payload values for console output."""
|
||||
if value is None:
|
||||
return ""
|
||||
value_str = str(value)
|
||||
if len(value_str) > limit:
|
||||
return value_str[: limit - 3] + "..."
|
||||
return value_str
|
||||
|
||||
def _format_payload_error(self, error: object) -> str:
|
||||
"""Extract a readable error string from ACP error payloads."""
|
||||
if error is None:
|
||||
return ""
|
||||
if isinstance(error, dict):
|
||||
message = error.get("message") or error.get("error") or error.get("detail")
|
||||
code = error.get("code")
|
||||
data = error.get("data")
|
||||
parts = []
|
||||
if message:
|
||||
parts.append(message)
|
||||
if code is not None:
|
||||
parts.append(f"code={code}")
|
||||
if data and not message:
|
||||
parts.append(str(data))
|
||||
if parts:
|
||||
return self._format_payload_value(" ".join(parts), limit=200)
|
||||
return self._format_payload_value(error, limit=200)
|
||||
|
||||
def _get_raw_payload(self, payload: UpdatePayload) -> dict | None:
|
||||
raw = getattr(payload, "_raw", None)
|
||||
return raw if isinstance(raw, dict) else None
|
||||
|
||||
def _extract_tool_field(self, raw: dict | None, key: str) -> object:
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
value = raw.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
for nested_key in ("toolCall", "tool_call", "tool"):
|
||||
nested = raw.get(nested_key)
|
||||
if isinstance(nested, dict) and key in nested:
|
||||
nested_value = nested.get(key)
|
||||
if nested_value not in (None, ""):
|
||||
return nested_value
|
||||
return None
|
||||
|
||||
def _extract_tool_name_from_meta(self, raw: dict | None) -> str | None:
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
meta = raw.get("_meta")
|
||||
if not isinstance(meta, dict):
|
||||
return None
|
||||
for key in ("codex", "claudeCode", "agent", "acp"):
|
||||
entry = meta.get(key)
|
||||
if isinstance(entry, dict):
|
||||
for name_key in ("toolName", "tool_name", "name", "tool"):
|
||||
value = entry.get(name_key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
for name_key in ("toolName", "tool_name", "name", "tool"):
|
||||
value = meta.get(name_key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
def _extract_tool_response(self, raw: dict | None) -> object:
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
meta = raw.get("_meta")
|
||||
if not isinstance(meta, dict):
|
||||
return None
|
||||
for key in ("codex", "claudeCode", "agent", "acp"):
|
||||
entry = meta.get(key)
|
||||
if isinstance(entry, dict) and "toolResponse" in entry:
|
||||
return entry.get("toolResponse")
|
||||
if "toolResponse" in meta:
|
||||
return meta.get("toolResponse")
|
||||
return None
|
||||
|
||||
def _stream_update(self, payload: UpdatePayload, show_details: bool = True) -> None:
|
||||
"""Stream session update to console.
|
||||
|
||||
Args:
|
||||
payload: The update payload to stream.
|
||||
show_details: Include detailed info (arguments, results, progress).
|
||||
"""
|
||||
kind = payload.kind
|
||||
|
||||
if kind == "agent_message_chunk":
|
||||
# Stream agent output text
|
||||
if payload.content:
|
||||
self._console.print_message(payload.content)
|
||||
|
||||
elif kind == "agent_thought_chunk":
|
||||
# Stream agent internal reasoning (dimmed)
|
||||
if payload.content:
|
||||
if self._console.console:
|
||||
self._console.console.print(
|
||||
f"[dim italic]{payload.content}[/dim italic]",
|
||||
end="",
|
||||
)
|
||||
else:
|
||||
print(payload.content, end="")
|
||||
|
||||
elif kind == "tool_call":
|
||||
# Show tool call start
|
||||
tool_name = payload.tool_name
|
||||
raw_update = self._get_raw_payload(payload)
|
||||
meta_tool_name = self._extract_tool_name_from_meta(raw_update)
|
||||
title = self._extract_tool_field(raw_update, "title")
|
||||
kind = self._extract_tool_field(raw_update, "kind")
|
||||
raw_input = self._extract_tool_field(raw_update, "rawInput")
|
||||
if raw_input is None:
|
||||
raw_input = self._extract_tool_field(raw_update, "input")
|
||||
tool_name = tool_name or meta_tool_name or title or kind or "unknown"
|
||||
tool_id = payload.tool_call_id or "unknown"
|
||||
self._console.print_separator()
|
||||
self._console.print_status(f"TOOL CALL: {tool_name}", style="cyan bold")
|
||||
self._console.print_info(f"ID: {tool_id[:12]}...")
|
||||
self._console.print_info(f"Agent: {self._format_agent_label()}")
|
||||
if show_details:
|
||||
if title and title != tool_name:
|
||||
self._console.print_info(f"Title: {title}")
|
||||
if kind:
|
||||
self._console.print_info(f"Kind: {kind}")
|
||||
if tool_name == "unknown":
|
||||
raw_str = self._format_payload_value(raw_update, limit=300)
|
||||
if raw_str:
|
||||
self._console.print_info(f"Update: {raw_str}")
|
||||
if payload.arguments or raw_input:
|
||||
input_value = payload.arguments or raw_input
|
||||
if isinstance(input_value, dict):
|
||||
self._console.print_info("Arguments:")
|
||||
for key, value in input_value.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:97] + "..."
|
||||
self._console.print_info(f" - {key}: {value_str}")
|
||||
else:
|
||||
input_str = self._format_payload_value(input_value, limit=300)
|
||||
if input_str:
|
||||
self._console.print_info(f"Input: {input_str}")
|
||||
|
||||
elif kind == "tool_call_update":
|
||||
if not show_details:
|
||||
return
|
||||
# Show tool call status update
|
||||
tool_id = payload.tool_call_id or "unknown"
|
||||
status = payload.status or "unknown"
|
||||
tool_name = payload.tool_name
|
||||
tool_args = None
|
||||
tool_call = None
|
||||
if self._session and payload.tool_call_id:
|
||||
tool_call = self._session.get_tool_call(payload.tool_call_id)
|
||||
if tool_call:
|
||||
tool_name = tool_name or tool_call.tool_name
|
||||
tool_args = tool_call.arguments or None
|
||||
raw_update = self._get_raw_payload(payload)
|
||||
meta_tool_name = self._extract_tool_name_from_meta(raw_update)
|
||||
title = self._extract_tool_field(raw_update, "title")
|
||||
kind = self._extract_tool_field(raw_update, "kind")
|
||||
raw_input = self._extract_tool_field(raw_update, "rawInput")
|
||||
raw_output = self._extract_tool_field(raw_update, "rawOutput")
|
||||
tool_name = tool_name or meta_tool_name or title
|
||||
display_name = tool_name or kind or "unknown"
|
||||
if display_name == "unknown":
|
||||
status_label = f"Tool call {tool_id[:12]}..."
|
||||
else:
|
||||
status_label = f"Tool {display_name} ({tool_id[:12]}...)"
|
||||
|
||||
if status == "completed":
|
||||
self._console.print_success(
|
||||
f"{status_label} completed"
|
||||
)
|
||||
result_value = payload.result
|
||||
if result_value is None and tool_call:
|
||||
result_value = tool_call.result
|
||||
if result_value is None:
|
||||
result_value = raw_output
|
||||
if result_value is None:
|
||||
result_value = self._extract_tool_response(raw_update)
|
||||
result_str = self._format_payload_value(result_value)
|
||||
if result_str:
|
||||
self._console.print_info(f"Result: {result_str}")
|
||||
elif status == "failed":
|
||||
self._console.print_error(
|
||||
f"{status_label} failed"
|
||||
)
|
||||
if display_name == "unknown":
|
||||
raw_str = self._format_payload_value(raw_update, limit=300)
|
||||
if raw_str:
|
||||
self._console.print_info(f"Update: {raw_str}")
|
||||
error_str = self._format_payload_error(payload.error)
|
||||
if not error_str and payload.result is not None:
|
||||
error_str = self._format_payload_value(payload.result)
|
||||
if not error_str and raw_output is not None:
|
||||
error_str = self._format_payload_value(raw_output)
|
||||
if not error_str:
|
||||
error_str = self._format_payload_value(
|
||||
self._extract_tool_response(raw_update)
|
||||
)
|
||||
if error_str:
|
||||
self._console.print_error(f"Error: {error_str}")
|
||||
if tool_args or raw_input:
|
||||
if tool_args is None:
|
||||
tool_args = raw_input
|
||||
self._console.print_info("Arguments:")
|
||||
if isinstance(tool_args, dict):
|
||||
for key, value in tool_args.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:97] + "..."
|
||||
self._console.print_info(f" - {key}: {value_str}")
|
||||
else:
|
||||
arg_str = self._format_payload_value(tool_args, limit=300)
|
||||
if arg_str:
|
||||
self._console.print_info(f" - {arg_str}")
|
||||
elif status == "running":
|
||||
self._console.print_status(
|
||||
f"{status_label} running",
|
||||
style="yellow",
|
||||
)
|
||||
progress_value = payload.result or payload.content
|
||||
if progress_value is None:
|
||||
progress_value = raw_output
|
||||
progress_str = self._format_payload_value(progress_value, limit=200)
|
||||
if progress_str:
|
||||
self._console.print_info(f"Progress: {progress_str}")
|
||||
else:
|
||||
self._console.print_status(
|
||||
f"{status_label} {status}",
|
||||
style="yellow",
|
||||
)
|
||||
if title and title != display_name:
|
||||
self._console.print_info(f"Title: {title}")
|
||||
if kind:
|
||||
self._console.print_info(f"Kind: {kind}")
|
||||
|
||||
def _handle_request(self, method: str, params: dict) -> dict:
|
||||
"""Handle requests from agent.
|
||||
|
||||
Routes requests to appropriate handlers:
|
||||
- session/request_permission: Permission checks
|
||||
- fs/read_text_file: File read operations
|
||||
- fs/write_text_file: File write operations
|
||||
- terminal/*: Terminal operations
|
||||
|
||||
Args:
|
||||
method: Request method name.
|
||||
params: Request parameters.
|
||||
|
||||
Returns:
|
||||
Response result dict.
|
||||
"""
|
||||
logger.info("ACP REQUEST: method=%s", method)
|
||||
if method == "session/request_permission":
|
||||
# Permission handler already returns ACP-compliant format
|
||||
return self._handle_permission_request(params)
|
||||
|
||||
# File operations - return raw result (client wraps in JSON-RPC)
|
||||
if method == "fs/read_text_file":
|
||||
return self._handlers.handle_read_file(params)
|
||||
if method == "fs/write_text_file":
|
||||
return self._handlers.handle_write_file(params)
|
||||
|
||||
# Terminal operations - return raw result (client wraps in JSON-RPC)
|
||||
if method == "terminal/create":
|
||||
return self._handlers.handle_terminal_create(params)
|
||||
if method == "terminal/output":
|
||||
return self._handlers.handle_terminal_output(params)
|
||||
if method == "terminal/wait_for_exit":
|
||||
return self._handlers.handle_terminal_wait_for_exit(params)
|
||||
if method == "terminal/kill":
|
||||
return self._handlers.handle_terminal_kill(params)
|
||||
if method == "terminal/release":
|
||||
return self._handlers.handle_terminal_release(params)
|
||||
|
||||
# Unknown request - log and return error
|
||||
logger.warning("Unknown ACP request method: %s with params: %s", method, params)
|
||||
return {"error": {"code": -32601, "message": f"Method not found: {method}"}}
|
||||
|
||||
def _handle_permission_request(self, params: dict) -> dict:
|
||||
"""Handle permission request from agent.
|
||||
|
||||
Delegates to ACPHandlers which supports multiple modes:
|
||||
- auto_approve: Always approve
|
||||
- deny_all: Always deny
|
||||
- allowlist: Check against configured patterns
|
||||
- interactive: Prompt user (if terminal available)
|
||||
|
||||
Args:
|
||||
params: Permission request parameters.
|
||||
|
||||
Returns:
|
||||
Response with approved: True/False.
|
||||
"""
|
||||
return self._handlers.handle_request_permission(params)
|
||||
|
||||
def _log_permission(self, message: str) -> None:
|
||||
"""Log permission decision.
|
||||
|
||||
Args:
|
||||
message: Permission decision message.
|
||||
"""
|
||||
logger.info(message)
|
||||
|
||||
def get_permission_history(self) -> list:
|
||||
"""Get permission decision history.
|
||||
|
||||
Returns:
|
||||
List of (request, result) tuples.
|
||||
"""
|
||||
return self._handlers.get_history()
|
||||
|
||||
def get_permission_stats(self) -> dict:
|
||||
"""Get permission decision statistics.
|
||||
|
||||
Returns:
|
||||
Dict with approved_count and denied_count.
|
||||
"""
|
||||
return {
|
||||
"approved_count": self._handlers.get_approved_count(),
|
||||
"denied_count": self._handlers.get_denied_count(),
|
||||
}
|
||||
|
||||
async def _execute_prompt(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute a prompt through the ACP agent.
|
||||
|
||||
Sends session/prompt request with messages array and waits for response.
|
||||
Session updates (streaming output, thoughts, tool calls) are processed
|
||||
through _handle_notification during the request.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to execute.
|
||||
**kwargs: Additional arguments (verbose: bool).
|
||||
|
||||
Returns:
|
||||
ToolResponse with execution result.
|
||||
"""
|
||||
# Get verbose from kwargs (per-call override) without mutating instance state
|
||||
verbose = kwargs.get("verbose", self.verbose)
|
||||
# Store for use in _handle_notification during this request
|
||||
self._current_verbose = verbose
|
||||
|
||||
# Reset session state for new prompt (preserve session_id)
|
||||
if self._session:
|
||||
self._session.reset()
|
||||
|
||||
# Print header if verbose
|
||||
if verbose:
|
||||
self._console.print_header(f"ACP AGENT ({self.agent_command})")
|
||||
self._console.print_status("Processing prompt...")
|
||||
|
||||
# Build prompt array per ACP spec (ContentBlock format)
|
||||
prompt_blocks = [{"type": "text", "text": prompt}]
|
||||
|
||||
# Send session/prompt request
|
||||
try:
|
||||
prompt_future = self._client.send_request(
|
||||
"session/prompt",
|
||||
{
|
||||
"sessionId": self._session_id,
|
||||
"prompt": prompt_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
# Wait for response with timeout
|
||||
response = await asyncio.wait_for(prompt_future, timeout=self.timeout)
|
||||
|
||||
# Check for error stop reason
|
||||
stop_reason = response.get("stopReason", "unknown")
|
||||
if stop_reason == "error":
|
||||
error_obj = response.get("error", {})
|
||||
error_msg = error_obj.get("message", "Unknown error from agent")
|
||||
if verbose:
|
||||
self._console.print_separator()
|
||||
self._console.print_error(f"Agent error: {error_msg}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=self._session.output if self._session else "",
|
||||
error=error_msg,
|
||||
metadata={
|
||||
"tool": "acp",
|
||||
"agent": self.agent_command,
|
||||
"session_id": self._session_id,
|
||||
"stop_reason": stop_reason,
|
||||
},
|
||||
)
|
||||
|
||||
# Build successful response
|
||||
output = self._session.output if self._session else ""
|
||||
if verbose:
|
||||
self._console.print_separator()
|
||||
tool_count = len(self._session.tool_calls) if self._session else 0
|
||||
self._console.print_success(f"Agent completed (tools: {tool_count})")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=output,
|
||||
metadata={
|
||||
"tool": "acp",
|
||||
"agent": self.agent_command,
|
||||
"session_id": self._session_id,
|
||||
"stop_reason": stop_reason,
|
||||
"tool_calls_count": len(self._session.tool_calls) if self._session else 0,
|
||||
"has_thoughts": bool(self._session.thoughts) if self._session else False,
|
||||
},
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if verbose:
|
||||
self._console.print_separator()
|
||||
self._console.print_error(f"Timeout after {self.timeout}s")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=self._session.output if self._session else "",
|
||||
error=f"Prompt execution timed out after {self.timeout} seconds",
|
||||
metadata={
|
||||
"tool": "acp",
|
||||
"agent": self.agent_command,
|
||||
"session_id": self._session_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def _shutdown(self) -> None:
|
||||
"""Shutdown the ACP connection.
|
||||
|
||||
Stops the client and cleans up state.
|
||||
"""
|
||||
# Kill all running terminals first
|
||||
if self._handlers:
|
||||
for terminal_id in list(self._handlers._terminals.keys()):
|
||||
try:
|
||||
self._handlers.handle_terminal_kill({"terminalId": terminal_id})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to kill terminal %s: %s", terminal_id, e)
|
||||
|
||||
if self._client:
|
||||
await self._client.stop()
|
||||
self._client = None
|
||||
|
||||
self._initialized = False
|
||||
self._session_id = None
|
||||
self._session = None
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute the prompt synchronously.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to execute.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
ToolResponse with execution result.
|
||||
"""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"ACP adapter not available: {self.agent_command} not found",
|
||||
)
|
||||
|
||||
# Run async method in new event loop
|
||||
try:
|
||||
return asyncio.run(self.aexecute(prompt, **kwargs))
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute the prompt asynchronously.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to execute.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
ToolResponse with execution result.
|
||||
"""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"ACP adapter not available: {self.agent_command} not found",
|
||||
)
|
||||
|
||||
try:
|
||||
# Initialize if needed
|
||||
if not self._initialized:
|
||||
await self._initialize()
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Execute prompt
|
||||
return await self._execute_prompt(enhanced_prompt, **kwargs)
|
||||
|
||||
except ACPClientError as e:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"ACP error: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Estimate execution cost.
|
||||
|
||||
ACP doesn't provide billing information, so returns 0.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to estimate.
|
||||
|
||||
Returns:
|
||||
Always 0.0 (no billing info from ACP).
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup on deletion."""
|
||||
self._restore_signal_handlers()
|
||||
|
||||
# Best-effort cleanup
|
||||
if self._client:
|
||||
try:
|
||||
self.kill_subprocess_sync()
|
||||
except Exception as e:
|
||||
logger.debug("Exception during cleanup in __del__: %s", e)
|
||||
@@ -0,0 +1,352 @@
|
||||
# ABOUTME: ACPClient manages subprocess lifecycle for ACP agents
|
||||
# ABOUTME: Handles async message routing, request tracking, and graceful shutdown
|
||||
|
||||
"""ACPClient subprocess manager for ACP (Agent Client Protocol)."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from .acp_protocol import ACPProtocol, MessageType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default asyncio StreamReader line limit (in bytes) for ACP agent stdout/stderr.
|
||||
# Some ACP agents can emit large single-line JSON-RPC frames (e.g., big tool payloads).
|
||||
# The asyncio default (~64KiB) can raise LimitOverrunError and break the ACP session.
|
||||
DEFAULT_ACP_STREAM_LIMIT = 8 * 1024 * 1024 # 8 MiB
|
||||
|
||||
|
||||
class ACPClientError(Exception):
|
||||
"""Exception raised by ACPClient operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ACPClient:
|
||||
"""Manages subprocess lifecycle and async message routing for ACP agents.
|
||||
|
||||
Spawns an agent subprocess, handles JSON-RPC message serialization,
|
||||
routes responses to pending requests, and invokes callbacks for
|
||||
notifications and incoming requests.
|
||||
|
||||
Attributes:
|
||||
command: The command to spawn the agent.
|
||||
args: Additional command-line arguments.
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: str,
|
||||
args: Optional[list[str]] = None,
|
||||
timeout: int = 300,
|
||||
stream_limit: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize ACPClient.
|
||||
|
||||
Args:
|
||||
command: The command to spawn the agent (e.g., "gemini").
|
||||
args: Additional command-line arguments.
|
||||
timeout: Request timeout in seconds (default: 300).
|
||||
"""
|
||||
self.command = command
|
||||
self.args = args or []
|
||||
self.timeout = timeout
|
||||
# Limit used by asyncio for readline()/readuntil(); must be > max ACP frame line length.
|
||||
env_limit = os.environ.get("RALPH_ACP_STREAM_LIMIT")
|
||||
if stream_limit is not None:
|
||||
self.stream_limit = stream_limit
|
||||
elif env_limit:
|
||||
try:
|
||||
self.stream_limit = int(env_limit)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Invalid RALPH_ACP_STREAM_LIMIT=%r; using default %d",
|
||||
env_limit,
|
||||
DEFAULT_ACP_STREAM_LIMIT,
|
||||
)
|
||||
self.stream_limit = DEFAULT_ACP_STREAM_LIMIT
|
||||
else:
|
||||
self.stream_limit = DEFAULT_ACP_STREAM_LIMIT
|
||||
|
||||
self._protocol = ACPProtocol()
|
||||
self._process: Optional[asyncio.subprocess.Process] = None
|
||||
self._read_task: Optional[asyncio.Task] = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
|
||||
# Pending requests: id -> Future
|
||||
self._pending_requests: dict[int, asyncio.Future] = {}
|
||||
|
||||
# Notification handlers
|
||||
self._notification_handlers: list[Callable[[str, dict], None]] = []
|
||||
|
||||
# Request handlers (for incoming requests from agent)
|
||||
self._request_handlers: list[Callable[[str, dict], Any]] = []
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if subprocess is running.
|
||||
|
||||
Returns:
|
||||
True if subprocess is running, False otherwise.
|
||||
"""
|
||||
return self._process is not None and self._process.returncode is None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the agent subprocess.
|
||||
|
||||
Spawns the subprocess with stdin/stdout/stderr pipes and starts
|
||||
the read loop task.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If already running.
|
||||
FileNotFoundError: If command not found.
|
||||
"""
|
||||
if self.is_running:
|
||||
raise RuntimeError("ACPClient is already running")
|
||||
|
||||
cmd = [self.command] + self.args
|
||||
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
limit=self.stream_limit,
|
||||
)
|
||||
|
||||
# Start the read loop
|
||||
self._read_task = asyncio.create_task(self._read_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the agent subprocess.
|
||||
|
||||
Terminates the subprocess gracefully with 2 second timeout, then kills if necessary.
|
||||
Cancels the read loop task and all pending requests.
|
||||
"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Cancel read loop first
|
||||
if self._read_task and not self._read_task.done():
|
||||
self._read_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._read_task, timeout=0.5)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Read task cancelled during shutdown")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Read task cancellation timed out")
|
||||
|
||||
# Terminate subprocess with 2 second timeout
|
||||
if self._process:
|
||||
try:
|
||||
self._process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
# Force kill if graceful termination fails
|
||||
logger.warning("Process did not terminate gracefully, killing")
|
||||
self._process.kill()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Process did not die after kill signal")
|
||||
except ProcessLookupError:
|
||||
logger.debug("Process already terminated")
|
||||
|
||||
self._process = None
|
||||
self._read_task = None
|
||||
|
||||
# Cancel all pending requests
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
async def _read_loop(self) -> None:
|
||||
"""Continuously read stdout and route messages.
|
||||
|
||||
Reads newline-delimited JSON-RPC messages from subprocess stdout.
|
||||
"""
|
||||
if not self._process or not self._process.stdout:
|
||||
return
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
line = await self._process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
|
||||
message_str = line.decode().strip()
|
||||
if message_str:
|
||||
await self._handle_message(message_str)
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected during shutdown
|
||||
except Exception as e:
|
||||
logger.error("ACP read loop failed: %s", e, exc_info=True)
|
||||
finally:
|
||||
# Cancel all pending requests when read loop exits (subprocess died or cancelled)
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(ACPClientError("Agent subprocess terminated"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
async def _handle_message(self, message_str: str) -> None:
|
||||
"""Handle a received JSON-RPC message.
|
||||
|
||||
Routes message to appropriate handler based on type.
|
||||
|
||||
Args:
|
||||
message_str: Raw JSON string.
|
||||
"""
|
||||
parsed = self._protocol.parse_message(message_str)
|
||||
msg_type = parsed.get("type")
|
||||
|
||||
if msg_type == MessageType.RESPONSE:
|
||||
# Route to pending request
|
||||
request_id = parsed.get("id")
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(parsed.get("result"))
|
||||
|
||||
elif msg_type == MessageType.ERROR:
|
||||
# Route error to pending request
|
||||
request_id = parsed.get("id")
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
error = parsed.get("error", {})
|
||||
error_msg = error.get("message", "Unknown error")
|
||||
future.set_exception(ACPClientError(error_msg))
|
||||
|
||||
elif msg_type == MessageType.NOTIFICATION:
|
||||
# Invoke notification handlers
|
||||
method = parsed.get("method", "")
|
||||
params = parsed.get("params", {})
|
||||
for handler in self._notification_handlers:
|
||||
try:
|
||||
handler(method, params)
|
||||
except Exception as e:
|
||||
logger.error("Notification handler failed for method=%s: %s", method, e, exc_info=True)
|
||||
|
||||
elif msg_type == MessageType.REQUEST:
|
||||
# Invoke request handlers and send response
|
||||
request_id = parsed.get("id")
|
||||
method = parsed.get("method", "")
|
||||
params = parsed.get("params", {})
|
||||
|
||||
for handler in self._request_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
result = await handler(method, params)
|
||||
else:
|
||||
result = handler(method, params)
|
||||
|
||||
# Check if handler returned an error (dict with "error" key)
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
error_info = result["error"]
|
||||
error_code = error_info.get("code", -32603) if isinstance(error_info, dict) else -32603
|
||||
error_msg = error_info.get("message", str(error_info)) if isinstance(error_info, dict) else str(error_info)
|
||||
response = self._protocol.create_error_response(request_id, error_code, error_msg)
|
||||
else:
|
||||
response = self._protocol.create_response(request_id, result)
|
||||
await self._write_message(response)
|
||||
break # Only first handler responds
|
||||
except Exception as e:
|
||||
# Send error response
|
||||
error_response = self._protocol.create_error_response(
|
||||
request_id, -32603, str(e)
|
||||
)
|
||||
await self._write_message(error_response)
|
||||
break
|
||||
|
||||
async def _write_message(self, message: str) -> None:
|
||||
"""Write a JSON-RPC message to subprocess stdin.
|
||||
|
||||
Args:
|
||||
message: JSON string to write.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not running.
|
||||
"""
|
||||
if not self.is_running or not self._process or not self._process.stdin:
|
||||
raise RuntimeError("ACPClient is not running")
|
||||
|
||||
async with self._write_lock:
|
||||
self._process.stdin.write((message + "\n").encode())
|
||||
await self._process.stdin.drain()
|
||||
|
||||
async def _do_send(self, request_id: int, message: str) -> None:
|
||||
"""Helper to send message and handle write errors.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
message: The JSON-RPC message to send.
|
||||
"""
|
||||
try:
|
||||
await self._write_message(message)
|
||||
except Exception as e:
|
||||
future = self._pending_requests.pop(request_id, None)
|
||||
if future and not future.done():
|
||||
future.set_exception(ACPClientError(f"Failed to send request: {e}"))
|
||||
|
||||
def send_request(
|
||||
self, method: str, params: dict[str, Any]
|
||||
) -> asyncio.Future[Any]:
|
||||
"""Send a JSON-RPC request and return Future for response.
|
||||
|
||||
Args:
|
||||
method: The RPC method name.
|
||||
params: The request parameters.
|
||||
|
||||
Returns:
|
||||
Future that resolves with the response result.
|
||||
"""
|
||||
request_id, message = self._protocol.create_request(method, params)
|
||||
|
||||
# Create future for response
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Any] = loop.create_future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
# Schedule write with error handling
|
||||
asyncio.create_task(self._do_send(request_id, message))
|
||||
|
||||
return future
|
||||
|
||||
async def send_notification(
|
||||
self, method: str, params: dict[str, Any]
|
||||
) -> None:
|
||||
"""Send a JSON-RPC notification (no response expected).
|
||||
|
||||
Args:
|
||||
method: The notification method name.
|
||||
params: The notification parameters.
|
||||
"""
|
||||
message = self._protocol.create_notification(method, params)
|
||||
await self._write_message(message)
|
||||
|
||||
def on_notification(
|
||||
self, handler: Callable[[str, dict], None]
|
||||
) -> None:
|
||||
"""Register a notification handler.
|
||||
|
||||
Args:
|
||||
handler: Callback invoked with (method, params) for each notification.
|
||||
"""
|
||||
self._notification_handlers.append(handler)
|
||||
|
||||
def on_request(
|
||||
self, handler: Callable[[str, dict], Any]
|
||||
) -> None:
|
||||
"""Register a request handler for incoming requests from agent.
|
||||
|
||||
Handler should return the response result. Can be sync or async.
|
||||
|
||||
Args:
|
||||
handler: Callback invoked with (method, params), returns result.
|
||||
"""
|
||||
self._request_handlers.append(handler)
|
||||
@@ -0,0 +1,889 @@
|
||||
# ABOUTME: ACP handlers for permission requests and file/terminal operations
|
||||
# ABOUTME: Provides permission_mode handling (auto_approve, deny_all, allowlist, interactive)
|
||||
# ABOUTME: Implements fs/read_text_file and fs/write_text_file handlers with security
|
||||
# ABOUTME: Implements terminal/* handlers for command execution
|
||||
|
||||
"""ACP Handlers for permission requests and agent-to-host operations.
|
||||
|
||||
This module provides the ACPHandlers class which manages permission requests
|
||||
from ACP-compliant agents and handles file operations. It supports:
|
||||
|
||||
Permission modes:
|
||||
- auto_approve: Approve all requests automatically
|
||||
- deny_all: Deny all requests
|
||||
- allowlist: Only approve requests matching configured patterns
|
||||
- interactive: Prompt user for each request (requires terminal)
|
||||
|
||||
File operations:
|
||||
- fs/read_text_file: Read file content with security validation
|
||||
- fs/write_text_file: Write file content with security validation
|
||||
|
||||
Terminal operations:
|
||||
- terminal/create: Create a new terminal with command
|
||||
- terminal/output: Read output from a terminal
|
||||
- terminal/wait_for_exit: Wait for terminal process to exit
|
||||
- terminal/kill: Kill a terminal process
|
||||
- terminal/release: Release terminal resources
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Terminal:
|
||||
"""Represents a terminal subprocess.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the terminal.
|
||||
process: The subprocess.Popen instance.
|
||||
output_buffer: Accumulated output from stdout/stderr.
|
||||
"""
|
||||
|
||||
id: str
|
||||
process: subprocess.Popen
|
||||
output_buffer: str = ""
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the process is still running."""
|
||||
return self.process.poll() is None
|
||||
|
||||
@property
|
||||
def exit_code(self) -> Optional[int]:
|
||||
"""Get the exit code if process has exited."""
|
||||
return self.process.poll()
|
||||
|
||||
def read_output(self) -> str:
|
||||
"""Read any available output without blocking.
|
||||
|
||||
Returns:
|
||||
New output since last read.
|
||||
"""
|
||||
import select
|
||||
|
||||
new_output = ""
|
||||
|
||||
# Try to read from stdout and stderr
|
||||
for stream in [self.process.stdout, self.process.stderr]:
|
||||
if stream is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Non-blocking read using select
|
||||
while True:
|
||||
ready, _, _ = select.select([stream], [], [], 0)
|
||||
if not ready:
|
||||
break
|
||||
chunk = stream.read(4096)
|
||||
if chunk:
|
||||
new_output += chunk
|
||||
else:
|
||||
break
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Error reading terminal output: %s", e)
|
||||
break
|
||||
|
||||
self.output_buffer += new_output
|
||||
return new_output
|
||||
|
||||
def kill(self) -> None:
|
||||
"""Kill the subprocess."""
|
||||
if self.is_running:
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
self.process.wait()
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> int:
|
||||
"""Wait for the process to exit.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds.
|
||||
|
||||
Returns:
|
||||
Exit code of the process.
|
||||
|
||||
Raises:
|
||||
subprocess.TimeoutExpired: If timeout is reached.
|
||||
"""
|
||||
return self.process.wait(timeout=timeout)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermissionRequest:
|
||||
"""Represents a permission request from an agent.
|
||||
|
||||
Attributes:
|
||||
operation: The operation being requested (e.g., 'fs/read_text_file').
|
||||
path: Optional path for filesystem operations.
|
||||
command: Optional command for terminal operations.
|
||||
arguments: Full arguments dict from the request.
|
||||
"""
|
||||
|
||||
operation: str
|
||||
path: Optional[str] = None
|
||||
command: Optional[str] = None
|
||||
arguments: dict = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, params: dict) -> "PermissionRequest":
|
||||
"""Create PermissionRequest from request parameters.
|
||||
|
||||
Args:
|
||||
params: Permission request parameters from agent.
|
||||
|
||||
Returns:
|
||||
Parsed PermissionRequest instance.
|
||||
"""
|
||||
return cls(
|
||||
operation=params.get("operation", ""),
|
||||
path=params.get("path"),
|
||||
command=params.get("command"),
|
||||
arguments=params,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermissionResult:
|
||||
"""Result of a permission decision.
|
||||
|
||||
Attributes:
|
||||
approved: Whether the request was approved.
|
||||
reason: Optional reason for the decision.
|
||||
mode: Permission mode that made the decision.
|
||||
"""
|
||||
|
||||
approved: bool
|
||||
reason: Optional[str] = None
|
||||
mode: str = "unknown"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to ACP response format.
|
||||
|
||||
Returns:
|
||||
Dict with 'approved' key for ACP response.
|
||||
"""
|
||||
return {"approved": self.approved}
|
||||
|
||||
|
||||
class ACPHandlers:
|
||||
"""Handles ACP permission requests with configurable modes.
|
||||
|
||||
Supports four permission modes:
|
||||
- auto_approve: Always approve (useful for trusted environments)
|
||||
- deny_all: Always deny (useful for testing)
|
||||
- allowlist: Only approve operations matching configured patterns
|
||||
- interactive: Prompt user for each request
|
||||
|
||||
Attributes:
|
||||
permission_mode: Current permission mode.
|
||||
allowlist: List of allowed operation patterns.
|
||||
on_permission_log: Optional callback for logging decisions.
|
||||
"""
|
||||
|
||||
# Valid permission modes
|
||||
VALID_MODES = ("auto_approve", "deny_all", "allowlist", "interactive")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
permission_mode: str = "auto_approve",
|
||||
permission_allowlist: Optional[list[str]] = None,
|
||||
on_permission_log: Optional[Callable[[str], None]] = None,
|
||||
) -> None:
|
||||
"""Initialize ACPHandlers.
|
||||
|
||||
Args:
|
||||
permission_mode: Permission handling mode (default: auto_approve).
|
||||
permission_allowlist: List of allowed operation patterns for allowlist mode.
|
||||
on_permission_log: Optional callback for logging permission decisions.
|
||||
|
||||
Raises:
|
||||
ValueError: If permission_mode is not valid.
|
||||
"""
|
||||
if permission_mode not in self.VALID_MODES:
|
||||
raise ValueError(
|
||||
f"Invalid permission_mode: {permission_mode}. "
|
||||
f"Must be one of: {', '.join(self.VALID_MODES)}"
|
||||
)
|
||||
|
||||
self.permission_mode = permission_mode
|
||||
self.allowlist = permission_allowlist or []
|
||||
self.on_permission_log = on_permission_log
|
||||
|
||||
# Track permission history for debugging
|
||||
self._history: list[tuple[PermissionRequest, PermissionResult]] = []
|
||||
|
||||
# Track active terminals
|
||||
self._terminals: dict[str, Terminal] = {}
|
||||
|
||||
def handle_request_permission(self, params: dict) -> dict:
|
||||
"""Handle a permission request from an agent.
|
||||
|
||||
Returns ACP-compliant response with nested outcome structure.
|
||||
The agent handles tool execution after receiving permission.
|
||||
|
||||
Args:
|
||||
params: Permission request parameters including options list.
|
||||
|
||||
Returns:
|
||||
Dict with result.outcome.outcome (selected/cancelled) and optionId.
|
||||
"""
|
||||
request = PermissionRequest.from_params(params)
|
||||
result = self._evaluate_permission(request)
|
||||
|
||||
# Log the decision
|
||||
self._log_decision(request, result)
|
||||
|
||||
# Store in history
|
||||
self._history.append((request, result))
|
||||
|
||||
# Extract options from params to find the appropriate optionId
|
||||
options = params.get("options", [])
|
||||
|
||||
if result.approved:
|
||||
# Find first "allow" option to use as optionId
|
||||
selected_option_id = None
|
||||
for option in options:
|
||||
if option.get("type") == "allow":
|
||||
selected_option_id = option.get("id")
|
||||
break
|
||||
|
||||
# Fallback to first option if no "allow" type found
|
||||
if not selected_option_id and options:
|
||||
selected_option_id = options[0].get("id", "proceed_once")
|
||||
elif not selected_option_id:
|
||||
# Default if no options provided
|
||||
selected_option_id = "proceed_once"
|
||||
|
||||
# Return raw result (client wraps in JSON-RPC response)
|
||||
return {
|
||||
"outcome": {
|
||||
"outcome": "selected",
|
||||
"optionId": selected_option_id
|
||||
}
|
||||
}
|
||||
else:
|
||||
# Permission denied - return cancelled outcome
|
||||
return {
|
||||
"outcome": {
|
||||
"outcome": "cancelled"
|
||||
}
|
||||
}
|
||||
|
||||
def _evaluate_permission(self, request: PermissionRequest) -> PermissionResult:
|
||||
"""Evaluate a permission request based on current mode.
|
||||
|
||||
Args:
|
||||
request: The permission request to evaluate.
|
||||
|
||||
Returns:
|
||||
PermissionResult with decision and reason.
|
||||
"""
|
||||
if self.permission_mode == "auto_approve":
|
||||
return PermissionResult(
|
||||
approved=True,
|
||||
reason="auto_approve mode",
|
||||
mode="auto_approve",
|
||||
)
|
||||
|
||||
if self.permission_mode == "deny_all":
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="deny_all mode",
|
||||
mode="deny_all",
|
||||
)
|
||||
|
||||
if self.permission_mode == "allowlist":
|
||||
return self._evaluate_allowlist(request)
|
||||
|
||||
if self.permission_mode == "interactive":
|
||||
return self._evaluate_interactive(request)
|
||||
|
||||
# Fallback - should not reach here
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="unknown mode",
|
||||
mode="unknown",
|
||||
)
|
||||
|
||||
def _evaluate_allowlist(self, request: PermissionRequest) -> PermissionResult:
|
||||
"""Evaluate permission against allowlist patterns.
|
||||
|
||||
Patterns can be:
|
||||
- Exact match: 'fs/read_text_file'
|
||||
- Glob pattern: 'fs/*' (matches any fs operation)
|
||||
- Regex pattern: '/^terminal\\/.*$/' (surrounded by slashes)
|
||||
|
||||
Args:
|
||||
request: The permission request to evaluate.
|
||||
|
||||
Returns:
|
||||
PermissionResult with decision.
|
||||
"""
|
||||
operation = request.operation
|
||||
|
||||
for pattern in self.allowlist:
|
||||
if self._matches_pattern(operation, pattern):
|
||||
return PermissionResult(
|
||||
approved=True,
|
||||
reason=f"matches allowlist pattern: {pattern}",
|
||||
mode="allowlist",
|
||||
)
|
||||
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="no matching allowlist pattern",
|
||||
mode="allowlist",
|
||||
)
|
||||
|
||||
def _matches_pattern(self, operation: str, pattern: str) -> bool:
|
||||
"""Check if an operation matches a pattern.
|
||||
|
||||
Args:
|
||||
operation: The operation name to check.
|
||||
pattern: Pattern to match against.
|
||||
|
||||
Returns:
|
||||
True if operation matches pattern.
|
||||
"""
|
||||
# Check for regex pattern (surrounded by slashes)
|
||||
if pattern.startswith("/") and pattern.endswith("/"):
|
||||
try:
|
||||
regex_pattern = pattern[1:-1]
|
||||
return bool(re.match(regex_pattern, operation))
|
||||
except re.error as e:
|
||||
logger.warning("Invalid regex pattern '%s' in permission allowlist: %s", pattern, e)
|
||||
return False
|
||||
|
||||
# Check for glob pattern
|
||||
if "*" in pattern or "?" in pattern:
|
||||
return fnmatch.fnmatch(operation, pattern)
|
||||
|
||||
# Exact match
|
||||
return operation == pattern
|
||||
|
||||
def _evaluate_interactive(self, request: PermissionRequest) -> PermissionResult:
|
||||
"""Evaluate permission interactively by prompting user.
|
||||
|
||||
Falls back to deny_all if no terminal is available.
|
||||
|
||||
Args:
|
||||
request: The permission request to evaluate.
|
||||
|
||||
Returns:
|
||||
PermissionResult with user's decision.
|
||||
"""
|
||||
# Check if we have a terminal
|
||||
if not sys.stdin.isatty():
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="no terminal available for interactive mode",
|
||||
mode="interactive",
|
||||
)
|
||||
|
||||
# Format the prompt
|
||||
prompt = self._format_interactive_prompt(request)
|
||||
|
||||
try:
|
||||
print(prompt, file=sys.stderr)
|
||||
response = input("[y/N]: ").strip().lower()
|
||||
|
||||
if response in ("y", "yes"):
|
||||
return PermissionResult(
|
||||
approved=True,
|
||||
reason="user approved",
|
||||
mode="interactive",
|
||||
)
|
||||
else:
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="user denied",
|
||||
mode="interactive",
|
||||
)
|
||||
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return PermissionResult(
|
||||
approved=False,
|
||||
reason="input interrupted",
|
||||
mode="interactive",
|
||||
)
|
||||
|
||||
def _format_interactive_prompt(self, request: PermissionRequest) -> str:
|
||||
"""Format an interactive permission prompt.
|
||||
|
||||
Args:
|
||||
request: The permission request to display.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string.
|
||||
"""
|
||||
lines = [
|
||||
"",
|
||||
"=" * 60,
|
||||
f"Permission Request: {request.operation}",
|
||||
"=" * 60,
|
||||
]
|
||||
|
||||
if request.path:
|
||||
lines.append(f" Path: {request.path}")
|
||||
if request.command:
|
||||
lines.append(f" Command: {request.command}")
|
||||
|
||||
# Add other arguments
|
||||
for key, value in request.arguments.items():
|
||||
if key not in ("operation", "path", "command"):
|
||||
lines.append(f" {key}: {value}")
|
||||
|
||||
lines.extend([
|
||||
"=" * 60,
|
||||
"Approve this operation?",
|
||||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _log_decision(
|
||||
self, request: PermissionRequest, result: PermissionResult
|
||||
) -> None:
|
||||
"""Log a permission decision.
|
||||
|
||||
Args:
|
||||
request: The permission request.
|
||||
result: The permission decision.
|
||||
"""
|
||||
if self.on_permission_log:
|
||||
status = "APPROVED" if result.approved else "DENIED"
|
||||
message = (
|
||||
f"Permission {status}: {request.operation} "
|
||||
f"[mode={result.mode}, reason={result.reason}]"
|
||||
)
|
||||
self.on_permission_log(message)
|
||||
|
||||
def get_history(self) -> list[tuple[PermissionRequest, PermissionResult]]:
|
||||
"""Get permission decision history.
|
||||
|
||||
Returns:
|
||||
List of (request, result) tuples.
|
||||
"""
|
||||
return self._history.copy()
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear permission decision history."""
|
||||
self._history.clear()
|
||||
|
||||
def get_approved_count(self) -> int:
|
||||
"""Get count of approved permissions.
|
||||
|
||||
Returns:
|
||||
Number of approved permission requests.
|
||||
"""
|
||||
return sum(1 for _, result in self._history if result.approved)
|
||||
|
||||
def get_denied_count(self) -> int:
|
||||
"""Get count of denied permissions.
|
||||
|
||||
Returns:
|
||||
Number of denied permission requests.
|
||||
"""
|
||||
return sum(1 for _, result in self._history if not result.approved)
|
||||
|
||||
# =========================================================================
|
||||
# File Operation Handlers
|
||||
# =========================================================================
|
||||
|
||||
def handle_read_file(self, params: dict) -> dict:
|
||||
"""Handle fs/read_text_file request from agent.
|
||||
|
||||
Reads file content with security validation to prevent path traversal.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'path' key.
|
||||
|
||||
Returns:
|
||||
Dict with 'content' on success, or 'error' on failure.
|
||||
"""
|
||||
path_str = params.get("path")
|
||||
|
||||
if not path_str:
|
||||
return {"error": {"code": -32602, "message": "Missing required parameter: path"}}
|
||||
|
||||
try:
|
||||
# Resolve the path
|
||||
path = Path(path_str)
|
||||
|
||||
# Security: require absolute path
|
||||
if not path.is_absolute():
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": f"Path must be absolute: {path_str}",
|
||||
}
|
||||
}
|
||||
|
||||
# Resolve symlinks and normalize
|
||||
resolved_path = path.resolve()
|
||||
|
||||
# Check if file exists - return null content for non-existent files
|
||||
# (this allows agents to check file existence without error)
|
||||
if not resolved_path.exists():
|
||||
return {"content": None, "exists": False}
|
||||
|
||||
# Check if it's a file (not directory)
|
||||
if not resolved_path.is_file():
|
||||
return {
|
||||
"error": {
|
||||
"code": -32002,
|
||||
"message": f"Path is not a file: {path_str}",
|
||||
}
|
||||
}
|
||||
|
||||
# Read file content
|
||||
content = resolved_path.read_text(encoding="utf-8")
|
||||
|
||||
return {"content": content}
|
||||
|
||||
except PermissionError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32003,
|
||||
"message": f"Permission denied: {path_str}",
|
||||
}
|
||||
}
|
||||
except UnicodeDecodeError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32004,
|
||||
"message": f"File is not valid UTF-8 text: {path_str}",
|
||||
}
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": f"Failed to read file: {e}",
|
||||
}
|
||||
}
|
||||
|
||||
def handle_write_file(self, params: dict) -> dict:
|
||||
"""Handle fs/write_text_file request from agent.
|
||||
|
||||
Writes content to file with security validation.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'path' and 'content' keys.
|
||||
|
||||
Returns:
|
||||
Dict with 'success: True' on success, or 'error' on failure.
|
||||
"""
|
||||
path_str = params.get("path")
|
||||
content = params.get("content")
|
||||
|
||||
if not path_str:
|
||||
return {"error": {"code": -32602, "message": "Missing required parameter: path"}}
|
||||
|
||||
if content is None:
|
||||
return {"error": {"code": -32602, "message": "Missing required parameter: content"}}
|
||||
|
||||
try:
|
||||
# Resolve the path
|
||||
path = Path(path_str)
|
||||
|
||||
# Security: require absolute path
|
||||
if not path.is_absolute():
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": f"Path must be absolute: {path_str}",
|
||||
}
|
||||
}
|
||||
|
||||
# Resolve symlinks and normalize
|
||||
resolved_path = path.resolve()
|
||||
|
||||
# Check if path exists and is a directory (can't write to directory)
|
||||
if resolved_path.exists() and resolved_path.is_dir():
|
||||
return {
|
||||
"error": {
|
||||
"code": -32002,
|
||||
"message": f"Path is a directory: {path_str}",
|
||||
}
|
||||
}
|
||||
|
||||
# Create parent directories if needed
|
||||
resolved_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write file content
|
||||
resolved_path.write_text(content, encoding="utf-8")
|
||||
|
||||
return {"success": True}
|
||||
|
||||
except PermissionError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32003,
|
||||
"message": f"Permission denied: {path_str}",
|
||||
}
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": f"Failed to write file: {e}",
|
||||
}
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Terminal Operation Handlers
|
||||
# =========================================================================
|
||||
|
||||
def handle_terminal_create(self, params: dict) -> dict:
|
||||
"""Handle terminal/create request from agent.
|
||||
|
||||
Creates a new terminal subprocess for command execution.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'command' (list of strings) and
|
||||
optional 'cwd' (working directory).
|
||||
|
||||
Returns:
|
||||
Dict with 'terminalId' on success, or 'error' on failure.
|
||||
"""
|
||||
command = params.get("command")
|
||||
|
||||
if command is None:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: command",
|
||||
}
|
||||
}
|
||||
|
||||
if not isinstance(command, list):
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "command must be a list of strings",
|
||||
}
|
||||
}
|
||||
|
||||
if len(command) == 0:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "command list cannot be empty",
|
||||
}
|
||||
}
|
||||
|
||||
cwd = params.get("cwd")
|
||||
|
||||
try:
|
||||
# Create subprocess with pipes for stdout/stderr
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
stdin=subprocess.DEVNULL,
|
||||
cwd=cwd,
|
||||
text=True,
|
||||
bufsize=0,
|
||||
)
|
||||
|
||||
# Generate unique terminal ID
|
||||
terminal_id = str(uuid.uuid4())
|
||||
|
||||
# Create terminal instance
|
||||
terminal = Terminal(id=terminal_id, process=process)
|
||||
self._terminals[terminal_id] = terminal
|
||||
|
||||
return {"terminalId": terminal_id}
|
||||
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Command not found: {command[0]}",
|
||||
}
|
||||
}
|
||||
except PermissionError:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32003,
|
||||
"message": f"Permission denied executing: {command[0]}",
|
||||
}
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": f"Failed to create terminal: {e}",
|
||||
}
|
||||
}
|
||||
|
||||
def handle_terminal_output(self, params: dict) -> dict:
|
||||
"""Handle terminal/output request from agent.
|
||||
|
||||
Reads available output from a terminal.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'terminalId'.
|
||||
|
||||
Returns:
|
||||
Dict with 'output' and 'done' on success, or 'error' on failure.
|
||||
"""
|
||||
terminal_id = params.get("terminalId")
|
||||
|
||||
if not terminal_id:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: terminalId",
|
||||
}
|
||||
}
|
||||
|
||||
terminal = self._terminals.get(terminal_id)
|
||||
if not terminal:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Terminal not found: {terminal_id}",
|
||||
}
|
||||
}
|
||||
|
||||
# Read any new output
|
||||
terminal.read_output()
|
||||
|
||||
return {
|
||||
"output": terminal.output_buffer,
|
||||
"done": not terminal.is_running,
|
||||
}
|
||||
|
||||
def handle_terminal_wait_for_exit(self, params: dict) -> dict:
|
||||
"""Handle terminal/wait_for_exit request from agent.
|
||||
|
||||
Waits for a terminal process to exit.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'terminalId' and optional 'timeout'.
|
||||
|
||||
Returns:
|
||||
Dict with 'exitCode' on success, or 'error' on failure/timeout.
|
||||
"""
|
||||
terminal_id = params.get("terminalId")
|
||||
|
||||
if not terminal_id:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: terminalId",
|
||||
}
|
||||
}
|
||||
|
||||
terminal = self._terminals.get(terminal_id)
|
||||
if not terminal:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Terminal not found: {terminal_id}",
|
||||
}
|
||||
}
|
||||
|
||||
timeout = params.get("timeout")
|
||||
|
||||
try:
|
||||
exit_code = terminal.wait(timeout=timeout)
|
||||
# Read any remaining output
|
||||
terminal.read_output()
|
||||
return {"exitCode": exit_code}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32000,
|
||||
"message": f"Wait timed out after {timeout}s",
|
||||
}
|
||||
}
|
||||
|
||||
def handle_terminal_kill(self, params: dict) -> dict:
|
||||
"""Handle terminal/kill request from agent.
|
||||
|
||||
Kills a terminal process.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'terminalId'.
|
||||
|
||||
Returns:
|
||||
Dict with 'success: True' on success, or 'error' on failure.
|
||||
"""
|
||||
terminal_id = params.get("terminalId")
|
||||
|
||||
if not terminal_id:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: terminalId",
|
||||
}
|
||||
}
|
||||
|
||||
terminal = self._terminals.get(terminal_id)
|
||||
if not terminal:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Terminal not found: {terminal_id}",
|
||||
}
|
||||
}
|
||||
|
||||
terminal.kill()
|
||||
return {"success": True}
|
||||
|
||||
def handle_terminal_release(self, params: dict) -> dict:
|
||||
"""Handle terminal/release request from agent.
|
||||
|
||||
Releases terminal resources, killing the process if still running.
|
||||
|
||||
Args:
|
||||
params: Request parameters with 'terminalId'.
|
||||
|
||||
Returns:
|
||||
Dict with 'success: True' on success, or 'error' on failure.
|
||||
"""
|
||||
terminal_id = params.get("terminalId")
|
||||
|
||||
if not terminal_id:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32602,
|
||||
"message": "Missing required parameter: terminalId",
|
||||
}
|
||||
}
|
||||
|
||||
terminal = self._terminals.get(terminal_id)
|
||||
if not terminal:
|
||||
return {
|
||||
"error": {
|
||||
"code": -32001,
|
||||
"message": f"Terminal not found: {terminal_id}",
|
||||
}
|
||||
}
|
||||
|
||||
# Kill the process if still running before releasing
|
||||
if terminal.is_running:
|
||||
try:
|
||||
terminal.kill()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to kill terminal %s during release: %s", terminal_id, e)
|
||||
|
||||
# Clean up the terminal
|
||||
del self._terminals[terminal_id]
|
||||
return {"success": True}
|
||||
@@ -0,0 +1,536 @@
|
||||
# ABOUTME: Typed dataclasses for ACP (Agent Client Protocol) messages
|
||||
# ABOUTME: Defines request/response types, session state, and configuration
|
||||
|
||||
"""Typed data models for ACP messages and session state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ralph_orchestrator.main import AdapterConfig
|
||||
|
||||
# Get logger for this module
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateKind(str, Enum):
|
||||
"""Types of session updates."""
|
||||
|
||||
AGENT_MESSAGE_CHUNK = "agent_message_chunk"
|
||||
AGENT_THOUGHT_CHUNK = "agent_thought_chunk"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_CALL_UPDATE = "tool_call_update"
|
||||
PLAN = "plan"
|
||||
|
||||
|
||||
class ToolCallStatus(str, Enum):
|
||||
"""Status of a tool call."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PermissionMode(str, Enum):
|
||||
"""Permission modes for ACP operations."""
|
||||
|
||||
AUTO_APPROVE = "auto_approve"
|
||||
DENY_ALL = "deny_all"
|
||||
ALLOWLIST = "allowlist"
|
||||
INTERACTIVE = "interactive"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPRequest:
|
||||
"""JSON-RPC 2.0 request message.
|
||||
|
||||
Attributes:
|
||||
id: Request identifier for matching response.
|
||||
method: The RPC method to invoke.
|
||||
params: Method parameters.
|
||||
"""
|
||||
|
||||
id: int
|
||||
method: str
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPRequest":
|
||||
"""Parse ACPRequest from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with id, method, and optional params.
|
||||
|
||||
Returns:
|
||||
ACPRequest instance.
|
||||
|
||||
Raises:
|
||||
KeyError: If id or method is missing.
|
||||
"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
method=data["method"],
|
||||
params=data.get("params", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPNotification:
|
||||
"""JSON-RPC 2.0 notification (no response expected).
|
||||
|
||||
Attributes:
|
||||
method: The notification method.
|
||||
params: Method parameters.
|
||||
"""
|
||||
|
||||
method: str
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPNotification":
|
||||
"""Parse ACPNotification from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with method and optional params.
|
||||
|
||||
Returns:
|
||||
ACPNotification instance.
|
||||
"""
|
||||
return cls(
|
||||
method=data["method"],
|
||||
params=data.get("params", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPResponse:
|
||||
"""JSON-RPC 2.0 success response.
|
||||
|
||||
Attributes:
|
||||
id: Request identifier this response matches.
|
||||
result: The response result data.
|
||||
"""
|
||||
|
||||
id: int
|
||||
result: Any
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPResponse":
|
||||
"""Parse ACPResponse from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with id and result.
|
||||
|
||||
Returns:
|
||||
ACPResponse instance.
|
||||
"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
result=data.get("result"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPErrorObject:
|
||||
"""JSON-RPC 2.0 error object.
|
||||
|
||||
Attributes:
|
||||
code: Error code (negative integers for standard errors).
|
||||
message: Human-readable error message.
|
||||
data: Optional additional error data.
|
||||
"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: Any = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPErrorObject":
|
||||
"""Parse ACPErrorObject from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with code, message, and optional data.
|
||||
|
||||
Returns:
|
||||
ACPErrorObject instance.
|
||||
"""
|
||||
return cls(
|
||||
code=data["code"],
|
||||
message=data["message"],
|
||||
data=data.get("data"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPError:
|
||||
"""JSON-RPC 2.0 error response.
|
||||
|
||||
Attributes:
|
||||
id: Request identifier this error matches.
|
||||
error: The error object with code and message.
|
||||
"""
|
||||
|
||||
id: int
|
||||
error: ACPErrorObject
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPError":
|
||||
"""Parse ACPError from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with id and error object.
|
||||
|
||||
Returns:
|
||||
ACPError instance.
|
||||
"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
error=ACPErrorObject.from_dict(data["error"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdatePayload:
|
||||
"""Payload for session/update notifications.
|
||||
|
||||
Handles different update kinds:
|
||||
- agent_message_chunk: Text output from agent
|
||||
- agent_thought_chunk: Internal reasoning (verbose mode)
|
||||
- tool_call: Agent requesting tool execution
|
||||
- tool_call_update: Status update for tool call
|
||||
- plan: Agent's execution plan
|
||||
|
||||
Attributes:
|
||||
kind: The update type (see UpdateKind enum for valid values).
|
||||
content: Text content (for message/thought chunks).
|
||||
tool_name: Name of tool being called.
|
||||
tool_call_id: Unique identifier for tool call.
|
||||
arguments: Tool call arguments.
|
||||
status: Tool call status (see ToolCallStatus enum for valid values).
|
||||
result: Tool call result data.
|
||||
error: Tool call error message.
|
||||
"""
|
||||
|
||||
kind: str # Valid values: UpdateKind enum members
|
||||
content: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
arguments: Optional[dict[str, Any]] = None
|
||||
status: Optional[str] = None
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "UpdatePayload":
|
||||
"""Parse UpdatePayload from dict.
|
||||
|
||||
Handles camelCase to snake_case conversion for ACP fields.
|
||||
|
||||
Args:
|
||||
data: Dict with kind and kind-specific fields.
|
||||
|
||||
Returns:
|
||||
UpdatePayload instance.
|
||||
"""
|
||||
kind = data["kind"]
|
||||
tool_name = data.get("toolName")
|
||||
tool_call_id = data.get("toolCallId")
|
||||
arguments = data.get("arguments")
|
||||
if kind in (UpdateKind.TOOL_CALL, UpdateKind.TOOL_CALL_UPDATE):
|
||||
if tool_name is None:
|
||||
tool_name = data.get("tool_name") or data.get("name")
|
||||
if tool_name is None and isinstance(data.get("tool"), str):
|
||||
tool_name = data.get("tool")
|
||||
if tool_call_id is None:
|
||||
tool_call_id = data.get("tool_call_id") or data.get("id")
|
||||
if kind == UpdateKind.TOOL_CALL and arguments is None:
|
||||
arguments = data.get("args") or data.get("parameters") or data.get("params")
|
||||
|
||||
return cls(
|
||||
kind=kind,
|
||||
content=data.get("content"),
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
status=data.get("status"),
|
||||
result=data.get("result"),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionUpdate:
|
||||
"""Wrapper for session/update notification.
|
||||
|
||||
Attributes:
|
||||
method: Should be "session/update".
|
||||
payload: The update payload.
|
||||
"""
|
||||
|
||||
method: str
|
||||
payload: UpdatePayload
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionUpdate":
|
||||
"""Parse SessionUpdate from notification dict.
|
||||
|
||||
Args:
|
||||
data: Dict with method and params.
|
||||
|
||||
Returns:
|
||||
SessionUpdate instance.
|
||||
"""
|
||||
return cls(
|
||||
method=data["method"],
|
||||
payload=UpdatePayload.from_dict(data["params"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""Tracks a tool execution within a session.
|
||||
|
||||
Attributes:
|
||||
tool_call_id: Unique identifier for this call.
|
||||
tool_name: Name of the tool being called.
|
||||
arguments: Arguments passed to the tool.
|
||||
status: Current status (pending/running/completed/failed).
|
||||
result: Result data if completed.
|
||||
error: Error message if failed.
|
||||
"""
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
arguments: dict[str, Any]
|
||||
status: str = "pending"
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolCall":
|
||||
"""Parse ToolCall from dict.
|
||||
|
||||
Args:
|
||||
data: Dict with toolCallId, toolName, arguments.
|
||||
|
||||
Returns:
|
||||
ToolCall instance.
|
||||
"""
|
||||
return cls(
|
||||
tool_call_id=data["toolCallId"],
|
||||
tool_name=data["toolName"],
|
||||
arguments=data.get("arguments", {}),
|
||||
status=data.get("status", "pending"),
|
||||
result=data.get("result"),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPSession:
|
||||
"""Accumulates session state during prompt execution.
|
||||
|
||||
Tracks output chunks, thoughts, and tool calls for building
|
||||
the final response.
|
||||
|
||||
Attributes:
|
||||
session_id: Unique session identifier.
|
||||
output: Accumulated agent output text.
|
||||
thoughts: Accumulated agent thoughts (verbose).
|
||||
tool_calls: List of tool calls in this session.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
output: str = ""
|
||||
thoughts: str = ""
|
||||
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||
|
||||
def append_output(self, text: str) -> None:
|
||||
"""Append text to accumulated output.
|
||||
|
||||
Args:
|
||||
text: Text chunk to append.
|
||||
"""
|
||||
self.output += text
|
||||
|
||||
def append_thought(self, text: str) -> None:
|
||||
"""Append text to accumulated thoughts.
|
||||
|
||||
Args:
|
||||
text: Thought chunk to append.
|
||||
"""
|
||||
self.thoughts += text
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCall) -> None:
|
||||
"""Add a new tool call to track.
|
||||
|
||||
Args:
|
||||
tool_call: The ToolCall to track.
|
||||
"""
|
||||
self.tool_calls.append(tool_call)
|
||||
|
||||
def get_tool_call(self, tool_call_id: str) -> Optional[ToolCall]:
|
||||
"""Find a tool call by ID.
|
||||
|
||||
Args:
|
||||
tool_call_id: The ID to look up.
|
||||
|
||||
Returns:
|
||||
ToolCall if found, None otherwise.
|
||||
"""
|
||||
for tc in self.tool_calls:
|
||||
if tc.tool_call_id == tool_call_id:
|
||||
return tc
|
||||
return None
|
||||
|
||||
def process_update(self, payload: UpdatePayload) -> None:
|
||||
"""Process a session update payload.
|
||||
|
||||
Routes update to appropriate handler based on kind.
|
||||
|
||||
Args:
|
||||
payload: The update payload to process.
|
||||
"""
|
||||
if payload.kind == "agent_message_chunk":
|
||||
if payload.content:
|
||||
self.append_output(payload.content)
|
||||
elif payload.kind == "agent_thought_chunk":
|
||||
if payload.content:
|
||||
self.append_thought(payload.content)
|
||||
elif payload.kind == "tool_call":
|
||||
tool_call = ToolCall(
|
||||
tool_call_id=payload.tool_call_id or "",
|
||||
tool_name=payload.tool_name or "",
|
||||
arguments=payload.arguments or {},
|
||||
)
|
||||
self.add_tool_call(tool_call)
|
||||
elif payload.kind == "tool_call_update":
|
||||
if payload.tool_call_id:
|
||||
tc = self.get_tool_call(payload.tool_call_id)
|
||||
if tc:
|
||||
if payload.status:
|
||||
tc.status = payload.status
|
||||
if payload.result is not None:
|
||||
tc.result = payload.result
|
||||
if payload.error:
|
||||
tc.error = payload.error
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset session state for a new prompt.
|
||||
|
||||
Preserves session_id but clears accumulated data.
|
||||
"""
|
||||
self.output = ""
|
||||
self.thoughts = ""
|
||||
self.tool_calls = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class ACPAdapterConfig:
|
||||
"""Configuration for the ACP adapter.
|
||||
|
||||
Attributes:
|
||||
agent_command: Command to spawn the agent (default: gemini).
|
||||
agent_args: Additional arguments for agent command.
|
||||
timeout: Request timeout in seconds.
|
||||
permission_mode: How to handle permission requests.
|
||||
- auto_approve: Approve all requests.
|
||||
- deny_all: Deny all requests.
|
||||
- allowlist: Check against permission_allowlist.
|
||||
- interactive: Prompt user for each request.
|
||||
permission_allowlist: Patterns to allow in allowlist mode.
|
||||
"""
|
||||
|
||||
agent_command: str = "gemini"
|
||||
agent_args: list[str] = field(default_factory=list)
|
||||
timeout: int = 300
|
||||
permission_mode: str = "auto_approve"
|
||||
permission_allowlist: list[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ACPAdapterConfig":
|
||||
"""Parse ACPAdapterConfig from dict.
|
||||
|
||||
Uses defaults for missing keys.
|
||||
|
||||
Args:
|
||||
data: Configuration dict.
|
||||
|
||||
Returns:
|
||||
ACPAdapterConfig instance.
|
||||
"""
|
||||
return cls(
|
||||
agent_command=data.get("agent_command", "gemini"),
|
||||
agent_args=data.get("agent_args", []),
|
||||
timeout=data.get("timeout", 300),
|
||||
permission_mode=data.get("permission_mode", "auto_approve"),
|
||||
permission_allowlist=data.get("permission_allowlist", []),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_adapter_config(cls, adapter_config: "AdapterConfig") -> "ACPAdapterConfig":
|
||||
"""Create ACPAdapterConfig from AdapterConfig with env var overrides.
|
||||
|
||||
Extracts ACP-specific settings from AdapterConfig.tool_permissions
|
||||
and applies environment variable overrides.
|
||||
|
||||
Environment Variables:
|
||||
RALPH_ACP_AGENT: Override agent_command
|
||||
RALPH_ACP_PERMISSION_MODE: Override permission_mode
|
||||
RALPH_ACP_TIMEOUT: Override timeout (integer)
|
||||
|
||||
Args:
|
||||
adapter_config: General adapter configuration.
|
||||
|
||||
Returns:
|
||||
ACPAdapterConfig with ACP-specific settings.
|
||||
"""
|
||||
# Start with tool_permissions or empty dict
|
||||
tool_perms = adapter_config.tool_permissions or {}
|
||||
|
||||
# Get base values from tool_permissions
|
||||
agent_command = tool_perms.get("agent_command", "gemini")
|
||||
agent_args = tool_perms.get("agent_args", [])
|
||||
timeout = tool_perms.get("timeout", adapter_config.timeout)
|
||||
permission_mode = tool_perms.get("permission_mode", "auto_approve")
|
||||
permission_allowlist = tool_perms.get("permission_allowlist", [])
|
||||
|
||||
# Apply environment variable overrides
|
||||
if env_agent := os.environ.get("RALPH_ACP_AGENT"):
|
||||
agent_command = env_agent
|
||||
|
||||
if env_mode := os.environ.get("RALPH_ACP_PERMISSION_MODE"):
|
||||
valid_modes = {"auto_approve", "deny_all", "allowlist", "interactive"}
|
||||
if env_mode in valid_modes:
|
||||
permission_mode = env_mode
|
||||
else:
|
||||
_logger.warning(
|
||||
"Invalid RALPH_ACP_PERMISSION_MODE value '%s'. Valid modes: %s. Using default: %s",
|
||||
env_mode,
|
||||
", ".join(valid_modes),
|
||||
permission_mode,
|
||||
)
|
||||
|
||||
if env_timeout := os.environ.get("RALPH_ACP_TIMEOUT"):
|
||||
try:
|
||||
timeout = int(env_timeout)
|
||||
except ValueError:
|
||||
_logger.warning(
|
||||
"Invalid RALPH_ACP_TIMEOUT value '%s' - must be integer. Using default: %d",
|
||||
env_timeout,
|
||||
timeout,
|
||||
)
|
||||
|
||||
return cls(
|
||||
agent_command=agent_command,
|
||||
agent_args=agent_args,
|
||||
timeout=timeout,
|
||||
permission_mode=permission_mode,
|
||||
permission_allowlist=permission_allowlist,
|
||||
)
|
||||
@@ -0,0 +1,214 @@
|
||||
# ABOUTME: JSON-RPC 2.0 protocol handler for ACP (Agent Client Protocol)
|
||||
# ABOUTME: Handles message serialization, parsing, and protocol state
|
||||
|
||||
"""JSON-RPC 2.0 protocol handling for ACP."""
|
||||
|
||||
import json
|
||||
from enum import Enum, auto
|
||||
from typing import Any
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of JSON-RPC 2.0 messages."""
|
||||
|
||||
REQUEST = auto() # Has id and method
|
||||
NOTIFICATION = auto() # Has method but no id
|
||||
RESPONSE = auto() # Has id and result
|
||||
ERROR = auto() # Has id and error
|
||||
PARSE_ERROR = auto() # Failed to parse JSON
|
||||
INVALID = auto() # Invalid JSON-RPC message
|
||||
|
||||
|
||||
class ACPErrorCodes:
|
||||
"""Standard JSON-RPC 2.0 and ACP-specific error codes."""
|
||||
|
||||
# Standard JSON-RPC 2.0 error codes
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
|
||||
# ACP-specific error codes
|
||||
PERMISSION_DENIED = -32001
|
||||
FILE_NOT_FOUND = -32002
|
||||
FILE_ACCESS_ERROR = -32003
|
||||
TERMINAL_ERROR = -32004
|
||||
|
||||
|
||||
class ACPProtocol:
|
||||
"""JSON-RPC 2.0 protocol handler for ACP.
|
||||
|
||||
Handles serialization and deserialization of JSON-RPC messages
|
||||
for the Agent Client Protocol.
|
||||
|
||||
Attributes:
|
||||
_request_id: Auto-incrementing request ID counter.
|
||||
"""
|
||||
|
||||
JSONRPC_VERSION = "2.0"
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize protocol handler with request ID counter at 0."""
|
||||
self._request_id: int = 0
|
||||
|
||||
def create_request(self, method: str, params: dict[str, Any]) -> tuple[int, str]:
|
||||
"""Create a JSON-RPC 2.0 request message.
|
||||
|
||||
Args:
|
||||
method: The RPC method name (e.g., "session/prompt").
|
||||
params: The request parameters.
|
||||
|
||||
Returns:
|
||||
Tuple of (request_id, json_string) for tracking and sending.
|
||||
"""
|
||||
self._request_id += 1
|
||||
request_id = self._request_id
|
||||
|
||||
message = {
|
||||
"jsonrpc": self.JSONRPC_VERSION,
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
return request_id, json.dumps(message)
|
||||
|
||||
def create_notification(self, method: str, params: dict[str, Any]) -> str:
|
||||
"""Create a JSON-RPC 2.0 notification message (no id, no response expected).
|
||||
|
||||
Args:
|
||||
method: The RPC method name.
|
||||
params: The notification parameters.
|
||||
|
||||
Returns:
|
||||
JSON string of the notification.
|
||||
"""
|
||||
message = {
|
||||
"jsonrpc": self.JSONRPC_VERSION,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
return json.dumps(message)
|
||||
|
||||
def parse_message(self, data: str) -> dict[str, Any]:
|
||||
"""Parse an incoming JSON-RPC 2.0 message.
|
||||
|
||||
Determines the message type and validates structure.
|
||||
|
||||
Args:
|
||||
data: Raw JSON string to parse.
|
||||
|
||||
Returns:
|
||||
Dict with 'type' key indicating MessageType and parsed fields.
|
||||
On error, includes 'error' key with description.
|
||||
"""
|
||||
# Try to parse JSON
|
||||
try:
|
||||
message = json.loads(data)
|
||||
except json.JSONDecodeError as e:
|
||||
return {
|
||||
"type": MessageType.PARSE_ERROR,
|
||||
"error": f"JSON parse error: {e}",
|
||||
}
|
||||
|
||||
# Validate jsonrpc version field
|
||||
if message.get("jsonrpc") != self.JSONRPC_VERSION:
|
||||
return {
|
||||
"type": MessageType.INVALID,
|
||||
"error": f"Invalid or missing jsonrpc field. Expected '2.0', got '{message.get('jsonrpc')}'",
|
||||
}
|
||||
|
||||
# Determine message type based on fields
|
||||
has_id = "id" in message
|
||||
has_method = "method" in message
|
||||
has_result = "result" in message
|
||||
has_error = "error" in message
|
||||
|
||||
if has_error and has_id:
|
||||
# Error response
|
||||
return {
|
||||
"type": MessageType.ERROR,
|
||||
"id": message["id"],
|
||||
"error": message["error"],
|
||||
}
|
||||
elif has_result and has_id:
|
||||
# Success response
|
||||
return {
|
||||
"type": MessageType.RESPONSE,
|
||||
"id": message["id"],
|
||||
"result": message["result"],
|
||||
}
|
||||
elif has_method and has_id:
|
||||
# Request (has id, expects response)
|
||||
return {
|
||||
"type": MessageType.REQUEST,
|
||||
"id": message["id"],
|
||||
"method": message["method"],
|
||||
"params": message.get("params", {}),
|
||||
}
|
||||
elif has_method and not has_id:
|
||||
# Notification (no id, no response expected)
|
||||
return {
|
||||
"type": MessageType.NOTIFICATION,
|
||||
"method": message["method"],
|
||||
"params": message.get("params", {}),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": MessageType.INVALID,
|
||||
"error": "Invalid JSON-RPC message structure",
|
||||
}
|
||||
|
||||
def create_response(self, request_id: int, result: Any) -> str:
|
||||
"""Create a JSON-RPC 2.0 success response.
|
||||
|
||||
Args:
|
||||
request_id: The ID from the original request.
|
||||
result: The result data to return.
|
||||
|
||||
Returns:
|
||||
JSON string of the response.
|
||||
"""
|
||||
message = {
|
||||
"jsonrpc": self.JSONRPC_VERSION,
|
||||
"id": request_id,
|
||||
"result": result,
|
||||
}
|
||||
|
||||
return json.dumps(message)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
request_id: int,
|
||||
code: int,
|
||||
message: str,
|
||||
data: Any = None,
|
||||
) -> str:
|
||||
"""Create a JSON-RPC 2.0 error response.
|
||||
|
||||
Args:
|
||||
request_id: The ID from the original request.
|
||||
code: The error code (use ACPErrorCodes constants).
|
||||
message: Human-readable error message.
|
||||
data: Optional additional error data.
|
||||
|
||||
Returns:
|
||||
JSON string of the error response.
|
||||
"""
|
||||
error_obj: dict[str, Any] = {
|
||||
"code": code,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if data is not None:
|
||||
error_obj["data"] = data
|
||||
|
||||
response = {
|
||||
"jsonrpc": self.JSONRPC_VERSION,
|
||||
"id": request_id,
|
||||
"error": error_obj,
|
||||
}
|
||||
|
||||
return json.dumps(response)
|
||||
@@ -0,0 +1,189 @@
|
||||
# ABOUTME: Abstract base class for tool adapters
|
||||
# ABOUTME: Defines the interface all tool adapters must implement
|
||||
|
||||
"""Base adapter interface for AI tools."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResponse:
|
||||
"""Response from a tool execution."""
|
||||
|
||||
success: bool
|
||||
output: str
|
||||
error: Optional[str] = None
|
||||
tokens_used: Optional[int] = None
|
||||
cost: Optional[float] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
class ToolAdapter(ABC):
|
||||
"""Abstract base class for tool adapters."""
|
||||
|
||||
def __init__(self, name: str, config=None):
|
||||
self.name = name
|
||||
self.config = config or type('Config', (), {
|
||||
'enabled': True, 'timeout': 300, 'max_retries': 3,
|
||||
'args': [], 'env': {}
|
||||
})()
|
||||
self.completion_promise: Optional[str] = None
|
||||
self.available = self.check_availability()
|
||||
|
||||
@abstractmethod
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if the tool is available and properly configured."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute the tool with the given prompt."""
|
||||
pass
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Async execute the tool with the given prompt.
|
||||
|
||||
Default implementation runs sync execute in thread pool.
|
||||
Subclasses can override for native async support.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
# Create a function that can be called with no arguments for run_in_executor
|
||||
def execute_with_args():
|
||||
return self.execute(prompt, **kwargs)
|
||||
return await loop.run_in_executor(None, execute_with_args)
|
||||
|
||||
def execute_with_file(self, prompt_file: Path, **kwargs) -> ToolResponse:
|
||||
"""Execute the tool with a prompt file."""
|
||||
if not prompt_file.exists():
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"Prompt file {prompt_file} not found"
|
||||
)
|
||||
|
||||
with open(prompt_file, 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
return self.execute(prompt, **kwargs)
|
||||
|
||||
async def aexecute_with_file(self, prompt_file: Path, **kwargs) -> ToolResponse:
|
||||
"""Async execute the tool with a prompt file."""
|
||||
if not prompt_file.exists():
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"Prompt file {prompt_file} not found"
|
||||
)
|
||||
|
||||
# Use asyncio.to_thread to avoid blocking the event loop with file I/O
|
||||
prompt = await asyncio.to_thread(prompt_file.read_text, encoding='utf-8')
|
||||
|
||||
return await self.aexecute(prompt, **kwargs)
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Estimate the cost of executing this prompt."""
|
||||
# Default implementation - subclasses can override
|
||||
return 0.0
|
||||
|
||||
def _enhance_prompt_with_instructions(self, prompt: str, completion_promise: Optional[str] = None) -> str:
|
||||
"""Enhance prompt with orchestration context and instructions.
|
||||
|
||||
Args:
|
||||
prompt: The original prompt
|
||||
completion_promise: Optional string that signals task completion.
|
||||
If not provided, uses self.completion_promise.
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with orchestration instructions
|
||||
"""
|
||||
# Resolve completion promise: argument takes precedence, then instance state
|
||||
promise = completion_promise or self.completion_promise
|
||||
|
||||
# Check if instructions already exist in the prompt
|
||||
instruction_markers = [
|
||||
"ORCHESTRATION CONTEXT:",
|
||||
"IMPORTANT INSTRUCTIONS:",
|
||||
"Implement only ONE small, focused task"
|
||||
]
|
||||
|
||||
# If any marker exists, assume instructions are already present
|
||||
instructions_present = False
|
||||
for marker in instruction_markers:
|
||||
if marker in prompt:
|
||||
instructions_present = True
|
||||
break
|
||||
|
||||
enhanced_prompt = prompt
|
||||
|
||||
if not instructions_present:
|
||||
# Add orchestration context and instructions
|
||||
orchestration_instructions = """
|
||||
ORCHESTRATION CONTEXT:
|
||||
You are running within the Ralph Orchestrator loop. This system will call you repeatedly
|
||||
for multiple iterations until the overall task is complete. Each iteration is a separate
|
||||
execution where you should make incremental progress.
|
||||
|
||||
The final output must be well-tested, documented, and production ready.
|
||||
|
||||
IMPORTANT INSTRUCTIONS:
|
||||
1. Implement only ONE small, focused task from this prompt per iteration.
|
||||
- Each iteration is independent - focus on a single atomic change
|
||||
- The orchestrator will handle calling you again for the next task
|
||||
- Mark subtasks complete as you finish them
|
||||
- You must commit your changes after each iteration, for checkpointing.
|
||||
2. Use the .agent/workspace/ directory for any temporary files or workspaces if not already instructed in the prompt.
|
||||
3. Follow this workflow for implementing features:
|
||||
- Explore: Research and understand the codebase
|
||||
- Plan: Design your implementation approach
|
||||
- Implement: Use Test-Driven Development (TDD) - write tests first, then code
|
||||
- Commit: Commit your changes with clear messages
|
||||
4. When you complete a subtask, document it in the prompt file so the next iteration knows what's done.
|
||||
5. For maximum efficiency, whenever you need to perform multiple independent operations, invoke all relevant tools simultaneously rather than sequentially.
|
||||
6. If you create any temporary new files, scripts, or helper files for iteration, clean up these files by removing them at the end of the task.
|
||||
|
||||
## Agent Scratchpad
|
||||
Before starting your work, check if .agent/scratchpad.md exists in the current working directory.
|
||||
If it does, read it to understand what was accomplished in previous iterations and continue from there.
|
||||
|
||||
At the end of your iteration, update .agent/scratchpad.md with:
|
||||
- What you accomplished this iteration
|
||||
- What remains to be done
|
||||
- Any important context or decisions made
|
||||
- Current blockers or issues (if any)
|
||||
|
||||
Do NOT restart from scratch if the scratchpad shows previous progress. Continue where the previous iteration left off.
|
||||
|
||||
Create the .agent/ directory if it doesn't exist.
|
||||
|
||||
---
|
||||
ORIGINAL PROMPT:
|
||||
|
||||
"""
|
||||
enhanced_prompt = orchestration_instructions + prompt
|
||||
|
||||
# Inject completion promise if requested and not present
|
||||
if promise:
|
||||
# Check if promise is already in the text (simple check)
|
||||
# We check both the raw promise and the section header to be safe,
|
||||
# but mainly we want to avoid duplicating the specific instruction.
|
||||
if promise not in enhanced_prompt:
|
||||
promise_section = f"""
|
||||
|
||||
## Completion Promise
|
||||
When you have completed the task, output this exact line:
|
||||
{promise}
|
||||
"""
|
||||
enhanced_prompt += promise_section
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} (available: {self.available})"
|
||||
@@ -0,0 +1,616 @@
|
||||
# ABOUTME: Claude SDK adapter implementation
|
||||
# ABOUTME: Provides integration with Anthropic's Claude via Python SDK
|
||||
# ABOUTME: Supports inheriting user's Claude Code settings (MCP servers, CLAUDE.md, etc.)
|
||||
|
||||
"""Claude SDK adapter for Ralph Orchestrator."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
from typing import Optional
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from ..error_formatter import ClaudeErrorFormatter
|
||||
from ..output.console import RalphConsole
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, query
|
||||
CLAUDE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
# Fallback to old package name for backwards compatibility
|
||||
try:
|
||||
from claude_code_sdk import ClaudeCodeOptions as ClaudeAgentOptions, query
|
||||
CLAUDE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
CLAUDE_SDK_AVAILABLE = False
|
||||
query = None
|
||||
ClaudeAgentOptions = None
|
||||
|
||||
|
||||
class ClaudeAdapter(ToolAdapter):
|
||||
"""Adapter for Claude using the Python SDK."""
|
||||
|
||||
# Default max buffer size: 10MB (handles large screenshots from chrome-devtools-mcp)
|
||||
DEFAULT_MAX_BUFFER_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# Default model: Claude Opus 4.5 (most intelligent model)
|
||||
DEFAULT_MODEL = "claude-opus-4-5-20251101"
|
||||
|
||||
# Model pricing (per million tokens)
|
||||
MODEL_PRICING = {
|
||||
"claude-opus-4-5-20251101": {"input": 5.0, "output": 25.0},
|
||||
"claude-sonnet-4-5-20250929": {"input": 3.0, "output": 15.0},
|
||||
"claude-haiku-4-5-20251001": {"input": 1.0, "output": 5.0},
|
||||
# Legacy models
|
||||
"claude-3-opus": {"input": 15.0, "output": 75.0},
|
||||
"claude-3-sonnet": {"input": 3.0, "output": 15.0},
|
||||
"claude-3-haiku": {"input": 0.25, "output": 1.25},
|
||||
}
|
||||
|
||||
def __init__(self, verbose: bool = False, max_buffer_size: int = None,
|
||||
inherit_user_settings: bool = True, cli_path: str = None,
|
||||
model: str = None):
|
||||
super().__init__("claude")
|
||||
self.sdk_available = CLAUDE_SDK_AVAILABLE
|
||||
self._system_prompt = None
|
||||
self._allowed_tools = None
|
||||
self._disallowed_tools = None
|
||||
self._enable_all_tools = False
|
||||
self._enable_web_search = True # Enable WebSearch by default
|
||||
self._max_buffer_size = max_buffer_size or self.DEFAULT_MAX_BUFFER_SIZE
|
||||
self.verbose = verbose
|
||||
# Enable loading user's Claude Code settings (including MCP servers) by default
|
||||
self._inherit_user_settings = inherit_user_settings
|
||||
# Optional path to user's Claude Code CLI (uses bundled CLI if not specified)
|
||||
self._cli_path = cli_path
|
||||
self._model = model or self.DEFAULT_MODEL
|
||||
self._subprocess_pid: Optional[int] = None
|
||||
self._console = RalphConsole()
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if Claude SDK is available and properly configured."""
|
||||
# Claude Code SDK works without API key - it uses the local environment
|
||||
return CLAUDE_SDK_AVAILABLE
|
||||
|
||||
def configure(self,
|
||||
system_prompt: Optional[str] = None,
|
||||
allowed_tools: Optional[list] = None,
|
||||
disallowed_tools: Optional[list] = None,
|
||||
enable_all_tools: bool = False,
|
||||
enable_web_search: bool = True,
|
||||
inherit_user_settings: Optional[bool] = None,
|
||||
cli_path: Optional[str] = None,
|
||||
model: Optional[str] = None):
|
||||
"""Configure the Claude adapter with custom options.
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt for Claude
|
||||
allowed_tools: List of allowed tools for Claude to use (if None and enable_all_tools=True, all tools are enabled)
|
||||
disallowed_tools: List of disallowed tools
|
||||
enable_all_tools: If True and allowed_tools is None, enables all native Claude tools
|
||||
enable_web_search: If True, explicitly enables WebSearch tool (default: True)
|
||||
inherit_user_settings: If True, load user's Claude Code settings including MCP servers (default: True)
|
||||
cli_path: Path to user's Claude Code CLI (uses bundled CLI if not specified)
|
||||
model: Model to use (default: claude-opus-4-5-20251101)
|
||||
"""
|
||||
self._system_prompt = system_prompt
|
||||
self._allowed_tools = allowed_tools
|
||||
self._disallowed_tools = disallowed_tools
|
||||
self._enable_all_tools = enable_all_tools
|
||||
self._enable_web_search = enable_web_search
|
||||
|
||||
# Update user settings inheritance if specified
|
||||
if inherit_user_settings is not None:
|
||||
self._inherit_user_settings = inherit_user_settings
|
||||
|
||||
# Update CLI path if specified
|
||||
if cli_path is not None:
|
||||
self._cli_path = cli_path
|
||||
|
||||
# Update model if specified
|
||||
if model is not None:
|
||||
self._model = model
|
||||
|
||||
# If web search is enabled and we have an allowed tools list, add WebSearch to it
|
||||
if enable_web_search and allowed_tools is not None and 'WebSearch' not in allowed_tools:
|
||||
self._allowed_tools = allowed_tools + ['WebSearch']
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute Claude with the given prompt synchronously.
|
||||
|
||||
This is a blocking wrapper around the async implementation.
|
||||
"""
|
||||
try:
|
||||
# Create new event loop if needed
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(self.aexecute(prompt, **kwargs))
|
||||
else:
|
||||
# If loop is already running, schedule as task
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, self.aexecute(prompt, **kwargs))
|
||||
return future.result()
|
||||
except Exception as e:
|
||||
# Use error formatter for user-friendly error messages
|
||||
error_msg = ClaudeErrorFormatter.format_error_from_exception(
|
||||
iteration=kwargs.get('iteration', 0),
|
||||
exception=e
|
||||
)
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(error_msg)
|
||||
)
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute Claude with the given prompt asynchronously."""
|
||||
if not self.available:
|
||||
logger.debug("Claude SDK not available")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Claude SDK is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get configuration from kwargs or use defaults
|
||||
prompt_file = kwargs.get('prompt_file', 'PROMPT.md')
|
||||
|
||||
# Build options for Claude Code
|
||||
options_dict = {}
|
||||
|
||||
# Set system prompt with orchestration context
|
||||
system_prompt = kwargs.get('system_prompt', self._system_prompt)
|
||||
if not system_prompt:
|
||||
# Create a default system prompt with orchestration context
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
system_prompt = (
|
||||
f"You are helping complete a task. "
|
||||
f"The task is described in the file '{prompt_file}'. "
|
||||
f"Please edit this file directly to add your solution and progress updates."
|
||||
)
|
||||
# Use the enhanced prompt as the main prompt
|
||||
prompt = enhanced_prompt
|
||||
else:
|
||||
# If custom system prompt provided, still enhance the main prompt
|
||||
prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
options_dict['system_prompt'] = system_prompt
|
||||
|
||||
# Set tool restrictions if provided
|
||||
# If enable_all_tools is True and no allowed_tools specified, don't set any restrictions
|
||||
enable_all_tools = kwargs.get('enable_all_tools', self._enable_all_tools)
|
||||
enable_web_search = kwargs.get('enable_web_search', self._enable_web_search)
|
||||
allowed_tools = kwargs.get('allowed_tools', self._allowed_tools)
|
||||
disallowed_tools = kwargs.get('disallowed_tools', self._disallowed_tools)
|
||||
|
||||
# Add WebSearch to allowed tools if web search is enabled
|
||||
if enable_web_search and allowed_tools is not None and 'WebSearch' not in allowed_tools:
|
||||
allowed_tools = allowed_tools + ['WebSearch']
|
||||
|
||||
# Only set tool restrictions if we're not enabling all tools or if specific tools are provided
|
||||
if not enable_all_tools or allowed_tools:
|
||||
if allowed_tools:
|
||||
options_dict['allowed_tools'] = allowed_tools
|
||||
|
||||
if disallowed_tools:
|
||||
options_dict['disallowed_tools'] = disallowed_tools
|
||||
|
||||
# If enable_all_tools is True and no allowed_tools, Claude will have access to all native tools
|
||||
if enable_all_tools and not allowed_tools:
|
||||
if self.verbose:
|
||||
logger.debug("Enabling all native Claude tools (including WebSearch)")
|
||||
|
||||
# Set permission mode - default to bypassPermissions for smoother operation
|
||||
permission_mode = kwargs.get('permission_mode', 'bypassPermissions')
|
||||
options_dict['permission_mode'] = permission_mode
|
||||
if self.verbose:
|
||||
logger.debug(f"Permission mode: {permission_mode}")
|
||||
|
||||
# Set current working directory to ensure files are created in the right place
|
||||
import os
|
||||
cwd = kwargs.get('cwd', os.getcwd())
|
||||
options_dict['cwd'] = cwd
|
||||
if self.verbose:
|
||||
logger.debug(f"Working directory: {cwd}")
|
||||
|
||||
# Set max buffer size for handling large responses (e.g., screenshots)
|
||||
max_buffer_size = kwargs.get('max_buffer_size', self._max_buffer_size)
|
||||
options_dict['max_buffer_size'] = max_buffer_size
|
||||
if self.verbose:
|
||||
logger.debug(f"Max buffer size: {max_buffer_size} bytes")
|
||||
|
||||
# Configure setting sources to inherit user's Claude Code configuration
|
||||
# This enables MCP servers, CLAUDE.md files, and other user settings
|
||||
inherit_user_settings = kwargs.get('inherit_user_settings', self._inherit_user_settings)
|
||||
if inherit_user_settings:
|
||||
# Load user, project, and local settings (includes MCP servers)
|
||||
options_dict['setting_sources'] = ['user', 'project', 'local']
|
||||
if self.verbose:
|
||||
logger.debug("Inheriting user's Claude Code settings (MCP servers, CLAUDE.md, etc.)")
|
||||
|
||||
# Optional: use user's installed Claude Code CLI instead of bundled
|
||||
cli_path = kwargs.get('cli_path', self._cli_path)
|
||||
if cli_path:
|
||||
options_dict['cli_path'] = cli_path
|
||||
if self.verbose:
|
||||
logger.debug(f"Using custom Claude CLI: {cli_path}")
|
||||
|
||||
# Set model - defaults to Opus 4.5
|
||||
model = kwargs.get('model', self._model)
|
||||
options_dict['model'] = model
|
||||
if self.verbose:
|
||||
logger.debug(f"Using model: {model}")
|
||||
|
||||
# Create options
|
||||
options = ClaudeAgentOptions(**options_dict)
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose:
|
||||
logger.debug("Claude SDK Request:")
|
||||
logger.debug(f" Prompt length: {len(prompt)} characters")
|
||||
logger.debug(f" System prompt: {system_prompt}")
|
||||
if allowed_tools:
|
||||
logger.debug(f" Allowed tools: {allowed_tools}")
|
||||
if disallowed_tools:
|
||||
logger.debug(f" Disallowed tools: {disallowed_tools}")
|
||||
|
||||
# Collect all response chunks
|
||||
output_chunks = []
|
||||
tokens_used = 0
|
||||
chunk_count = 0
|
||||
|
||||
# Use one-shot query for simpler execution
|
||||
if self.verbose:
|
||||
logger.debug("Starting Claude SDK query...")
|
||||
self._console.print_header("CLAUDE PROCESSING")
|
||||
|
||||
async for message in query(prompt=prompt, options=options):
|
||||
chunk_count += 1
|
||||
msg_type = type(message).__name__
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[DEBUG: Received {msg_type}]", flush=True)
|
||||
logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle different message types
|
||||
if msg_type == 'AssistantMessage':
|
||||
# Extract content from AssistantMessage
|
||||
if hasattr(message, 'content') and message.content:
|
||||
for content_block in message.content:
|
||||
block_type = type(content_block).__name__
|
||||
|
||||
if hasattr(content_block, 'text'):
|
||||
# TextBlock
|
||||
text = content_block.text
|
||||
output_chunks.append(text)
|
||||
|
||||
if self.verbose and text:
|
||||
self._console.print_message(text)
|
||||
logger.debug(f"Received assistant text: {len(text)} characters")
|
||||
|
||||
elif block_type == 'ToolUseBlock':
|
||||
if self.verbose:
|
||||
tool_name = getattr(content_block, 'name', 'unknown')
|
||||
tool_id = getattr(content_block, 'id', 'unknown')
|
||||
tool_input = getattr(content_block, 'input', {})
|
||||
|
||||
self._console.print_separator()
|
||||
self._console.print_status(f"TOOL USE: {tool_name}", style="cyan bold")
|
||||
self._console.print_info(f"ID: {tool_id[:12]}...")
|
||||
|
||||
if tool_input:
|
||||
self._console.print_info("Input Parameters:")
|
||||
for key, value in tool_input.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:97] + "..."
|
||||
self._console.print_info(f" - {key}: {value_str}")
|
||||
|
||||
logger.debug(f"Tool use detected: {tool_name} (id: {tool_id[:8]}...)")
|
||||
if hasattr(content_block, 'input'):
|
||||
logger.debug(f" Tool input: {content_block.input}")
|
||||
|
||||
else:
|
||||
if self.verbose:
|
||||
logger.debug(f"Unknown content block type: {block_type}")
|
||||
|
||||
elif msg_type == 'ResultMessage':
|
||||
# ResultMessage contains final result and usage stats
|
||||
if hasattr(message, 'result'):
|
||||
# Don't append result - it's usually a duplicate of assistant message
|
||||
if self.verbose:
|
||||
logger.debug(f"Result message received: {len(str(message.result))} characters")
|
||||
|
||||
# Extract token usage from ResultMessage
|
||||
if hasattr(message, 'usage'):
|
||||
usage = message.usage
|
||||
if isinstance(usage, dict):
|
||||
tokens_used = usage.get('input_tokens', 0) + usage.get('output_tokens', 0)
|
||||
else:
|
||||
tokens_used = getattr(usage, 'total_tokens', 0)
|
||||
if self.verbose:
|
||||
logger.debug(f"Token usage: {tokens_used} tokens")
|
||||
|
||||
elif msg_type == 'SystemMessage':
|
||||
# SystemMessage is initialization data, skip it
|
||||
if self.verbose:
|
||||
logger.debug("System initialization message received")
|
||||
|
||||
elif msg_type == 'UserMessage':
|
||||
if self.verbose:
|
||||
logger.debug("User message (tool result) received")
|
||||
|
||||
if hasattr(message, 'content'):
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
for content_item in content:
|
||||
if hasattr(content_item, '__class__'):
|
||||
item_type = content_item.__class__.__name__
|
||||
if item_type == 'ToolResultBlock':
|
||||
tool_use_id = getattr(content_item, 'tool_use_id', 'unknown')
|
||||
result_content = getattr(content_item, 'content', None)
|
||||
is_error = getattr(content_item, 'is_error', False)
|
||||
|
||||
self._console.print_separator()
|
||||
self._console.print_status("TOOL RESULT", style="yellow bold")
|
||||
self._console.print_info(f"For Tool ID: {tool_use_id[:12]}...")
|
||||
|
||||
if is_error:
|
||||
self._console.print_error("Status: ERROR")
|
||||
else:
|
||||
self._console.print_success("Status: Success")
|
||||
|
||||
if result_content:
|
||||
self._console.print_info("Output:")
|
||||
if isinstance(result_content, str):
|
||||
if len(result_content) > 500:
|
||||
self._console.print_message(f" {result_content[:497]}...")
|
||||
else:
|
||||
self._console.print_message(f" {result_content}")
|
||||
elif isinstance(result_content, list):
|
||||
for item in result_content[:3]:
|
||||
self._console.print_info(f" - {item}")
|
||||
if len(result_content) > 3:
|
||||
self._console.print_info(f" ... and {len(result_content) - 3} more items")
|
||||
|
||||
elif msg_type == 'ToolResultMessage':
|
||||
if self.verbose:
|
||||
logger.debug("Tool result message received")
|
||||
|
||||
self._console.print_separator()
|
||||
self._console.print_status("TOOL RESULT MESSAGE", style="yellow bold")
|
||||
|
||||
if hasattr(message, 'tool_use_id'):
|
||||
self._console.print_info(f"Tool ID: {message.tool_use_id[:12]}...")
|
||||
|
||||
if hasattr(message, 'content'):
|
||||
content = message.content
|
||||
if content:
|
||||
self._console.print_info("Content:")
|
||||
if isinstance(content, str):
|
||||
if len(content) > 500:
|
||||
self._console.print_message(f" {content[:497]}...")
|
||||
else:
|
||||
self._console.print_message(f" {content}")
|
||||
elif isinstance(content, list):
|
||||
for item in content[:3]:
|
||||
self._console.print_info(f" - {item}")
|
||||
if len(content) > 3:
|
||||
self._console.print_info(f" ... and {len(content) - 3} more items")
|
||||
|
||||
if hasattr(message, 'is_error') and message.is_error:
|
||||
self._console.print_error("Error: True")
|
||||
|
||||
elif hasattr(message, 'text'):
|
||||
chunk_text = message.text
|
||||
output_chunks.append(chunk_text)
|
||||
if self.verbose:
|
||||
self._console.print_message(chunk_text)
|
||||
logger.debug(f"Received text chunk {chunk_count}: {len(chunk_text)} characters")
|
||||
|
||||
elif isinstance(message, str):
|
||||
output_chunks.append(message)
|
||||
if self.verbose:
|
||||
self._console.print_message(message)
|
||||
logger.debug(f"Received string chunk {chunk_count}: {len(message)} characters")
|
||||
|
||||
else:
|
||||
if self.verbose:
|
||||
logger.debug(f"Unknown message type {msg_type}: {message}")
|
||||
|
||||
# Combine output
|
||||
output = ''.join(output_chunks)
|
||||
|
||||
# End streaming section if verbose
|
||||
if self.verbose:
|
||||
self._console.print_separator()
|
||||
|
||||
# Always log the output we're about to return
|
||||
logger.debug(f"Claude adapter returning {len(output)} characters of output")
|
||||
if output:
|
||||
logger.debug(f"Output preview: {output[:200]}...")
|
||||
|
||||
# Calculate cost if we have token count (using model-specific pricing)
|
||||
cost = self._calculate_cost(tokens_used, model) if tokens_used > 0 else None
|
||||
|
||||
# Log response details if verbose
|
||||
if self.verbose:
|
||||
logger.debug("Claude SDK Response:")
|
||||
logger.debug(f" Output length: {len(output)} characters")
|
||||
logger.debug(f" Chunks received: {chunk_count}")
|
||||
if tokens_used > 0:
|
||||
logger.debug(f" Tokens used: {tokens_used}")
|
||||
if cost:
|
||||
logger.debug(f" Estimated cost: ${cost:.4f}")
|
||||
logger.debug(f"Response preview: {output[:500]}..." if len(output) > 500 else f"Response: {output}")
|
||||
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=output,
|
||||
tokens_used=tokens_used if tokens_used > 0 else None,
|
||||
cost=cost,
|
||||
metadata={"model": model}
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
# Use error formatter for user-friendly timeout message
|
||||
error_msg = ClaudeErrorFormatter.format_error_from_exception(
|
||||
iteration=kwargs.get('iteration', 0),
|
||||
exception=e
|
||||
)
|
||||
logger.warning(f"Claude SDK request timed out: {error_msg.message}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(error_msg)
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if this is a user-initiated cancellation (SIGINT = exit code -2)
|
||||
error_str = str(e)
|
||||
if "exit code -2" in error_str or "exit code: -2" in error_str:
|
||||
logger.debug("Claude execution cancelled by user")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Execution cancelled by user"
|
||||
)
|
||||
|
||||
# Use error formatter for user-friendly error messages
|
||||
error_msg = ClaudeErrorFormatter.format_error_from_exception(
|
||||
iteration=kwargs.get('iteration', 0),
|
||||
exception=e
|
||||
)
|
||||
|
||||
# Log additional debugging information for "Command failed" errors
|
||||
if "Command failed with exit code 1" in error_str:
|
||||
logger.error(f"Claude CLI command failed - this may indicate:")
|
||||
logger.error(" 1. Claude CLI not properly installed or configured")
|
||||
logger.error(" 2. Missing API key or authentication issues")
|
||||
logger.error(" 3. Network connectivity problems")
|
||||
logger.error(" 4. Insufficient permissions")
|
||||
logger.error(f" Full error: {error_str}")
|
||||
|
||||
# Try to provide more specific guidance
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(['claude', '--version'], capture_output=True, text=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
logger.error(f" Claude CLI version: {result.stdout.strip()}")
|
||||
else:
|
||||
logger.error(f" Claude CLI check failed: {result.stderr}")
|
||||
except Exception as check_e:
|
||||
logger.error(f" Could not check Claude CLI: {check_e}")
|
||||
|
||||
logger.error(f"Claude SDK error: {error_msg.message}", exc_info=True)
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(error_msg)
|
||||
)
|
||||
|
||||
def _calculate_cost(self, tokens: Optional[int], model: str = None) -> Optional[float]:
|
||||
"""Calculate estimated cost based on tokens and model.
|
||||
|
||||
Args:
|
||||
tokens: Total tokens used (input + output combined)
|
||||
model: Model ID used for the request
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD, or None if tokens is None/0
|
||||
"""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
model = model or self._model
|
||||
|
||||
# Get model pricing or use default
|
||||
if model in self.MODEL_PRICING:
|
||||
pricing = self.MODEL_PRICING[model]
|
||||
else:
|
||||
# Fallback to Opus 4.5 pricing for unknown models
|
||||
pricing = self.MODEL_PRICING[self.DEFAULT_MODEL]
|
||||
|
||||
# Estimate input/output split (typically ~30% input, ~70% output for agent work)
|
||||
# This is an approximation since we don't always get separate counts
|
||||
input_tokens = int(tokens * 0.3)
|
||||
output_tokens = int(tokens * 0.7)
|
||||
|
||||
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
||||
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
||||
|
||||
return input_cost + output_cost
|
||||
|
||||
def estimate_cost(self, prompt: str, model: str = None) -> float:
|
||||
"""Estimate cost for the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt text to estimate cost for
|
||||
model: Model ID to use for pricing (defaults to configured model)
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
# Rough estimation: 1 token ≈ 4 characters
|
||||
estimated_tokens = len(prompt) / 4
|
||||
return self._calculate_cost(int(estimated_tokens), model) or 0.0
|
||||
|
||||
def kill_subprocess_sync(self) -> None:
|
||||
"""
|
||||
Kill subprocess synchronously (safe to call from signal handler).
|
||||
|
||||
This method uses os.kill() which is signal-safe and can be called
|
||||
from the signal handler context. It immediately terminates the subprocess,
|
||||
which unblocks any I/O operations waiting on it.
|
||||
"""
|
||||
if self._subprocess_pid:
|
||||
try:
|
||||
# Try SIGTERM first for graceful shutdown
|
||||
os.kill(self._subprocess_pid, signal.SIGTERM)
|
||||
# Small delay to allow graceful shutdown - keep minimal for signal handler
|
||||
import time
|
||||
try:
|
||||
time.sleep(0.01)
|
||||
except Exception:
|
||||
pass # Ignore errors during sleep in signal handler
|
||||
# Then SIGKILL if still alive (more forceful)
|
||||
try:
|
||||
os.kill(self._subprocess_pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass # Already dead from SIGTERM
|
||||
except ProcessLookupError:
|
||||
pass # Already dead
|
||||
except (PermissionError, OSError):
|
||||
pass # Best effort - process might be owned by another user
|
||||
finally:
|
||||
self._subprocess_pid = None
|
||||
|
||||
async def _cleanup_transport(self) -> None:
|
||||
"""Clean up transport and kill subprocess with timeout protection."""
|
||||
# Kill subprocess first (if not already killed by signal handler)
|
||||
if self._subprocess_pid:
|
||||
try:
|
||||
# Try SIGTERM first
|
||||
os.kill(self._subprocess_pid, signal.SIGTERM)
|
||||
# Wait with timeout to avoid hanging
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.sleep(0.01), timeout=0.05)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Continue even if sleep times out
|
||||
# Force kill if still alive
|
||||
try:
|
||||
os.kill(self._subprocess_pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass # Already dead
|
||||
except ProcessLookupError:
|
||||
pass # Already terminated
|
||||
except (PermissionError, OSError):
|
||||
pass # Best effort cleanup
|
||||
finally:
|
||||
self._subprocess_pid = None
|
||||
@@ -0,0 +1,120 @@
|
||||
# ABOUTME: Gemini CLI adapter implementation
|
||||
# ABOUTME: Provides fallback integration with Google's Gemini AI
|
||||
|
||||
"""Gemini CLI adapter for Ralph Orchestrator."""
|
||||
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
|
||||
|
||||
class GeminiAdapter(ToolAdapter):
|
||||
"""Adapter for Gemini CLI tool."""
|
||||
|
||||
def __init__(self):
|
||||
self.command = "gemini"
|
||||
super().__init__("gemini")
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if Gemini CLI is available."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self.command, "--version"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute Gemini with the given prompt."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Gemini CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Build command
|
||||
cmd = [self.command]
|
||||
|
||||
# Add model if specified
|
||||
if kwargs.get("model"):
|
||||
cmd.extend(["--model", kwargs["model"]])
|
||||
|
||||
# Add the enhanced prompt
|
||||
cmd.extend(["-p", enhanced_prompt])
|
||||
|
||||
# Add output format if specified
|
||||
if kwargs.get("output_format"):
|
||||
cmd.extend(["--output", kwargs["output_format"]])
|
||||
|
||||
# Execute command
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=kwargs.get("timeout", 300) # 5 minute default
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
# Extract token count if available
|
||||
tokens = self._extract_token_count(result.stderr)
|
||||
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=result.stdout,
|
||||
tokens_used=tokens,
|
||||
cost=self._calculate_cost(tokens),
|
||||
metadata={"model": kwargs.get("model", "gemini-2.5-pro")}
|
||||
)
|
||||
else:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=result.stdout,
|
||||
error=result.stderr or "Gemini command failed"
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Gemini command timed out"
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _extract_token_count(self, stderr: str) -> Optional[int]:
|
||||
"""Extract token count from Gemini output."""
|
||||
# Implementation depends on Gemini's output format
|
||||
return None
|
||||
|
||||
def _calculate_cost(self, tokens: Optional[int]) -> Optional[float]:
|
||||
"""Calculate estimated cost based on tokens."""
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
# Gemini has free tier up to 1M tokens
|
||||
if tokens < 1_000_000:
|
||||
return 0.0
|
||||
|
||||
# After free tier: $0.001 per 1K tokens (approximate)
|
||||
excess_tokens = tokens - 1_000_000
|
||||
cost_per_1k = 0.001
|
||||
return (excess_tokens / 1000) * cost_per_1k
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Estimate cost for the prompt."""
|
||||
# Rough estimation: 1 token ≈ 4 characters
|
||||
estimated_tokens = len(prompt) / 4
|
||||
return self._calculate_cost(estimated_tokens) or 0.0
|
||||
@@ -0,0 +1,562 @@
|
||||
# ABOUTME: Kiro CLI adapter implementation
|
||||
# ABOUTME: Provides integration with kiro-cli chat command for AI interactions
|
||||
|
||||
"""Kiro CLI adapter for Ralph Orchestrator."""
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import threading
|
||||
import asyncio
|
||||
import time
|
||||
try:
|
||||
import fcntl # Unix-only
|
||||
except ModuleNotFoundError:
|
||||
fcntl = None
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from ..logging_config import RalphLogger
|
||||
|
||||
# Get logger for this module
|
||||
logger = RalphLogger.get_logger(RalphLogger.ADAPTER_KIRO)
|
||||
|
||||
|
||||
class KiroAdapter(ToolAdapter):
|
||||
"""Adapter for Kiro CLI tool."""
|
||||
|
||||
def __init__(self):
|
||||
# Get configuration from environment variables
|
||||
self.command = os.getenv("RALPH_KIRO_COMMAND", "kiro-cli")
|
||||
self.default_timeout = int(os.getenv("RALPH_KIRO_TIMEOUT", "600"))
|
||||
self.default_prompt_file = os.getenv("RALPH_KIRO_PROMPT_FILE", "PROMPT.md")
|
||||
self.trust_all_tools = os.getenv("RALPH_KIRO_TRUST_TOOLS", "true").lower() == "true"
|
||||
self.no_interactive = os.getenv("RALPH_KIRO_NO_INTERACTIVE", "true").lower() == "true"
|
||||
|
||||
# Initialize signal handler attributes before calling super()
|
||||
self._original_sigint = None
|
||||
self._original_sigterm = None
|
||||
|
||||
super().__init__("kiro")
|
||||
self.current_process = None
|
||||
self.shutdown_requested = False
|
||||
|
||||
# Thread synchronization
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Register signal handlers to propagate shutdown to subprocess
|
||||
self._register_signal_handlers()
|
||||
|
||||
logger.info(f"Kiro adapter initialized - Command: {self.command}, "
|
||||
f"Default timeout: {self.default_timeout}s, "
|
||||
f"Trust tools: {self.trust_all_tools}")
|
||||
|
||||
def _register_signal_handlers(self):
|
||||
"""Register signal handlers and store originals."""
|
||||
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
|
||||
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
logger.debug("Signal handlers registered for SIGINT and SIGTERM")
|
||||
|
||||
def _restore_signal_handlers(self):
|
||||
"""Restore original signal handlers."""
|
||||
if hasattr(self, "_original_sigint") and self._original_sigint is not None:
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if hasattr(self, "_original_sigterm") and self._original_sigterm is not None:
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals and terminate running subprocess."""
|
||||
with self._lock:
|
||||
self.shutdown_requested = True
|
||||
process = self.current_process
|
||||
|
||||
if process and process.poll() is None:
|
||||
logger.warning(f"Received signal {signum}, terminating Kiro process...")
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=3)
|
||||
logger.debug("Process terminated gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Force killing Kiro process...")
|
||||
process.kill()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process may still be running after force kill")
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if kiro-cli (or q) is available."""
|
||||
# First check for configured command (defaults to kiro-cli)
|
||||
if self._check_command(self.command):
|
||||
logger.debug(f"Kiro command '{self.command}' available")
|
||||
return True
|
||||
|
||||
# If kiro-cli not found, try fallback to 'q'
|
||||
if self.command == "kiro-cli":
|
||||
if self._check_command("q"):
|
||||
logger.info("kiro-cli not found, falling back to 'q'")
|
||||
self.command = "q"
|
||||
return True
|
||||
|
||||
logger.warning(f"Kiro command '{self.command}' (and fallback) not found")
|
||||
return False
|
||||
|
||||
def _check_command(self, cmd: str) -> bool:
|
||||
"""Check if a specific command exists."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["which", cmd],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute kiro-cli chat with the given prompt."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Kiro CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get verbose flag from kwargs
|
||||
verbose = kwargs.get("verbose", True)
|
||||
|
||||
# Get the prompt file path from kwargs if available
|
||||
prompt_file = kwargs.get("prompt_file", self.default_prompt_file)
|
||||
|
||||
logger.info(f"Executing Kiro chat - Prompt file: {prompt_file}, Verbose: {verbose}")
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Construct a more effective prompt
|
||||
# Tell it explicitly to edit the prompt file
|
||||
effective_prompt = (
|
||||
f"Please read and complete the task described in the file '{prompt_file}'. "
|
||||
f"The current content is:\n\n{enhanced_prompt}\n\n"
|
||||
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
|
||||
)
|
||||
|
||||
# Build command
|
||||
cmd = [self.command, "chat"]
|
||||
|
||||
if self.no_interactive:
|
||||
cmd.append("--no-interactive")
|
||||
|
||||
if self.trust_all_tools:
|
||||
cmd.append("--trust-all-tools")
|
||||
|
||||
cmd.append(effective_prompt)
|
||||
|
||||
logger.debug(f"Command constructed: {' '.join(cmd)}")
|
||||
|
||||
timeout = kwargs.get("timeout", self.default_timeout)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"Starting {self.command} chat command...")
|
||||
logger.info(f"Command: {' '.join(cmd)}")
|
||||
logger.info(f"Working directory: {os.getcwd()}")
|
||||
logger.info(f"Timeout: {timeout} seconds")
|
||||
print("-" * 60, file=sys.stderr)
|
||||
|
||||
# Use Popen for real-time output streaming
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=os.getcwd(),
|
||||
bufsize=0, # Unbuffered to prevent deadlock
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
# Set process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = process
|
||||
|
||||
# Make pipes non-blocking to prevent deadlock
|
||||
self._make_non_blocking(process.stdout)
|
||||
self._make_non_blocking(process.stderr)
|
||||
|
||||
# Collect output while streaming
|
||||
stdout_lines = []
|
||||
stderr_lines = []
|
||||
|
||||
start_time = time.time()
|
||||
last_output_time = start_time
|
||||
|
||||
while True:
|
||||
# Check for shutdown signal first with lock
|
||||
with self._lock:
|
||||
shutdown = self.shutdown_requested
|
||||
|
||||
if shutdown:
|
||||
if verbose:
|
||||
print("Shutdown requested, terminating process...", file=sys.stderr)
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait(timeout=2)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="".join(stdout_lines),
|
||||
error="Process terminated due to shutdown signal"
|
||||
)
|
||||
|
||||
# Check for timeout
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Log progress every 30 seconds
|
||||
if int(elapsed_time) % 30 == 0 and int(elapsed_time) > 0:
|
||||
logger.debug(f"Kiro still running... elapsed: {elapsed_time:.1f}s / {timeout}s")
|
||||
|
||||
# Check if the process seems stuck (no output for a while)
|
||||
time_since_output = time.time() - last_output_time
|
||||
if time_since_output > 60:
|
||||
logger.info(f"No output received for {time_since_output:.1f}s, Kiro might be stuck")
|
||||
|
||||
if verbose:
|
||||
print(f"Kiro still running... elapsed: {elapsed_time:.1f}s / {timeout}s", file=sys.stderr)
|
||||
|
||||
if elapsed_time > timeout:
|
||||
logger.warning(f"Command timed out after {elapsed_time:.2f} seconds")
|
||||
if verbose:
|
||||
print(f"Command timed out after {elapsed_time:.2f} seconds", file=sys.stderr)
|
||||
|
||||
# Try to terminate gracefully first
|
||||
process.terminate()
|
||||
try:
|
||||
# Wait a bit for graceful termination
|
||||
process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Graceful termination failed, force killing process")
|
||||
if verbose:
|
||||
print("Graceful termination failed, force killing process", file=sys.stderr)
|
||||
process.kill()
|
||||
# Wait for force kill to complete
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process may still be running after kill")
|
||||
if verbose:
|
||||
print("Warning: Process may still be running after kill", file=sys.stderr)
|
||||
|
||||
# Try to capture any remaining output after termination
|
||||
try:
|
||||
remaining_stdout = process.stdout.read()
|
||||
remaining_stderr = process.stderr.read()
|
||||
if remaining_stdout:
|
||||
stdout_lines.append(remaining_stdout)
|
||||
if remaining_stderr:
|
||||
stderr_lines.append(remaining_stderr)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read remaining output after timeout: {e}")
|
||||
if verbose:
|
||||
print(f"Warning: Could not read remaining output after timeout: {e}", file=sys.stderr)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="".join(stdout_lines),
|
||||
error=f"Kiro command timed out after {elapsed_time:.2f} seconds"
|
||||
)
|
||||
|
||||
# Check if process is still running
|
||||
if process.poll() is not None:
|
||||
# Process finished, read remaining output
|
||||
remaining_stdout = process.stdout.read()
|
||||
remaining_stderr = process.stderr.read()
|
||||
|
||||
if remaining_stdout:
|
||||
stdout_lines.append(remaining_stdout)
|
||||
if verbose:
|
||||
print(f"{remaining_stdout}", end="", file=sys.stderr)
|
||||
|
||||
if remaining_stderr:
|
||||
stderr_lines.append(remaining_stderr)
|
||||
if verbose:
|
||||
print(f"{remaining_stderr}", end="", file=sys.stderr)
|
||||
|
||||
break
|
||||
|
||||
# Read available data without blocking
|
||||
try:
|
||||
# Read stdout
|
||||
stdout_data = self._read_available(process.stdout)
|
||||
if stdout_data:
|
||||
stdout_lines.append(stdout_data)
|
||||
last_output_time = time.time()
|
||||
if verbose:
|
||||
print(stdout_data, end="", file=sys.stderr)
|
||||
|
||||
# Read stderr
|
||||
stderr_data = self._read_available(process.stderr)
|
||||
if stderr_data:
|
||||
stderr_lines.append(stderr_data)
|
||||
last_output_time = time.time()
|
||||
if verbose:
|
||||
print(stderr_data, end="", file=sys.stderr)
|
||||
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.01)
|
||||
|
||||
except BlockingIOError:
|
||||
# Non-blocking read returns BlockingIOError when no data available
|
||||
pass
|
||||
|
||||
# Get final return code
|
||||
returncode = process.poll()
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"Process completed - Return code: {returncode}, Execution time: {execution_time:.2f}s")
|
||||
|
||||
if verbose:
|
||||
print("-" * 60, file=sys.stderr)
|
||||
print(f"Process completed with return code: {returncode}", file=sys.stderr)
|
||||
print(f"Total execution time: {execution_time:.2f} seconds", file=sys.stderr)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
# Combine output
|
||||
full_stdout = "".join(stdout_lines)
|
||||
full_stderr = "".join(stderr_lines)
|
||||
|
||||
if returncode == 0:
|
||||
logger.debug(f"Kiro chat succeeded - Output length: {len(full_stdout)} chars")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=full_stdout,
|
||||
metadata={
|
||||
"tool": "kiro chat",
|
||||
"execution_time": execution_time,
|
||||
"verbose": verbose,
|
||||
"return_code": returncode
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Kiro chat failed - Return code: {returncode}, Error: {full_stderr[:200]}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=full_stdout,
|
||||
error=full_stderr or f"Kiro command failed with code {returncode}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during Kiro chat execution: {str(e)}")
|
||||
if verbose:
|
||||
print(f"Exception occurred: {str(e)}", file=sys.stderr)
|
||||
# Clean up process reference on exception
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _make_non_blocking(self, pipe):
|
||||
"""Make a pipe non-blocking to prevent deadlock."""
|
||||
if not pipe or fcntl is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fd = pipe.fileno()
|
||||
if isinstance(fd, int) and fd >= 0:
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
except (AttributeError, ValueError, OSError):
|
||||
# In tests or when pipe doesn't support fileno()
|
||||
pass
|
||||
|
||||
|
||||
def _read_available(self, pipe):
|
||||
"""Read available data from a non-blocking pipe."""
|
||||
if not pipe:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Try to read up to 4KB at a time
|
||||
data = pipe.read(4096)
|
||||
# Ensure we always return a string, not None
|
||||
if data is None:
|
||||
return ""
|
||||
return data if data else ""
|
||||
except (IOError, OSError):
|
||||
# Would block or no data available
|
||||
return ""
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Native async execution using asyncio subprocess."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="Kiro CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
verbose = kwargs.get("verbose", True)
|
||||
prompt_file = kwargs.get("prompt_file", self.default_prompt_file)
|
||||
timeout = kwargs.get("timeout", self.default_timeout)
|
||||
|
||||
logger.info(f"Executing Kiro chat async - Prompt file: {prompt_file}, Timeout: {timeout}s")
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Construct effective prompt
|
||||
effective_prompt = (
|
||||
f"Please read and complete the task described in the file '{prompt_file}'. "
|
||||
f"The current content is:\n\n{enhanced_prompt}\n\n"
|
||||
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
|
||||
)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
self.command,
|
||||
"chat",
|
||||
"--no-interactive",
|
||||
"--trust-all-tools",
|
||||
effective_prompt
|
||||
]
|
||||
|
||||
logger.debug(f"Starting async Kiro chat command: {' '.join(cmd)}")
|
||||
if verbose:
|
||||
print(f"Starting {self.command} chat command (async)...", file=sys.stderr)
|
||||
print(f"Command: {' '.join(cmd)}", file=sys.stderr)
|
||||
print("-" * 60, file=sys.stderr)
|
||||
|
||||
# Create async subprocess
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=os.getcwd()
|
||||
)
|
||||
|
||||
# Set process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = process
|
||||
|
||||
try:
|
||||
# Wait for completion with timeout
|
||||
stdout_data, stderr_data = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Decode output
|
||||
stdout = stdout_data.decode("utf-8") if stdout_data else ""
|
||||
stderr = stderr_data.decode("utf-8") if stderr_data else ""
|
||||
|
||||
if verbose and stdout:
|
||||
print(stdout, file=sys.stderr)
|
||||
if verbose and stderr:
|
||||
print(stderr, file=sys.stderr)
|
||||
|
||||
# Check return code
|
||||
if process.returncode == 0:
|
||||
logger.debug(f"Async Kiro chat succeeded - Output length: {len(stdout)} chars")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=stdout,
|
||||
metadata={
|
||||
"tool": "kiro chat",
|
||||
"verbose": verbose,
|
||||
"async": True,
|
||||
"return_code": process.returncode
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Async Kiro chat failed - Return code: {process.returncode}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=stdout,
|
||||
error=stderr or f"Kiro chat failed with code {process.returncode}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout occurred
|
||||
logger.warning(f"Async Kiro chat timed out after {timeout} seconds")
|
||||
if verbose:
|
||||
print(f"Async Kiro chat timed out after {timeout} seconds", file=sys.stderr)
|
||||
|
||||
# Try to terminate process
|
||||
try:
|
||||
process.terminate()
|
||||
await asyncio.wait_for(process.wait(), timeout=3)
|
||||
except (asyncio.TimeoutError, ProcessLookupError):
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"Kiro command timed out after {timeout} seconds"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up process reference
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Async execution error: {str(e)}")
|
||||
if kwargs.get("verbose"):
|
||||
print(f"Async execution error: {str(e)}", file=sys.stderr)
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Kiro chat cost estimation (if applicable)."""
|
||||
# Return 0 for now
|
||||
return 0.0
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
# Restore original signal handlers
|
||||
self._restore_signal_handlers()
|
||||
|
||||
# Ensure any running process is terminated
|
||||
if hasattr(self, "_lock"):
|
||||
with self._lock:
|
||||
process = self.current_process if hasattr(self, "current_process") else None
|
||||
else:
|
||||
process = getattr(self, "current_process", None)
|
||||
|
||||
if process:
|
||||
try:
|
||||
if hasattr(process, "poll"):
|
||||
# Sync process
|
||||
if process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait(timeout=1)
|
||||
else:
|
||||
# Async process - can't do much in __del__
|
||||
pass
|
||||
except Exception as e:
|
||||
# Best-effort cleanup during interpreter shutdown
|
||||
logger.debug(f"Cleanup warning in __del__: {type(e).__name__}: {e}")
|
||||
@@ -0,0 +1,579 @@
|
||||
# ABOUTME: Q Chat adapter implementation for q CLI tool (DEPRECATED)
|
||||
# ABOUTME: Provides integration with q chat command for AI interactions
|
||||
# ABOUTME: NOTE: Consider using KiroAdapter instead - Q CLI rebranded to Kiro CLI
|
||||
|
||||
"""Q Chat adapter for Ralph Orchestrator.
|
||||
|
||||
.. deprecated::
|
||||
The QChatAdapter is deprecated in favor of KiroAdapter.
|
||||
Amazon Q Developer CLI has been rebranded to Kiro CLI (v1.20+).
|
||||
Use `-a kiro` instead of `-a qchat` or `-a q` for new projects.
|
||||
|
||||
Migration guide:
|
||||
- Config paths changed: ~/.aws/amazonq/ -> ~/.kiro/
|
||||
- MCP servers: ~/.kiro/settings/mcp.json
|
||||
- Project files: .kiro/ folder
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import threading
|
||||
import asyncio
|
||||
import time
|
||||
import warnings
|
||||
try:
|
||||
import fcntl # Unix-only
|
||||
except ModuleNotFoundError:
|
||||
fcntl = None
|
||||
from .base import ToolAdapter, ToolResponse
|
||||
from ..logging_config import RalphLogger
|
||||
|
||||
# Get logger for this module
|
||||
logger = RalphLogger.get_logger(RalphLogger.ADAPTER_QCHAT)
|
||||
|
||||
|
||||
class QChatAdapter(ToolAdapter):
|
||||
"""Adapter for Q Chat CLI tool.
|
||||
|
||||
.. deprecated::
|
||||
Use :class:`KiroAdapter` instead. The Amazon Q Developer CLI has been
|
||||
rebranded to Kiro CLI. This adapter is maintained for backwards
|
||||
compatibility only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Emit deprecation warning
|
||||
warnings.warn(
|
||||
"QChatAdapter is deprecated. Use KiroAdapter instead. "
|
||||
"Amazon Q Developer CLI has been rebranded to Kiro CLI (v1.20+). "
|
||||
"Run with '-a kiro' instead of '-a qchat'.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Get configuration from environment variables
|
||||
self.command = os.getenv("RALPH_QCHAT_COMMAND", "q")
|
||||
self.default_timeout = int(os.getenv("RALPH_QCHAT_TIMEOUT", "600"))
|
||||
self.default_prompt_file = os.getenv("RALPH_QCHAT_PROMPT_FILE", "PROMPT.md")
|
||||
self.trust_all_tools = os.getenv("RALPH_QCHAT_TRUST_TOOLS", "true").lower() == "true"
|
||||
self.no_interactive = os.getenv("RALPH_QCHAT_NO_INTERACTIVE", "true").lower() == "true"
|
||||
|
||||
# Initialize signal handler attributes before calling super()
|
||||
self._original_sigint = None
|
||||
self._original_sigterm = None
|
||||
|
||||
super().__init__("qchat")
|
||||
self.current_process = None
|
||||
self.shutdown_requested = False
|
||||
|
||||
# Thread synchronization
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Register signal handlers to propagate shutdown to subprocess
|
||||
self._register_signal_handlers()
|
||||
|
||||
logger.info(f"Q Chat adapter initialized - Command: {self.command}, "
|
||||
f"Default timeout: {self.default_timeout}s, "
|
||||
f"Trust tools: {self.trust_all_tools}")
|
||||
|
||||
def _register_signal_handlers(self):
|
||||
"""Register signal handlers and store originals."""
|
||||
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
|
||||
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
logger.debug("Signal handlers registered for SIGINT and SIGTERM")
|
||||
|
||||
def _restore_signal_handlers(self):
|
||||
"""Restore original signal handlers."""
|
||||
if hasattr(self, '_original_sigint') and self._original_sigint is not None:
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if hasattr(self, '_original_sigterm') and self._original_sigterm is not None:
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""Handle shutdown signals and terminate running subprocess."""
|
||||
with self._lock:
|
||||
self.shutdown_requested = True
|
||||
process = self.current_process
|
||||
|
||||
if process and process.poll() is None:
|
||||
logger.warning(f"Received signal {signum}, terminating q chat process...")
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=3)
|
||||
logger.debug("Process terminated gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Force killing q chat process...")
|
||||
process.kill()
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process may still be running after force kill")
|
||||
|
||||
def check_availability(self) -> bool:
|
||||
"""Check if q CLI is available."""
|
||||
try:
|
||||
# Try to check if q command exists
|
||||
result = subprocess.run(
|
||||
["which", self.command],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
text=True
|
||||
)
|
||||
available = result.returncode == 0
|
||||
logger.debug(f"Q command '{self.command}' availability check: {available}")
|
||||
return available
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError) as e:
|
||||
logger.warning(f"Q command availability check failed: {e}")
|
||||
return False
|
||||
|
||||
def execute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Execute q chat with the given prompt."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="q CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get verbose flag from kwargs
|
||||
verbose = kwargs.get('verbose', True)
|
||||
|
||||
# Get the prompt file path from kwargs if available
|
||||
prompt_file = kwargs.get('prompt_file', self.default_prompt_file)
|
||||
|
||||
logger.info(f"Executing Q chat - Prompt file: {prompt_file}, Verbose: {verbose}")
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Construct a more effective prompt for q chat
|
||||
# Tell it explicitly to edit the prompt file
|
||||
effective_prompt = (
|
||||
f"Please read and complete the task described in the file '{prompt_file}'. "
|
||||
f"The current content is:\n\n{enhanced_prompt}\n\n"
|
||||
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
|
||||
)
|
||||
|
||||
# Build command - q chat works with files by adding them to context
|
||||
# We pass the prompt through stdin and tell it to trust file operations
|
||||
cmd = [self.command, "chat"]
|
||||
|
||||
if self.no_interactive:
|
||||
cmd.append("--no-interactive")
|
||||
|
||||
if self.trust_all_tools:
|
||||
cmd.append("--trust-all-tools")
|
||||
|
||||
cmd.append(effective_prompt)
|
||||
|
||||
logger.debug(f"Command constructed: {' '.join(cmd)}")
|
||||
|
||||
timeout = kwargs.get("timeout", self.default_timeout)
|
||||
|
||||
if verbose:
|
||||
logger.info("Starting q chat command...")
|
||||
logger.info(f"Command: {' '.join(cmd)}")
|
||||
logger.info(f"Working directory: {os.getcwd()}")
|
||||
logger.info(f"Timeout: {timeout} seconds")
|
||||
print("-" * 60, file=sys.stderr)
|
||||
|
||||
# Use Popen for real-time output streaming
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
cwd=os.getcwd(),
|
||||
bufsize=0, # Unbuffered to prevent deadlock
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
# Set process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = process
|
||||
|
||||
# Make pipes non-blocking to prevent deadlock
|
||||
self._make_non_blocking(process.stdout)
|
||||
self._make_non_blocking(process.stderr)
|
||||
|
||||
# Collect output while streaming
|
||||
stdout_lines = []
|
||||
stderr_lines = []
|
||||
|
||||
start_time = time.time()
|
||||
last_output_time = start_time
|
||||
|
||||
while True:
|
||||
# Check for shutdown signal first with lock
|
||||
with self._lock:
|
||||
shutdown = self.shutdown_requested
|
||||
|
||||
if shutdown:
|
||||
if verbose:
|
||||
print("Shutdown requested, terminating q chat process...", file=sys.stderr)
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait(timeout=2)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="".join(stdout_lines),
|
||||
error="Process terminated due to shutdown signal"
|
||||
)
|
||||
|
||||
# Check for timeout
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Log progress every 30 seconds
|
||||
if int(elapsed_time) % 30 == 0 and int(elapsed_time) > 0:
|
||||
logger.debug(f"Q chat still running... elapsed: {elapsed_time:.1f}s / {timeout}s")
|
||||
|
||||
# Check if the process seems stuck (no output for a while)
|
||||
time_since_output = time.time() - last_output_time
|
||||
if time_since_output > 60:
|
||||
logger.info(f"No output received for {time_since_output:.1f}s, Q might be stuck")
|
||||
|
||||
if verbose:
|
||||
print(f"Q chat still running... elapsed: {elapsed_time:.1f}s / {timeout}s", file=sys.stderr)
|
||||
|
||||
if elapsed_time > timeout:
|
||||
logger.warning(f"Command timed out after {elapsed_time:.2f} seconds")
|
||||
if verbose:
|
||||
print(f"Command timed out after {elapsed_time:.2f} seconds", file=sys.stderr)
|
||||
|
||||
# Try to terminate gracefully first
|
||||
process.terminate()
|
||||
try:
|
||||
# Wait a bit for graceful termination
|
||||
process.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Graceful termination failed, force killing process")
|
||||
if verbose:
|
||||
print("Graceful termination failed, force killing process", file=sys.stderr)
|
||||
process.kill()
|
||||
# Wait for force kill to complete
|
||||
try:
|
||||
process.wait(timeout=2)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process may still be running after kill")
|
||||
if verbose:
|
||||
print("Warning: Process may still be running after kill", file=sys.stderr)
|
||||
|
||||
# Try to capture any remaining output after termination
|
||||
try:
|
||||
remaining_stdout = process.stdout.read()
|
||||
remaining_stderr = process.stderr.read()
|
||||
if remaining_stdout:
|
||||
stdout_lines.append(remaining_stdout)
|
||||
if remaining_stderr:
|
||||
stderr_lines.append(remaining_stderr)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read remaining output after timeout: {e}")
|
||||
if verbose:
|
||||
print(f"Warning: Could not read remaining output after timeout: {e}", file=sys.stderr)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="".join(stdout_lines),
|
||||
error=f"q chat command timed out after {elapsed_time:.2f} seconds"
|
||||
)
|
||||
|
||||
# Check if process is still running
|
||||
if process.poll() is not None:
|
||||
# Process finished, read remaining output
|
||||
remaining_stdout = process.stdout.read()
|
||||
remaining_stderr = process.stderr.read()
|
||||
|
||||
if remaining_stdout:
|
||||
stdout_lines.append(remaining_stdout)
|
||||
if verbose:
|
||||
print(f"{remaining_stdout}", end='', file=sys.stderr)
|
||||
|
||||
if remaining_stderr:
|
||||
stderr_lines.append(remaining_stderr)
|
||||
if verbose:
|
||||
print(f"{remaining_stderr}", end='', file=sys.stderr)
|
||||
|
||||
break
|
||||
|
||||
# Read available data without blocking
|
||||
try:
|
||||
# Read stdout
|
||||
stdout_data = self._read_available(process.stdout)
|
||||
if stdout_data:
|
||||
stdout_lines.append(stdout_data)
|
||||
last_output_time = time.time()
|
||||
if verbose:
|
||||
print(stdout_data, end='', file=sys.stderr)
|
||||
|
||||
# Read stderr
|
||||
stderr_data = self._read_available(process.stderr)
|
||||
if stderr_data:
|
||||
stderr_lines.append(stderr_data)
|
||||
last_output_time = time.time()
|
||||
if verbose:
|
||||
print(stderr_data, end='', file=sys.stderr)
|
||||
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.01)
|
||||
|
||||
except BlockingIOError:
|
||||
# Non-blocking read returns BlockingIOError when no data available
|
||||
pass
|
||||
|
||||
# Get final return code
|
||||
returncode = process.poll()
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"Process completed - Return code: {returncode}, Execution time: {execution_time:.2f}s")
|
||||
|
||||
if verbose:
|
||||
print("-" * 60, file=sys.stderr)
|
||||
print(f"Process completed with return code: {returncode}", file=sys.stderr)
|
||||
print(f"Total execution time: {execution_time:.2f} seconds", file=sys.stderr)
|
||||
|
||||
# Clean up process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
# Combine output
|
||||
full_stdout = "".join(stdout_lines)
|
||||
full_stderr = "".join(stderr_lines)
|
||||
|
||||
if returncode == 0:
|
||||
logger.debug(f"Q chat succeeded - Output length: {len(full_stdout)} chars")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=full_stdout,
|
||||
metadata={
|
||||
"tool": "q chat",
|
||||
"execution_time": execution_time,
|
||||
"verbose": verbose,
|
||||
"return_code": returncode
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Q chat failed - Return code: {returncode}, Error: {full_stderr[:200]}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=full_stdout,
|
||||
error=full_stderr or f"q chat command failed with code {returncode}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception during Q chat execution: {str(e)}")
|
||||
if verbose:
|
||||
print(f"Exception occurred: {str(e)}", file=sys.stderr)
|
||||
# Clean up process reference on exception (matches async version's finally block)
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _make_non_blocking(self, pipe):
|
||||
"""Make a pipe non-blocking to prevent deadlock."""
|
||||
if not pipe or fcntl is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fd = pipe.fileno()
|
||||
if isinstance(fd, int) and fd >= 0:
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
except (AttributeError, ValueError, OSError):
|
||||
# In tests or when pipe doesn't support fileno()
|
||||
pass
|
||||
|
||||
|
||||
def _read_available(self, pipe):
|
||||
"""Read available data from a non-blocking pipe."""
|
||||
if not pipe:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Try to read up to 4KB at a time
|
||||
data = pipe.read(4096)
|
||||
# Ensure we always return a string, not None
|
||||
if data is None:
|
||||
return ""
|
||||
return data if data else ""
|
||||
except (IOError, OSError):
|
||||
# Would block or no data available
|
||||
return ""
|
||||
|
||||
async def aexecute(self, prompt: str, **kwargs) -> ToolResponse:
|
||||
"""Native async execution using asyncio subprocess."""
|
||||
if not self.available:
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error="q CLI is not available"
|
||||
)
|
||||
|
||||
try:
|
||||
verbose = kwargs.get('verbose', True)
|
||||
prompt_file = kwargs.get('prompt_file', self.default_prompt_file)
|
||||
timeout = kwargs.get('timeout', self.default_timeout)
|
||||
|
||||
logger.info(f"Executing Q chat async - Prompt file: {prompt_file}, Timeout: {timeout}s")
|
||||
|
||||
# Enhance prompt with orchestration instructions
|
||||
enhanced_prompt = self._enhance_prompt_with_instructions(prompt)
|
||||
|
||||
# Construct effective prompt
|
||||
effective_prompt = (
|
||||
f"Please read and complete the task described in the file '{prompt_file}'. "
|
||||
f"The current content is:\n\n{enhanced_prompt}\n\n"
|
||||
f"Edit the file '{prompt_file}' directly to add your solution and progress updates."
|
||||
)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
self.command,
|
||||
"chat",
|
||||
"--no-interactive",
|
||||
"--trust-all-tools",
|
||||
effective_prompt
|
||||
]
|
||||
|
||||
logger.debug(f"Starting async Q chat command: {' '.join(cmd)}")
|
||||
if verbose:
|
||||
print("Starting q chat command (async)...", file=sys.stderr)
|
||||
print(f"Command: {' '.join(cmd)}", file=sys.stderr)
|
||||
print("-" * 60, file=sys.stderr)
|
||||
|
||||
# Create async subprocess
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=os.getcwd()
|
||||
)
|
||||
|
||||
# Set process reference with lock
|
||||
with self._lock:
|
||||
self.current_process = process
|
||||
|
||||
try:
|
||||
# Wait for completion with timeout
|
||||
stdout_data, stderr_data = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Decode output
|
||||
stdout = stdout_data.decode('utf-8') if stdout_data else ""
|
||||
stderr = stderr_data.decode('utf-8') if stderr_data else ""
|
||||
|
||||
if verbose and stdout:
|
||||
print(stdout, file=sys.stderr)
|
||||
if verbose and stderr:
|
||||
print(stderr, file=sys.stderr)
|
||||
|
||||
# Check return code
|
||||
if process.returncode == 0:
|
||||
logger.debug(f"Async Q chat succeeded - Output length: {len(stdout)} chars")
|
||||
return ToolResponse(
|
||||
success=True,
|
||||
output=stdout,
|
||||
metadata={
|
||||
"tool": "q chat",
|
||||
"verbose": verbose,
|
||||
"async": True,
|
||||
"return_code": process.returncode
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Async Q chat failed - Return code: {process.returncode}")
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output=stdout,
|
||||
error=stderr or f"q chat failed with code {process.returncode}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout occurred
|
||||
logger.warning(f"Async q chat timed out after {timeout} seconds")
|
||||
if verbose:
|
||||
print(f"Async q chat timed out after {timeout} seconds", file=sys.stderr)
|
||||
|
||||
# Try to terminate process
|
||||
try:
|
||||
process.terminate()
|
||||
await asyncio.wait_for(process.wait(), timeout=3)
|
||||
except (asyncio.TimeoutError, ProcessLookupError):
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=f"q chat command timed out after {timeout} seconds"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up process reference
|
||||
with self._lock:
|
||||
self.current_process = None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Async execution error: {str(e)}")
|
||||
if kwargs.get('verbose'):
|
||||
print(f"Async execution error: {str(e)}", file=sys.stderr)
|
||||
return ToolResponse(
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def estimate_cost(self, prompt: str) -> float:
|
||||
"""Q chat cost estimation (if applicable)."""
|
||||
# Q chat might be free or have different pricing
|
||||
# Return 0 for now, can be updated based on actual pricing
|
||||
return 0.0
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup on deletion."""
|
||||
# Restore original signal handlers
|
||||
self._restore_signal_handlers()
|
||||
|
||||
# Ensure any running process is terminated
|
||||
if hasattr(self, '_lock'):
|
||||
with self._lock:
|
||||
process = self.current_process if hasattr(self, 'current_process') else None
|
||||
else:
|
||||
process = getattr(self, 'current_process', None)
|
||||
|
||||
if process:
|
||||
try:
|
||||
if hasattr(process, 'poll'):
|
||||
# Sync process
|
||||
if process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait(timeout=1)
|
||||
else:
|
||||
# Async process - can't do much in __del__
|
||||
pass
|
||||
except Exception as e:
|
||||
# Best-effort cleanup during interpreter shutdown
|
||||
# Log at debug level since __del__ is unreliable
|
||||
logger.debug(f"Cleanup warning in __del__: {type(e).__name__}: {e}")
|
||||
Reference in New Issue
Block a user