Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Internal implementation details."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,124 @@
|
||||
"""Internal client implementation."""
|
||||
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from ..types import (
|
||||
ClaudeAgentOptions,
|
||||
HookEvent,
|
||||
HookMatcher,
|
||||
Message,
|
||||
)
|
||||
from .message_parser import parse_message
|
||||
from .query import Query
|
||||
from .transport import Transport
|
||||
from .transport.subprocess_cli import SubprocessCLITransport
|
||||
|
||||
|
||||
class InternalClient:
|
||||
"""Internal client implementation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the internal client."""
|
||||
|
||||
def _convert_hooks_to_internal_format(
|
||||
self, hooks: dict[HookEvent, list[HookMatcher]]
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
"""Convert HookMatcher format to internal Query format."""
|
||||
internal_hooks: dict[str, list[dict[str, Any]]] = {}
|
||||
for event, matchers in hooks.items():
|
||||
internal_hooks[event] = []
|
||||
for matcher in matchers:
|
||||
# Convert HookMatcher to internal dict format
|
||||
internal_matcher: dict[str, Any] = {
|
||||
"matcher": matcher.matcher if hasattr(matcher, "matcher") else None,
|
||||
"hooks": matcher.hooks if hasattr(matcher, "hooks") else [],
|
||||
}
|
||||
if hasattr(matcher, "timeout") and matcher.timeout is not None:
|
||||
internal_matcher["timeout"] = matcher.timeout
|
||||
internal_hooks[event].append(internal_matcher)
|
||||
return internal_hooks
|
||||
|
||||
async def process_query(
|
||||
self,
|
||||
prompt: str | AsyncIterable[dict[str, Any]],
|
||||
options: ClaudeAgentOptions,
|
||||
transport: Transport | None = None,
|
||||
) -> AsyncIterator[Message]:
|
||||
"""Process a query through transport and Query."""
|
||||
|
||||
# Validate and configure permission settings (matching TypeScript SDK logic)
|
||||
configured_options = options
|
||||
if options.can_use_tool:
|
||||
# canUseTool callback requires streaming mode (AsyncIterable prompt)
|
||||
if isinstance(prompt, str):
|
||||
raise ValueError(
|
||||
"can_use_tool callback requires streaming mode. "
|
||||
"Please provide prompt as an AsyncIterable instead of a string."
|
||||
)
|
||||
|
||||
# canUseTool and permission_prompt_tool_name are mutually exclusive
|
||||
if options.permission_prompt_tool_name:
|
||||
raise ValueError(
|
||||
"can_use_tool callback cannot be used with permission_prompt_tool_name. "
|
||||
"Please use one or the other."
|
||||
)
|
||||
|
||||
# Automatically set permission_prompt_tool_name to "stdio" for control protocol
|
||||
configured_options = replace(options, permission_prompt_tool_name="stdio")
|
||||
|
||||
# Use provided transport or create subprocess transport
|
||||
if transport is not None:
|
||||
chosen_transport = transport
|
||||
else:
|
||||
chosen_transport = SubprocessCLITransport(
|
||||
prompt=prompt,
|
||||
options=configured_options,
|
||||
)
|
||||
|
||||
# Connect transport
|
||||
await chosen_transport.connect()
|
||||
|
||||
# Extract SDK MCP servers from configured options
|
||||
sdk_mcp_servers = {}
|
||||
if configured_options.mcp_servers and isinstance(
|
||||
configured_options.mcp_servers, dict
|
||||
):
|
||||
for name, config in configured_options.mcp_servers.items():
|
||||
if isinstance(config, dict) and config.get("type") == "sdk":
|
||||
sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item]
|
||||
|
||||
# Create Query to handle control protocol
|
||||
is_streaming = not isinstance(prompt, str)
|
||||
query = Query(
|
||||
transport=chosen_transport,
|
||||
is_streaming_mode=is_streaming,
|
||||
can_use_tool=configured_options.can_use_tool,
|
||||
hooks=self._convert_hooks_to_internal_format(configured_options.hooks)
|
||||
if configured_options.hooks
|
||||
else None,
|
||||
sdk_mcp_servers=sdk_mcp_servers,
|
||||
)
|
||||
|
||||
try:
|
||||
# Start reading messages
|
||||
await query.start()
|
||||
|
||||
# Initialize if streaming
|
||||
if is_streaming:
|
||||
await query.initialize()
|
||||
|
||||
# Stream input if it's an AsyncIterable
|
||||
if isinstance(prompt, AsyncIterable) and query._tg:
|
||||
# Start streaming in background
|
||||
# Create a task that will run in the background
|
||||
query._tg.start_soon(query.stream_input, prompt)
|
||||
# For string prompts, the prompt is already passed via CLI args
|
||||
|
||||
# Yield parsed messages
|
||||
async for data in query.receive_messages():
|
||||
yield parse_message(data)
|
||||
|
||||
finally:
|
||||
await query.close()
|
||||
@@ -0,0 +1,177 @@
|
||||
"""Message parser for Claude Code SDK responses."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .._errors import MessageParseError
|
||||
from ..types import (
|
||||
AssistantMessage,
|
||||
ContentBlock,
|
||||
Message,
|
||||
ResultMessage,
|
||||
StreamEvent,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_message(data: dict[str, Any]) -> Message:
|
||||
"""
|
||||
Parse message from CLI output into typed Message objects.
|
||||
|
||||
Args:
|
||||
data: Raw message dictionary from CLI output
|
||||
|
||||
Returns:
|
||||
Parsed Message object
|
||||
|
||||
Raises:
|
||||
MessageParseError: If parsing fails or message type is unrecognized
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise MessageParseError(
|
||||
f"Invalid message data type (expected dict, got {type(data).__name__})",
|
||||
data,
|
||||
)
|
||||
|
||||
message_type = data.get("type")
|
||||
if not message_type:
|
||||
raise MessageParseError("Message missing 'type' field", data)
|
||||
|
||||
match message_type:
|
||||
case "user":
|
||||
try:
|
||||
parent_tool_use_id = data.get("parent_tool_use_id")
|
||||
uuid = data.get("uuid")
|
||||
if isinstance(data["message"]["content"], list):
|
||||
user_content_blocks: list[ContentBlock] = []
|
||||
for block in data["message"]["content"]:
|
||||
match block["type"]:
|
||||
case "text":
|
||||
user_content_blocks.append(
|
||||
TextBlock(text=block["text"])
|
||||
)
|
||||
case "tool_use":
|
||||
user_content_blocks.append(
|
||||
ToolUseBlock(
|
||||
id=block["id"],
|
||||
name=block["name"],
|
||||
input=block["input"],
|
||||
)
|
||||
)
|
||||
case "tool_result":
|
||||
user_content_blocks.append(
|
||||
ToolResultBlock(
|
||||
tool_use_id=block["tool_use_id"],
|
||||
content=block.get("content"),
|
||||
is_error=block.get("is_error"),
|
||||
)
|
||||
)
|
||||
return UserMessage(
|
||||
content=user_content_blocks,
|
||||
uuid=uuid,
|
||||
parent_tool_use_id=parent_tool_use_id,
|
||||
)
|
||||
return UserMessage(
|
||||
content=data["message"]["content"],
|
||||
uuid=uuid,
|
||||
parent_tool_use_id=parent_tool_use_id,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in user message: {e}", data
|
||||
) from e
|
||||
|
||||
case "assistant":
|
||||
try:
|
||||
content_blocks: list[ContentBlock] = []
|
||||
for block in data["message"]["content"]:
|
||||
match block["type"]:
|
||||
case "text":
|
||||
content_blocks.append(TextBlock(text=block["text"]))
|
||||
case "thinking":
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
thinking=block["thinking"],
|
||||
signature=block["signature"],
|
||||
)
|
||||
)
|
||||
case "tool_use":
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
id=block["id"],
|
||||
name=block["name"],
|
||||
input=block["input"],
|
||||
)
|
||||
)
|
||||
case "tool_result":
|
||||
content_blocks.append(
|
||||
ToolResultBlock(
|
||||
tool_use_id=block["tool_use_id"],
|
||||
content=block.get("content"),
|
||||
is_error=block.get("is_error"),
|
||||
)
|
||||
)
|
||||
|
||||
return AssistantMessage(
|
||||
content=content_blocks,
|
||||
model=data["message"]["model"],
|
||||
parent_tool_use_id=data.get("parent_tool_use_id"),
|
||||
error=data["message"].get("error"),
|
||||
)
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in assistant message: {e}", data
|
||||
) from e
|
||||
|
||||
case "system":
|
||||
try:
|
||||
return SystemMessage(
|
||||
subtype=data["subtype"],
|
||||
data=data,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in system message: {e}", data
|
||||
) from e
|
||||
|
||||
case "result":
|
||||
try:
|
||||
return ResultMessage(
|
||||
subtype=data["subtype"],
|
||||
duration_ms=data["duration_ms"],
|
||||
duration_api_ms=data["duration_api_ms"],
|
||||
is_error=data["is_error"],
|
||||
num_turns=data["num_turns"],
|
||||
session_id=data["session_id"],
|
||||
total_cost_usd=data.get("total_cost_usd"),
|
||||
usage=data.get("usage"),
|
||||
result=data.get("result"),
|
||||
structured_output=data.get("structured_output"),
|
||||
)
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in result message: {e}", data
|
||||
) from e
|
||||
|
||||
case "stream_event":
|
||||
try:
|
||||
return StreamEvent(
|
||||
uuid=data["uuid"],
|
||||
session_id=data["session_id"],
|
||||
event=data["event"],
|
||||
parent_tool_use_id=data.get("parent_tool_use_id"),
|
||||
)
|
||||
except KeyError as e:
|
||||
raise MessageParseError(
|
||||
f"Missing required field in stream_event message: {e}", data
|
||||
) from e
|
||||
|
||||
case _:
|
||||
raise MessageParseError(f"Unknown message type: {message_type}", data)
|
||||
@@ -0,0 +1,621 @@
|
||||
"""Query class for handling bidirectional control protocol."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import anyio
|
||||
from mcp.types import (
|
||||
CallToolRequest,
|
||||
CallToolRequestParams,
|
||||
ListToolsRequest,
|
||||
)
|
||||
|
||||
from ..types import (
|
||||
PermissionResultAllow,
|
||||
PermissionResultDeny,
|
||||
SDKControlPermissionRequest,
|
||||
SDKControlRequest,
|
||||
SDKControlResponse,
|
||||
SDKHookCallbackRequest,
|
||||
ToolPermissionContext,
|
||||
)
|
||||
from .transport import Transport
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server import Server as McpServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_hook_output_for_cli(hook_output: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert Python-safe field names to CLI-expected field names.
|
||||
|
||||
The Python SDK uses `async_` and `continue_` to avoid keyword conflicts,
|
||||
but the CLI expects `async` and `continue`. This function performs the
|
||||
necessary conversion.
|
||||
"""
|
||||
converted = {}
|
||||
for key, value in hook_output.items():
|
||||
# Convert Python-safe names to JavaScript names
|
||||
if key == "async_":
|
||||
converted["async"] = value
|
||||
elif key == "continue_":
|
||||
converted["continue"] = value
|
||||
else:
|
||||
converted[key] = value
|
||||
return converted
|
||||
|
||||
|
||||
class Query:
|
||||
"""Handles bidirectional control protocol on top of Transport.
|
||||
|
||||
This class manages:
|
||||
- Control request/response routing
|
||||
- Hook callbacks
|
||||
- Tool permission callbacks
|
||||
- Message streaming
|
||||
- Initialization handshake
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
is_streaming_mode: bool,
|
||||
can_use_tool: Callable[
|
||||
[str, dict[str, Any], ToolPermissionContext],
|
||||
Awaitable[PermissionResultAllow | PermissionResultDeny],
|
||||
]
|
||||
| None = None,
|
||||
hooks: dict[str, list[dict[str, Any]]] | None = None,
|
||||
sdk_mcp_servers: dict[str, "McpServer"] | None = None,
|
||||
initialize_timeout: float = 60.0,
|
||||
):
|
||||
"""Initialize Query with transport and callbacks.
|
||||
|
||||
Args:
|
||||
transport: Low-level transport for I/O
|
||||
is_streaming_mode: Whether using streaming (bidirectional) mode
|
||||
can_use_tool: Optional callback for tool permission requests
|
||||
hooks: Optional hook configurations
|
||||
sdk_mcp_servers: Optional SDK MCP server instances
|
||||
initialize_timeout: Timeout in seconds for the initialize request
|
||||
"""
|
||||
self._initialize_timeout = initialize_timeout
|
||||
self.transport = transport
|
||||
self.is_streaming_mode = is_streaming_mode
|
||||
self.can_use_tool = can_use_tool
|
||||
self.hooks = hooks or {}
|
||||
self.sdk_mcp_servers = sdk_mcp_servers or {}
|
||||
|
||||
# Control protocol state
|
||||
self.pending_control_responses: dict[str, anyio.Event] = {}
|
||||
self.pending_control_results: dict[str, dict[str, Any] | Exception] = {}
|
||||
self.hook_callbacks: dict[str, Callable[..., Any]] = {}
|
||||
self.next_callback_id = 0
|
||||
self._request_counter = 0
|
||||
|
||||
# Message stream
|
||||
self._message_send, self._message_receive = anyio.create_memory_object_stream[
|
||||
dict[str, Any]
|
||||
](max_buffer_size=100)
|
||||
self._tg: anyio.abc.TaskGroup | None = None
|
||||
self._initialized = False
|
||||
self._closed = False
|
||||
self._initialization_result: dict[str, Any] | None = None
|
||||
|
||||
# Track first result for proper stream closure with SDK MCP servers
|
||||
self._first_result_event = anyio.Event()
|
||||
self._stream_close_timeout = (
|
||||
float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0
|
||||
) # Convert ms to seconds
|
||||
|
||||
async def initialize(self) -> dict[str, Any] | None:
|
||||
"""Initialize control protocol if in streaming mode.
|
||||
|
||||
Returns:
|
||||
Initialize response with supported commands, or None if not streaming
|
||||
"""
|
||||
if not self.is_streaming_mode:
|
||||
return None
|
||||
|
||||
# Build hooks configuration for initialization
|
||||
hooks_config: dict[str, Any] = {}
|
||||
if self.hooks:
|
||||
for event, matchers in self.hooks.items():
|
||||
if matchers:
|
||||
hooks_config[event] = []
|
||||
for matcher in matchers:
|
||||
callback_ids = []
|
||||
for callback in matcher.get("hooks", []):
|
||||
callback_id = f"hook_{self.next_callback_id}"
|
||||
self.next_callback_id += 1
|
||||
self.hook_callbacks[callback_id] = callback
|
||||
callback_ids.append(callback_id)
|
||||
hook_matcher_config: dict[str, Any] = {
|
||||
"matcher": matcher.get("matcher"),
|
||||
"hookCallbackIds": callback_ids,
|
||||
}
|
||||
if matcher.get("timeout") is not None:
|
||||
hook_matcher_config["timeout"] = matcher.get("timeout")
|
||||
hooks_config[event].append(hook_matcher_config)
|
||||
|
||||
# Send initialize request
|
||||
request = {
|
||||
"subtype": "initialize",
|
||||
"hooks": hooks_config if hooks_config else None,
|
||||
}
|
||||
|
||||
# Use longer timeout for initialize since MCP servers may take time to start
|
||||
response = await self._send_control_request(
|
||||
request, timeout=self._initialize_timeout
|
||||
)
|
||||
self._initialized = True
|
||||
self._initialization_result = response # Store for later access
|
||||
return response
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start reading messages from transport."""
|
||||
if self._tg is None:
|
||||
self._tg = anyio.create_task_group()
|
||||
await self._tg.__aenter__()
|
||||
self._tg.start_soon(self._read_messages)
|
||||
|
||||
async def _read_messages(self) -> None:
|
||||
"""Read messages from transport and route them."""
|
||||
try:
|
||||
async for message in self.transport.read_messages():
|
||||
if self._closed:
|
||||
break
|
||||
|
||||
msg_type = message.get("type")
|
||||
|
||||
# Route control messages
|
||||
if msg_type == "control_response":
|
||||
response = message.get("response", {})
|
||||
request_id = response.get("request_id")
|
||||
if request_id in self.pending_control_responses:
|
||||
event = self.pending_control_responses[request_id]
|
||||
if response.get("subtype") == "error":
|
||||
self.pending_control_results[request_id] = Exception(
|
||||
response.get("error", "Unknown error")
|
||||
)
|
||||
else:
|
||||
self.pending_control_results[request_id] = response
|
||||
event.set()
|
||||
continue
|
||||
|
||||
elif msg_type == "control_request":
|
||||
# Handle incoming control requests from CLI
|
||||
# Cast message to SDKControlRequest for type safety
|
||||
request: SDKControlRequest = message # type: ignore[assignment]
|
||||
if self._tg:
|
||||
self._tg.start_soon(self._handle_control_request, request)
|
||||
continue
|
||||
|
||||
elif msg_type == "control_cancel_request":
|
||||
# Handle cancel requests
|
||||
# TODO: Implement cancellation support
|
||||
continue
|
||||
|
||||
# Track results for proper stream closure
|
||||
if msg_type == "result":
|
||||
self._first_result_event.set()
|
||||
|
||||
# Regular SDK messages go to the stream
|
||||
await self._message_send.send(message)
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# Task was cancelled - this is expected behavior
|
||||
logger.debug("Read task cancelled")
|
||||
raise # Re-raise to properly handle cancellation
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in message reader: {e}")
|
||||
# Signal all pending control requests so they fail fast instead of timing out
|
||||
for request_id, event in list(self.pending_control_responses.items()):
|
||||
if request_id not in self.pending_control_results:
|
||||
self.pending_control_results[request_id] = e
|
||||
event.set()
|
||||
# Put error in stream so iterators can handle it
|
||||
await self._message_send.send({"type": "error", "error": str(e)})
|
||||
finally:
|
||||
# Always signal end of stream
|
||||
await self._message_send.send({"type": "end"})
|
||||
|
||||
async def _handle_control_request(self, request: SDKControlRequest) -> None:
|
||||
"""Handle incoming control request from CLI."""
|
||||
request_id = request["request_id"]
|
||||
request_data = request["request"]
|
||||
subtype = request_data["subtype"]
|
||||
|
||||
try:
|
||||
response_data: dict[str, Any] = {}
|
||||
|
||||
if subtype == "can_use_tool":
|
||||
permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
|
||||
original_input = permission_request["input"]
|
||||
# Handle tool permission request
|
||||
if not self.can_use_tool:
|
||||
raise Exception("canUseTool callback is not provided")
|
||||
|
||||
context = ToolPermissionContext(
|
||||
signal=None, # TODO: Add abort signal support
|
||||
suggestions=permission_request.get("permission_suggestions", [])
|
||||
or [],
|
||||
)
|
||||
|
||||
response = await self.can_use_tool(
|
||||
permission_request["tool_name"],
|
||||
permission_request["input"],
|
||||
context,
|
||||
)
|
||||
|
||||
# Convert PermissionResult to expected dict format
|
||||
if isinstance(response, PermissionResultAllow):
|
||||
response_data = {
|
||||
"behavior": "allow",
|
||||
"updatedInput": (
|
||||
response.updated_input
|
||||
if response.updated_input is not None
|
||||
else original_input
|
||||
),
|
||||
}
|
||||
if response.updated_permissions is not None:
|
||||
response_data["updatedPermissions"] = [
|
||||
permission.to_dict()
|
||||
for permission in response.updated_permissions
|
||||
]
|
||||
elif isinstance(response, PermissionResultDeny):
|
||||
response_data = {"behavior": "deny", "message": response.message}
|
||||
if response.interrupt:
|
||||
response_data["interrupt"] = response.interrupt
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Tool permission callback must return PermissionResult (PermissionResultAllow or PermissionResultDeny), got {type(response)}"
|
||||
)
|
||||
|
||||
elif subtype == "hook_callback":
|
||||
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
|
||||
# Handle hook callback
|
||||
callback_id = hook_callback_request["callback_id"]
|
||||
callback = self.hook_callbacks.get(callback_id)
|
||||
if not callback:
|
||||
raise Exception(f"No hook callback found for ID: {callback_id}")
|
||||
|
||||
hook_output = await callback(
|
||||
request_data.get("input"),
|
||||
request_data.get("tool_use_id"),
|
||||
{"signal": None}, # TODO: Add abort signal support
|
||||
)
|
||||
# Convert Python-safe field names (async_, continue_) to CLI-expected names (async, continue)
|
||||
response_data = _convert_hook_output_for_cli(hook_output)
|
||||
|
||||
elif subtype == "mcp_message":
|
||||
# Handle SDK MCP request
|
||||
server_name = request_data.get("server_name")
|
||||
mcp_message = request_data.get("message")
|
||||
|
||||
if not server_name or not mcp_message:
|
||||
raise Exception("Missing server_name or message for MCP request")
|
||||
|
||||
# Type narrowing - we've verified these are not None above
|
||||
assert isinstance(server_name, str)
|
||||
assert isinstance(mcp_message, dict)
|
||||
mcp_response = await self._handle_sdk_mcp_request(
|
||||
server_name, mcp_message
|
||||
)
|
||||
# Wrap the MCP response as expected by the control protocol
|
||||
response_data = {"mcp_response": mcp_response}
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported control request subtype: {subtype}")
|
||||
|
||||
# Send success response
|
||||
success_response: SDKControlResponse = {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "success",
|
||||
"request_id": request_id,
|
||||
"response": response_data,
|
||||
},
|
||||
}
|
||||
await self.transport.write(json.dumps(success_response) + "\n")
|
||||
|
||||
except Exception as e:
|
||||
# Send error response
|
||||
error_response: SDKControlResponse = {
|
||||
"type": "control_response",
|
||||
"response": {
|
||||
"subtype": "error",
|
||||
"request_id": request_id,
|
||||
"error": str(e),
|
||||
},
|
||||
}
|
||||
await self.transport.write(json.dumps(error_response) + "\n")
|
||||
|
||||
async def _send_control_request(
|
||||
self, request: dict[str, Any], timeout: float = 60.0
|
||||
) -> dict[str, Any]:
|
||||
"""Send control request to CLI and wait for response.
|
||||
|
||||
Args:
|
||||
request: The control request to send
|
||||
timeout: Timeout in seconds to wait for response (default 60s)
|
||||
"""
|
||||
if not self.is_streaming_mode:
|
||||
raise Exception("Control requests require streaming mode")
|
||||
|
||||
# Generate unique request ID
|
||||
self._request_counter += 1
|
||||
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
|
||||
|
||||
# Create event for response
|
||||
event = anyio.Event()
|
||||
self.pending_control_responses[request_id] = event
|
||||
|
||||
# Build and send request
|
||||
control_request = {
|
||||
"type": "control_request",
|
||||
"request_id": request_id,
|
||||
"request": request,
|
||||
}
|
||||
|
||||
await self.transport.write(json.dumps(control_request) + "\n")
|
||||
|
||||
# Wait for response
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
await event.wait()
|
||||
|
||||
result = self.pending_control_results.pop(request_id)
|
||||
self.pending_control_responses.pop(request_id, None)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
|
||||
response_data = result.get("response", {})
|
||||
return response_data if isinstance(response_data, dict) else {}
|
||||
except TimeoutError as e:
|
||||
self.pending_control_responses.pop(request_id, None)
|
||||
self.pending_control_results.pop(request_id, None)
|
||||
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
|
||||
|
||||
async def _handle_sdk_mcp_request(
|
||||
self, server_name: str, message: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Handle an MCP request for an SDK server.
|
||||
|
||||
This acts as a bridge between JSONRPC messages from the CLI
|
||||
and the in-process MCP server. Ideally the MCP SDK would provide
|
||||
a method to handle raw JSONRPC, but for now we route manually.
|
||||
|
||||
Args:
|
||||
server_name: Name of the SDK MCP server
|
||||
message: The JSONRPC message
|
||||
|
||||
Returns:
|
||||
The response message
|
||||
"""
|
||||
if server_name not in self.sdk_mcp_servers:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Server '{server_name}' not found",
|
||||
},
|
||||
}
|
||||
|
||||
server = self.sdk_mcp_servers[server_name]
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
|
||||
try:
|
||||
# TODO: Python MCP SDK lacks the Transport abstraction that TypeScript has.
|
||||
# TypeScript: server.connect(transport) allows custom transports
|
||||
# Python: server.run(read_stream, write_stream) requires actual streams
|
||||
#
|
||||
# This forces us to manually route methods. When Python MCP adds Transport
|
||||
# support, we can refactor to match the TypeScript approach.
|
||||
if method == "initialize":
|
||||
# Handle MCP initialization - hardcoded for tools only, no listChanged
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {} # Tools capability without listChanged
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": server.name,
|
||||
"version": server.version or "1.0.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
elif method == "tools/list":
|
||||
request = ListToolsRequest(method=method)
|
||||
handler = server.request_handlers.get(ListToolsRequest)
|
||||
if handler:
|
||||
result = await handler(request)
|
||||
# Convert MCP result to JSONRPC response
|
||||
tools_data = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": (
|
||||
tool.inputSchema.model_dump()
|
||||
if hasattr(tool.inputSchema, "model_dump")
|
||||
else tool.inputSchema
|
||||
)
|
||||
if tool.inputSchema
|
||||
else {},
|
||||
}
|
||||
for tool in result.root.tools # type: ignore[union-attr]
|
||||
]
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"result": {"tools": tools_data},
|
||||
}
|
||||
|
||||
elif method == "tools/call":
|
||||
call_request = CallToolRequest(
|
||||
method=method,
|
||||
params=CallToolRequestParams(
|
||||
name=params.get("name"), arguments=params.get("arguments", {})
|
||||
),
|
||||
)
|
||||
handler = server.request_handlers.get(CallToolRequest)
|
||||
if handler:
|
||||
result = await handler(call_request)
|
||||
# Convert MCP result to JSONRPC response
|
||||
content = []
|
||||
for item in result.root.content: # type: ignore[union-attr]
|
||||
if hasattr(item, "text"):
|
||||
content.append({"type": "text", "text": item.text})
|
||||
elif hasattr(item, "data") and hasattr(item, "mimeType"):
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.data,
|
||||
"mimeType": item.mimeType,
|
||||
}
|
||||
)
|
||||
|
||||
response_data = {"content": content}
|
||||
if hasattr(result.root, "is_error") and result.root.is_error:
|
||||
response_data["is_error"] = True # type: ignore[assignment]
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"result": response_data,
|
||||
}
|
||||
|
||||
elif method == "notifications/initialized":
|
||||
# Handle initialized notification - just acknowledge it
|
||||
return {"jsonrpc": "2.0", "result": {}}
|
||||
|
||||
# Add more methods here as MCP SDK adds them (resources, prompts, etc.)
|
||||
# This is the limitation Ashwin pointed out - we have to manually update
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"error": {"code": -32601, "message": f"Method '{method}' not found"},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message.get("id"),
|
||||
"error": {"code": -32603, "message": str(e)},
|
||||
}
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Send interrupt control request."""
|
||||
await self._send_control_request({"subtype": "interrupt"})
|
||||
|
||||
async def set_permission_mode(self, mode: str) -> None:
|
||||
"""Change permission mode."""
|
||||
await self._send_control_request(
|
||||
{
|
||||
"subtype": "set_permission_mode",
|
||||
"mode": mode,
|
||||
}
|
||||
)
|
||||
|
||||
async def set_model(self, model: str | None) -> None:
|
||||
"""Change the AI model."""
|
||||
await self._send_control_request(
|
||||
{
|
||||
"subtype": "set_model",
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
async def rewind_files(self, user_message_id: str) -> None:
|
||||
"""Rewind tracked files to their state at a specific user message.
|
||||
|
||||
Requires file checkpointing to be enabled via the `enable_file_checkpointing` option.
|
||||
|
||||
Args:
|
||||
user_message_id: UUID of the user message to rewind to
|
||||
"""
|
||||
await self._send_control_request(
|
||||
{
|
||||
"subtype": "rewind_files",
|
||||
"user_message_id": user_message_id,
|
||||
}
|
||||
)
|
||||
|
||||
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
|
||||
"""Stream input messages to transport.
|
||||
|
||||
If SDK MCP servers or hooks are present, waits for the first result
|
||||
before closing stdin to allow bidirectional control protocol communication.
|
||||
"""
|
||||
try:
|
||||
async for message in stream:
|
||||
if self._closed:
|
||||
break
|
||||
await self.transport.write(json.dumps(message) + "\n")
|
||||
|
||||
# If we have SDK MCP servers or hooks that need bidirectional communication,
|
||||
# wait for first result before closing the channel
|
||||
has_hooks = bool(self.hooks)
|
||||
if self.sdk_mcp_servers or has_hooks:
|
||||
logger.debug(
|
||||
f"Waiting for first result before closing stdin "
|
||||
f"(sdk_mcp_servers={len(self.sdk_mcp_servers)}, has_hooks={has_hooks})"
|
||||
)
|
||||
try:
|
||||
with anyio.move_on_after(self._stream_close_timeout):
|
||||
await self._first_result_event.wait()
|
||||
logger.debug("Received first result, closing input stream")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Timed out waiting for first result, closing input stream"
|
||||
)
|
||||
|
||||
# After all messages sent (and result received if needed), end input
|
||||
await self.transport.end_input()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error streaming input: {e}")
|
||||
|
||||
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Receive SDK messages (not control messages)."""
|
||||
async for message in self._message_receive:
|
||||
# Check for special messages
|
||||
if message.get("type") == "end":
|
||||
break
|
||||
elif message.get("type") == "error":
|
||||
raise Exception(message.get("error", "Unknown error"))
|
||||
|
||||
yield message
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the query and transport."""
|
||||
self._closed = True
|
||||
if self._tg:
|
||||
self._tg.cancel_scope.cancel()
|
||||
# Wait for task group to complete cancellation
|
||||
with suppress(anyio.get_cancelled_exc_class()):
|
||||
await self._tg.__aexit__(None, None, None)
|
||||
await self.transport.close()
|
||||
|
||||
# Make Query an async iterator
|
||||
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Return async iterator for messages."""
|
||||
return self.receive_messages()
|
||||
|
||||
async def __anext__(self) -> dict[str, Any]:
|
||||
"""Get next message."""
|
||||
async for message in self.receive_messages():
|
||||
return message
|
||||
raise StopAsyncIteration
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Transport implementations for Claude SDK."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Transport(ABC):
|
||||
"""Abstract transport for Claude communication.
|
||||
|
||||
WARNING: This internal API is exposed for custom transport implementations
|
||||
(e.g., remote Claude Code connections). The Claude Code team may change or
|
||||
or remove this abstract class in any future release. Custom implementations
|
||||
must be updated to match interface changes.
|
||||
|
||||
This is a low-level transport interface that handles raw I/O with the Claude
|
||||
process or service. The Query class builds on top of this to implement the
|
||||
control protocol and message routing.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Connect the transport and prepare for communication.
|
||||
|
||||
For subprocess transports, this starts the process.
|
||||
For network transports, this establishes the connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def write(self, data: str) -> None:
|
||||
"""Write raw data to the transport.
|
||||
|
||||
Args:
|
||||
data: Raw string data to write (typically JSON + newline)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Read and parse messages from the transport.
|
||||
|
||||
Yields:
|
||||
Parsed JSON messages from the transport
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the transport connection and clean up resources."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if transport is ready for communication.
|
||||
|
||||
Returns:
|
||||
True if transport is ready to send/receive messages
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def end_input(self) -> None:
|
||||
"""End the input stream (close stdin for process transports)."""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["Transport"]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,672 @@
|
||||
"""Subprocess transport implementation using Claude Code CLI."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
from contextlib import suppress
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from subprocess import PIPE
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
from anyio.abc import Process
|
||||
from anyio.streams.text import TextReceiveStream, TextSendStream
|
||||
|
||||
from ..._errors import CLIConnectionError, CLINotFoundError, ProcessError
|
||||
from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
|
||||
from ..._version import __version__
|
||||
from ...types import ClaudeAgentOptions
|
||||
from . import Transport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MAX_BUFFER_SIZE = 1024 * 1024 # 1MB buffer limit
|
||||
MINIMUM_CLAUDE_CODE_VERSION = "2.0.0"
|
||||
|
||||
# Platform-specific command line length limits
|
||||
# Windows cmd.exe has a limit of 8191 characters, use 8000 for safety
|
||||
# Other platforms have much higher limits
|
||||
_CMD_LENGTH_LIMIT = 8000 if platform.system() == "Windows" else 100000
|
||||
|
||||
|
||||
class SubprocessCLITransport(Transport):
|
||||
"""Subprocess transport using Claude Code CLI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt: str | AsyncIterable[dict[str, Any]],
|
||||
options: ClaudeAgentOptions,
|
||||
):
|
||||
self._prompt = prompt
|
||||
self._is_streaming = not isinstance(prompt, str)
|
||||
self._options = options
|
||||
self._cli_path = (
|
||||
str(options.cli_path) if options.cli_path is not None else self._find_cli()
|
||||
)
|
||||
self._cwd = str(options.cwd) if options.cwd else None
|
||||
self._process: Process | None = None
|
||||
self._stdout_stream: TextReceiveStream | None = None
|
||||
self._stdin_stream: TextSendStream | None = None
|
||||
self._stderr_stream: TextReceiveStream | None = None
|
||||
self._stderr_task_group: anyio.abc.TaskGroup | None = None
|
||||
self._ready = False
|
||||
self._exit_error: Exception | None = None # Track process exit errors
|
||||
self._max_buffer_size = (
|
||||
options.max_buffer_size
|
||||
if options.max_buffer_size is not None
|
||||
else _DEFAULT_MAX_BUFFER_SIZE
|
||||
)
|
||||
self._temp_files: list[str] = [] # Track temporary files for cleanup
|
||||
self._write_lock: anyio.Lock = anyio.Lock()
|
||||
|
||||
def _find_cli(self) -> str:
|
||||
"""Find Claude Code CLI binary."""
|
||||
# First, check for bundled CLI
|
||||
bundled_cli = self._find_bundled_cli()
|
||||
if bundled_cli:
|
||||
return bundled_cli
|
||||
|
||||
# Fall back to system-wide search
|
||||
if cli := shutil.which("claude"):
|
||||
return cli
|
||||
|
||||
locations = [
|
||||
Path.home() / ".npm-global/bin/claude",
|
||||
Path("/usr/local/bin/claude"),
|
||||
Path.home() / ".local/bin/claude",
|
||||
Path.home() / "node_modules/.bin/claude",
|
||||
Path.home() / ".yarn/bin/claude",
|
||||
Path.home() / ".claude/local/claude",
|
||||
]
|
||||
|
||||
for path in locations:
|
||||
if path.exists() and path.is_file():
|
||||
return str(path)
|
||||
|
||||
raise CLINotFoundError(
|
||||
"Claude Code not found. Install with:\n"
|
||||
" npm install -g @anthropic-ai/claude-code\n"
|
||||
"\nIf already installed locally, try:\n"
|
||||
' export PATH="$HOME/node_modules/.bin:$PATH"\n'
|
||||
"\nOr provide the path via ClaudeAgentOptions:\n"
|
||||
" ClaudeAgentOptions(cli_path='/path/to/claude')"
|
||||
)
|
||||
|
||||
def _find_bundled_cli(self) -> str | None:
|
||||
"""Find bundled CLI binary if it exists."""
|
||||
# Determine the CLI binary name based on platform
|
||||
cli_name = "claude.exe" if platform.system() == "Windows" else "claude"
|
||||
|
||||
# Get the path to the bundled CLI
|
||||
# The _bundled directory is in the same package as this module
|
||||
bundled_path = Path(__file__).parent.parent.parent / "_bundled" / cli_name
|
||||
|
||||
if bundled_path.exists() and bundled_path.is_file():
|
||||
logger.info(f"Using bundled Claude Code CLI: {bundled_path}")
|
||||
return str(bundled_path)
|
||||
|
||||
return None
|
||||
|
||||
def _build_settings_value(self) -> str | None:
|
||||
"""Build settings value, merging sandbox settings if provided.
|
||||
|
||||
Returns the settings value as either:
|
||||
- A JSON string (if sandbox is provided or settings is JSON)
|
||||
- A file path (if only settings path is provided without sandbox)
|
||||
- None if neither settings nor sandbox is provided
|
||||
"""
|
||||
has_settings = self._options.settings is not None
|
||||
has_sandbox = self._options.sandbox is not None
|
||||
|
||||
if not has_settings and not has_sandbox:
|
||||
return None
|
||||
|
||||
# If only settings path and no sandbox, pass through as-is
|
||||
if has_settings and not has_sandbox:
|
||||
return self._options.settings
|
||||
|
||||
# If we have sandbox settings, we need to merge into a JSON object
|
||||
settings_obj: dict[str, Any] = {}
|
||||
|
||||
if has_settings:
|
||||
assert self._options.settings is not None
|
||||
settings_str = self._options.settings.strip()
|
||||
# Check if settings is a JSON string or a file path
|
||||
if settings_str.startswith("{") and settings_str.endswith("}"):
|
||||
# Parse JSON string
|
||||
try:
|
||||
settings_obj = json.loads(settings_str)
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, treat as file path
|
||||
logger.warning(
|
||||
f"Failed to parse settings as JSON, treating as file path: {settings_str}"
|
||||
)
|
||||
# Read the file
|
||||
settings_path = Path(settings_str)
|
||||
if settings_path.exists():
|
||||
with settings_path.open(encoding="utf-8") as f:
|
||||
settings_obj = json.load(f)
|
||||
else:
|
||||
# It's a file path - read and parse
|
||||
settings_path = Path(settings_str)
|
||||
if settings_path.exists():
|
||||
with settings_path.open(encoding="utf-8") as f:
|
||||
settings_obj = json.load(f)
|
||||
else:
|
||||
logger.warning(f"Settings file not found: {settings_path}")
|
||||
|
||||
# Merge sandbox settings
|
||||
if has_sandbox:
|
||||
settings_obj["sandbox"] = self._options.sandbox
|
||||
|
||||
return json.dumps(settings_obj)
|
||||
|
||||
def _build_command(self) -> list[str]:
|
||||
"""Build CLI command with arguments."""
|
||||
cmd = [self._cli_path, "--output-format", "stream-json", "--verbose"]
|
||||
|
||||
if self._options.system_prompt is None:
|
||||
cmd.extend(["--system-prompt", ""])
|
||||
elif isinstance(self._options.system_prompt, str):
|
||||
cmd.extend(["--system-prompt", self._options.system_prompt])
|
||||
else:
|
||||
if (
|
||||
self._options.system_prompt.get("type") == "preset"
|
||||
and "append" in self._options.system_prompt
|
||||
):
|
||||
cmd.extend(
|
||||
["--append-system-prompt", self._options.system_prompt["append"]]
|
||||
)
|
||||
|
||||
# Handle tools option (base set of tools)
|
||||
if self._options.tools is not None:
|
||||
tools = self._options.tools
|
||||
if isinstance(tools, list):
|
||||
if len(tools) == 0:
|
||||
cmd.extend(["--tools", ""])
|
||||
else:
|
||||
cmd.extend(["--tools", ",".join(tools)])
|
||||
else:
|
||||
# Preset object - 'claude_code' preset maps to 'default'
|
||||
cmd.extend(["--tools", "default"])
|
||||
|
||||
if self._options.allowed_tools:
|
||||
cmd.extend(["--allowedTools", ",".join(self._options.allowed_tools)])
|
||||
|
||||
if self._options.max_turns:
|
||||
cmd.extend(["--max-turns", str(self._options.max_turns)])
|
||||
|
||||
if self._options.max_budget_usd is not None:
|
||||
cmd.extend(["--max-budget-usd", str(self._options.max_budget_usd)])
|
||||
|
||||
if self._options.disallowed_tools:
|
||||
cmd.extend(["--disallowedTools", ",".join(self._options.disallowed_tools)])
|
||||
|
||||
if self._options.model:
|
||||
cmd.extend(["--model", self._options.model])
|
||||
|
||||
if self._options.fallback_model:
|
||||
cmd.extend(["--fallback-model", self._options.fallback_model])
|
||||
|
||||
if self._options.betas:
|
||||
cmd.extend(["--betas", ",".join(self._options.betas)])
|
||||
|
||||
if self._options.permission_prompt_tool_name:
|
||||
cmd.extend(
|
||||
["--permission-prompt-tool", self._options.permission_prompt_tool_name]
|
||||
)
|
||||
|
||||
if self._options.permission_mode:
|
||||
cmd.extend(["--permission-mode", self._options.permission_mode])
|
||||
|
||||
if self._options.continue_conversation:
|
||||
cmd.append("--continue")
|
||||
|
||||
if self._options.resume:
|
||||
cmd.extend(["--resume", self._options.resume])
|
||||
|
||||
# Handle settings and sandbox: merge sandbox into settings if both are provided
|
||||
settings_value = self._build_settings_value()
|
||||
if settings_value:
|
||||
cmd.extend(["--settings", settings_value])
|
||||
|
||||
if self._options.add_dirs:
|
||||
# Convert all paths to strings and add each directory
|
||||
for directory in self._options.add_dirs:
|
||||
cmd.extend(["--add-dir", str(directory)])
|
||||
|
||||
if self._options.mcp_servers:
|
||||
if isinstance(self._options.mcp_servers, dict):
|
||||
# Process all servers, stripping instance field from SDK servers
|
||||
servers_for_cli: dict[str, Any] = {}
|
||||
for name, config in self._options.mcp_servers.items():
|
||||
if isinstance(config, dict) and config.get("type") == "sdk":
|
||||
# For SDK servers, pass everything except the instance field
|
||||
sdk_config: dict[str, object] = {
|
||||
k: v for k, v in config.items() if k != "instance"
|
||||
}
|
||||
servers_for_cli[name] = sdk_config
|
||||
else:
|
||||
# For external servers, pass as-is
|
||||
servers_for_cli[name] = config
|
||||
|
||||
# Pass all servers to CLI
|
||||
if servers_for_cli:
|
||||
cmd.extend(
|
||||
[
|
||||
"--mcp-config",
|
||||
json.dumps({"mcpServers": servers_for_cli}),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# String or Path format: pass directly as file path or JSON string
|
||||
cmd.extend(["--mcp-config", str(self._options.mcp_servers)])
|
||||
|
||||
if self._options.include_partial_messages:
|
||||
cmd.append("--include-partial-messages")
|
||||
|
||||
if self._options.fork_session:
|
||||
cmd.append("--fork-session")
|
||||
|
||||
if self._options.agents:
|
||||
agents_dict = {
|
||||
name: {k: v for k, v in asdict(agent_def).items() if v is not None}
|
||||
for name, agent_def in self._options.agents.items()
|
||||
}
|
||||
agents_json = json.dumps(agents_dict)
|
||||
cmd.extend(["--agents", agents_json])
|
||||
|
||||
sources_value = (
|
||||
",".join(self._options.setting_sources)
|
||||
if self._options.setting_sources is not None
|
||||
else ""
|
||||
)
|
||||
cmd.extend(["--setting-sources", sources_value])
|
||||
|
||||
# Add plugin directories
|
||||
if self._options.plugins:
|
||||
for plugin in self._options.plugins:
|
||||
if plugin["type"] == "local":
|
||||
cmd.extend(["--plugin-dir", plugin["path"]])
|
||||
else:
|
||||
raise ValueError(f"Unsupported plugin type: {plugin['type']}")
|
||||
|
||||
# Add extra args for future CLI flags
|
||||
for flag, value in self._options.extra_args.items():
|
||||
if value is None:
|
||||
# Boolean flag without value
|
||||
cmd.append(f"--{flag}")
|
||||
else:
|
||||
# Flag with value
|
||||
cmd.extend([f"--{flag}", str(value)])
|
||||
|
||||
if self._options.max_thinking_tokens is not None:
|
||||
cmd.extend(
|
||||
["--max-thinking-tokens", str(self._options.max_thinking_tokens)]
|
||||
)
|
||||
|
||||
# Extract schema from output_format structure if provided
|
||||
# Expected: {"type": "json_schema", "schema": {...}}
|
||||
if (
|
||||
self._options.output_format is not None
|
||||
and isinstance(self._options.output_format, dict)
|
||||
and self._options.output_format.get("type") == "json_schema"
|
||||
):
|
||||
schema = self._options.output_format.get("schema")
|
||||
if schema is not None:
|
||||
cmd.extend(["--json-schema", json.dumps(schema)])
|
||||
|
||||
# Add prompt handling based on mode
|
||||
# IMPORTANT: This must come AFTER all flags because everything after "--" is treated as arguments
|
||||
if self._is_streaming:
|
||||
# Streaming mode: use --input-format stream-json
|
||||
cmd.extend(["--input-format", "stream-json"])
|
||||
else:
|
||||
# String mode: use --print with the prompt
|
||||
cmd.extend(["--print", "--", str(self._prompt)])
|
||||
|
||||
# Check if command line is too long (Windows limitation)
|
||||
cmd_str = " ".join(cmd)
|
||||
if len(cmd_str) > _CMD_LENGTH_LIMIT and self._options.agents:
|
||||
# Command is too long - use temp file for agents
|
||||
# Find the --agents argument and replace its value with @filepath
|
||||
try:
|
||||
agents_idx = cmd.index("--agents")
|
||||
agents_json_value = cmd[agents_idx + 1]
|
||||
|
||||
# Create a temporary file
|
||||
# ruff: noqa: SIM115
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False, encoding="utf-8"
|
||||
)
|
||||
temp_file.write(agents_json_value)
|
||||
temp_file.close()
|
||||
|
||||
# Track for cleanup
|
||||
self._temp_files.append(temp_file.name)
|
||||
|
||||
# Replace agents JSON with @filepath reference
|
||||
cmd[agents_idx + 1] = f"@{temp_file.name}"
|
||||
|
||||
logger.info(
|
||||
f"Command line length ({len(cmd_str)}) exceeds limit ({_CMD_LENGTH_LIMIT}). "
|
||||
f"Using temp file for --agents: {temp_file.name}"
|
||||
)
|
||||
except (ValueError, IndexError) as e:
|
||||
# This shouldn't happen, but log it just in case
|
||||
logger.warning(f"Failed to optimize command line length: {e}")
|
||||
|
||||
return cmd
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Start subprocess."""
|
||||
if self._process:
|
||||
return
|
||||
|
||||
if not os.environ.get("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK"):
|
||||
await self._check_claude_version()
|
||||
|
||||
cmd = self._build_command()
|
||||
try:
|
||||
# Merge environment variables: system -> user -> SDK required
|
||||
process_env = {
|
||||
**os.environ,
|
||||
**self._options.env, # User-provided env vars
|
||||
"CLAUDE_CODE_ENTRYPOINT": "sdk-py",
|
||||
"CLAUDE_AGENT_SDK_VERSION": __version__,
|
||||
}
|
||||
|
||||
# Enable file checkpointing if requested
|
||||
if self._options.enable_file_checkpointing:
|
||||
process_env["CLAUDE_CODE_ENABLE_SDK_FILE_CHECKPOINTING"] = "true"
|
||||
|
||||
if self._cwd:
|
||||
process_env["PWD"] = self._cwd
|
||||
|
||||
# Pipe stderr if we have a callback OR debug mode is enabled
|
||||
should_pipe_stderr = (
|
||||
self._options.stderr is not None
|
||||
or "debug-to-stderr" in self._options.extra_args
|
||||
)
|
||||
|
||||
# For backward compat: use debug_stderr file object if no callback and debug is on
|
||||
stderr_dest = PIPE if should_pipe_stderr else None
|
||||
|
||||
self._process = await anyio.open_process(
|
||||
cmd,
|
||||
stdin=PIPE,
|
||||
stdout=PIPE,
|
||||
stderr=stderr_dest,
|
||||
cwd=self._cwd,
|
||||
env=process_env,
|
||||
user=self._options.user,
|
||||
)
|
||||
|
||||
if self._process.stdout:
|
||||
self._stdout_stream = TextReceiveStream(self._process.stdout)
|
||||
|
||||
# Setup stderr stream if piped
|
||||
if should_pipe_stderr and self._process.stderr:
|
||||
self._stderr_stream = TextReceiveStream(self._process.stderr)
|
||||
# Start async task to read stderr
|
||||
self._stderr_task_group = anyio.create_task_group()
|
||||
await self._stderr_task_group.__aenter__()
|
||||
self._stderr_task_group.start_soon(self._handle_stderr)
|
||||
|
||||
# Setup stdin for streaming mode
|
||||
if self._is_streaming and self._process.stdin:
|
||||
self._stdin_stream = TextSendStream(self._process.stdin)
|
||||
elif not self._is_streaming and self._process.stdin:
|
||||
# String mode: close stdin immediately
|
||||
await self._process.stdin.aclose()
|
||||
|
||||
self._ready = True
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# Check if the error comes from the working directory or the CLI
|
||||
if self._cwd and not Path(self._cwd).exists():
|
||||
error = CLIConnectionError(
|
||||
f"Working directory does not exist: {self._cwd}"
|
||||
)
|
||||
self._exit_error = error
|
||||
raise error from e
|
||||
error = CLINotFoundError(f"Claude Code not found at: {self._cli_path}")
|
||||
self._exit_error = error
|
||||
raise error from e
|
||||
except Exception as e:
|
||||
error = CLIConnectionError(f"Failed to start Claude Code: {e}")
|
||||
self._exit_error = error
|
||||
raise error from e
|
||||
|
||||
async def _handle_stderr(self) -> None:
|
||||
"""Handle stderr stream - read and invoke callbacks."""
|
||||
if not self._stderr_stream:
|
||||
return
|
||||
|
||||
try:
|
||||
async for line in self._stderr_stream:
|
||||
line_str = line.rstrip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
# Call the stderr callback if provided
|
||||
if self._options.stderr:
|
||||
self._options.stderr(line_str)
|
||||
|
||||
# For backward compatibility: write to debug_stderr if in debug mode
|
||||
elif (
|
||||
"debug-to-stderr" in self._options.extra_args
|
||||
and self._options.debug_stderr
|
||||
):
|
||||
self._options.debug_stderr.write(line_str + "\n")
|
||||
if hasattr(self._options.debug_stderr, "flush"):
|
||||
self._options.debug_stderr.flush()
|
||||
except anyio.ClosedResourceError:
|
||||
pass # Stream closed, exit normally
|
||||
except Exception:
|
||||
pass # Ignore other errors during stderr reading
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and clean up resources."""
|
||||
# Clean up temporary files first (before early return)
|
||||
for temp_file in self._temp_files:
|
||||
with suppress(Exception):
|
||||
Path(temp_file).unlink(missing_ok=True)
|
||||
self._temp_files.clear()
|
||||
|
||||
if not self._process:
|
||||
self._ready = False
|
||||
return
|
||||
|
||||
# Close stderr task group if active
|
||||
if self._stderr_task_group:
|
||||
with suppress(Exception):
|
||||
self._stderr_task_group.cancel_scope.cancel()
|
||||
await self._stderr_task_group.__aexit__(None, None, None)
|
||||
self._stderr_task_group = None
|
||||
|
||||
# Close stdin stream (acquire lock to prevent race with concurrent writes)
|
||||
async with self._write_lock:
|
||||
self._ready = False # Set inside lock to prevent TOCTOU with write()
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
|
||||
if self._stderr_stream:
|
||||
with suppress(Exception):
|
||||
await self._stderr_stream.aclose()
|
||||
self._stderr_stream = None
|
||||
|
||||
# Terminate and wait for process
|
||||
if self._process.returncode is None:
|
||||
with suppress(ProcessLookupError):
|
||||
self._process.terminate()
|
||||
# Wait for process to finish with timeout
|
||||
with suppress(Exception):
|
||||
# Just try to wait, but don't block if it fails
|
||||
await self._process.wait()
|
||||
|
||||
self._process = None
|
||||
self._stdout_stream = None
|
||||
self._stdin_stream = None
|
||||
self._stderr_stream = None
|
||||
self._exit_error = None
|
||||
|
||||
async def write(self, data: str) -> None:
|
||||
"""Write raw data to the transport."""
|
||||
async with self._write_lock:
|
||||
# All checks inside lock to prevent TOCTOU races with close()/end_input()
|
||||
if not self._ready or not self._stdin_stream:
|
||||
raise CLIConnectionError("ProcessTransport is not ready for writing")
|
||||
|
||||
if self._process and self._process.returncode is not None:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to terminated process (exit code: {self._process.returncode})"
|
||||
)
|
||||
|
||||
if self._exit_error:
|
||||
raise CLIConnectionError(
|
||||
f"Cannot write to process that exited with error: {self._exit_error}"
|
||||
) from self._exit_error
|
||||
|
||||
try:
|
||||
await self._stdin_stream.send(data)
|
||||
except Exception as e:
|
||||
self._ready = False
|
||||
self._exit_error = CLIConnectionError(
|
||||
f"Failed to write to process stdin: {e}"
|
||||
)
|
||||
raise self._exit_error from e
|
||||
|
||||
async def end_input(self) -> None:
|
||||
"""End the input stream (close stdin)."""
|
||||
async with self._write_lock:
|
||||
if self._stdin_stream:
|
||||
with suppress(Exception):
|
||||
await self._stdin_stream.aclose()
|
||||
self._stdin_stream = None
|
||||
|
||||
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Read and parse messages from the transport."""
|
||||
return self._read_messages_impl()
|
||||
|
||||
async def _read_messages_impl(self) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Internal implementation of read_messages."""
|
||||
if not self._process or not self._stdout_stream:
|
||||
raise CLIConnectionError("Not connected")
|
||||
|
||||
json_buffer = ""
|
||||
|
||||
# Process stdout messages
|
||||
try:
|
||||
async for line in self._stdout_stream:
|
||||
line_str = line.strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
# Accumulate partial JSON until we can parse it
|
||||
# Note: TextReceiveStream can truncate long lines, so we need to buffer
|
||||
# and speculatively parse until we get a complete JSON object
|
||||
json_lines = line_str.split("\n")
|
||||
|
||||
for json_line in json_lines:
|
||||
json_line = json_line.strip()
|
||||
if not json_line:
|
||||
continue
|
||||
|
||||
# Keep accumulating partial JSON until we can parse it
|
||||
json_buffer += json_line
|
||||
|
||||
if len(json_buffer) > self._max_buffer_size:
|
||||
buffer_length = len(json_buffer)
|
||||
json_buffer = ""
|
||||
raise SDKJSONDecodeError(
|
||||
f"JSON message exceeded maximum buffer size of {self._max_buffer_size} bytes",
|
||||
ValueError(
|
||||
f"Buffer size {buffer_length} exceeds limit {self._max_buffer_size}"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
data = json.loads(json_buffer)
|
||||
json_buffer = ""
|
||||
yield data
|
||||
except json.JSONDecodeError:
|
||||
# We are speculatively decoding the buffer until we get
|
||||
# a full JSON object. If there is an actual issue, we
|
||||
# raise an error after exceeding the configured limit.
|
||||
continue
|
||||
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
except GeneratorExit:
|
||||
# Client disconnected
|
||||
pass
|
||||
|
||||
# Check process completion and handle errors
|
||||
try:
|
||||
returncode = await self._process.wait()
|
||||
except Exception:
|
||||
returncode = -1
|
||||
|
||||
# Use exit code for error detection
|
||||
if returncode is not None and returncode != 0:
|
||||
self._exit_error = ProcessError(
|
||||
f"Command failed with exit code {returncode}",
|
||||
exit_code=returncode,
|
||||
stderr="Check stderr output for details",
|
||||
)
|
||||
raise self._exit_error
|
||||
|
||||
async def _check_claude_version(self) -> None:
|
||||
"""Check Claude Code version and warn if below minimum."""
|
||||
version_process = None
|
||||
try:
|
||||
with anyio.fail_after(2): # 2 second timeout
|
||||
version_process = await anyio.open_process(
|
||||
[self._cli_path, "-v"],
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
|
||||
if version_process.stdout:
|
||||
stdout_bytes = await version_process.stdout.receive()
|
||||
version_output = stdout_bytes.decode().strip()
|
||||
|
||||
match = re.match(r"([0-9]+\.[0-9]+\.[0-9]+)", version_output)
|
||||
if match:
|
||||
version = match.group(1)
|
||||
version_parts = [int(x) for x in version.split(".")]
|
||||
min_parts = [
|
||||
int(x) for x in MINIMUM_CLAUDE_CODE_VERSION.split(".")
|
||||
]
|
||||
|
||||
if version_parts < min_parts:
|
||||
warning = (
|
||||
f"Warning: Claude Code version {version} is unsupported in the Agent SDK. "
|
||||
f"Minimum required version is {MINIMUM_CLAUDE_CODE_VERSION}. "
|
||||
"Some features may not work correctly."
|
||||
)
|
||||
logger.warning(warning)
|
||||
print(warning, file=sys.stderr)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if version_process:
|
||||
with suppress(Exception):
|
||||
version_process.terminate()
|
||||
with suppress(Exception):
|
||||
await version_process.wait()
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if transport is ready for communication."""
|
||||
return self._ready
|
||||
Reference in New Issue
Block a user