Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1 @@
"""Internal implementation details."""

View File

@@ -0,0 +1,124 @@
"""Internal client implementation."""
from collections.abc import AsyncIterable, AsyncIterator
from dataclasses import replace
from typing import Any
from ..types import (
ClaudeAgentOptions,
HookEvent,
HookMatcher,
Message,
)
from .message_parser import parse_message
from .query import Query
from .transport import Transport
from .transport.subprocess_cli import SubprocessCLITransport
class InternalClient:
"""Internal client implementation."""
def __init__(self) -> None:
"""Initialize the internal client."""
def _convert_hooks_to_internal_format(
self, hooks: dict[HookEvent, list[HookMatcher]]
) -> dict[str, list[dict[str, Any]]]:
"""Convert HookMatcher format to internal Query format."""
internal_hooks: dict[str, list[dict[str, Any]]] = {}
for event, matchers in hooks.items():
internal_hooks[event] = []
for matcher in matchers:
# Convert HookMatcher to internal dict format
internal_matcher: dict[str, Any] = {
"matcher": matcher.matcher if hasattr(matcher, "matcher") else None,
"hooks": matcher.hooks if hasattr(matcher, "hooks") else [],
}
if hasattr(matcher, "timeout") and matcher.timeout is not None:
internal_matcher["timeout"] = matcher.timeout
internal_hooks[event].append(internal_matcher)
return internal_hooks
async def process_query(
self,
prompt: str | AsyncIterable[dict[str, Any]],
options: ClaudeAgentOptions,
transport: Transport | None = None,
) -> AsyncIterator[Message]:
"""Process a query through transport and Query."""
# Validate and configure permission settings (matching TypeScript SDK logic)
configured_options = options
if options.can_use_tool:
# canUseTool callback requires streaming mode (AsyncIterable prompt)
if isinstance(prompt, str):
raise ValueError(
"can_use_tool callback requires streaming mode. "
"Please provide prompt as an AsyncIterable instead of a string."
)
# canUseTool and permission_prompt_tool_name are mutually exclusive
if options.permission_prompt_tool_name:
raise ValueError(
"can_use_tool callback cannot be used with permission_prompt_tool_name. "
"Please use one or the other."
)
# Automatically set permission_prompt_tool_name to "stdio" for control protocol
configured_options = replace(options, permission_prompt_tool_name="stdio")
# Use provided transport or create subprocess transport
if transport is not None:
chosen_transport = transport
else:
chosen_transport = SubprocessCLITransport(
prompt=prompt,
options=configured_options,
)
# Connect transport
await chosen_transport.connect()
# Extract SDK MCP servers from configured options
sdk_mcp_servers = {}
if configured_options.mcp_servers and isinstance(
configured_options.mcp_servers, dict
):
for name, config in configured_options.mcp_servers.items():
if isinstance(config, dict) and config.get("type") == "sdk":
sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item]
# Create Query to handle control protocol
is_streaming = not isinstance(prompt, str)
query = Query(
transport=chosen_transport,
is_streaming_mode=is_streaming,
can_use_tool=configured_options.can_use_tool,
hooks=self._convert_hooks_to_internal_format(configured_options.hooks)
if configured_options.hooks
else None,
sdk_mcp_servers=sdk_mcp_servers,
)
try:
# Start reading messages
await query.start()
# Initialize if streaming
if is_streaming:
await query.initialize()
# Stream input if it's an AsyncIterable
if isinstance(prompt, AsyncIterable) and query._tg:
# Start streaming in background
# Create a task that will run in the background
query._tg.start_soon(query.stream_input, prompt)
# For string prompts, the prompt is already passed via CLI args
# Yield parsed messages
async for data in query.receive_messages():
yield parse_message(data)
finally:
await query.close()

View File

@@ -0,0 +1,177 @@
"""Message parser for Claude Code SDK responses."""
import logging
from typing import Any
from .._errors import MessageParseError
from ..types import (
AssistantMessage,
ContentBlock,
Message,
ResultMessage,
StreamEvent,
SystemMessage,
TextBlock,
ThinkingBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
logger = logging.getLogger(__name__)
def parse_message(data: dict[str, Any]) -> Message:
"""
Parse message from CLI output into typed Message objects.
Args:
data: Raw message dictionary from CLI output
Returns:
Parsed Message object
Raises:
MessageParseError: If parsing fails or message type is unrecognized
"""
if not isinstance(data, dict):
raise MessageParseError(
f"Invalid message data type (expected dict, got {type(data).__name__})",
data,
)
message_type = data.get("type")
if not message_type:
raise MessageParseError("Message missing 'type' field", data)
match message_type:
case "user":
try:
parent_tool_use_id = data.get("parent_tool_use_id")
uuid = data.get("uuid")
if isinstance(data["message"]["content"], list):
user_content_blocks: list[ContentBlock] = []
for block in data["message"]["content"]:
match block["type"]:
case "text":
user_content_blocks.append(
TextBlock(text=block["text"])
)
case "tool_use":
user_content_blocks.append(
ToolUseBlock(
id=block["id"],
name=block["name"],
input=block["input"],
)
)
case "tool_result":
user_content_blocks.append(
ToolResultBlock(
tool_use_id=block["tool_use_id"],
content=block.get("content"),
is_error=block.get("is_error"),
)
)
return UserMessage(
content=user_content_blocks,
uuid=uuid,
parent_tool_use_id=parent_tool_use_id,
)
return UserMessage(
content=data["message"]["content"],
uuid=uuid,
parent_tool_use_id=parent_tool_use_id,
)
except KeyError as e:
raise MessageParseError(
f"Missing required field in user message: {e}", data
) from e
case "assistant":
try:
content_blocks: list[ContentBlock] = []
for block in data["message"]["content"]:
match block["type"]:
case "text":
content_blocks.append(TextBlock(text=block["text"]))
case "thinking":
content_blocks.append(
ThinkingBlock(
thinking=block["thinking"],
signature=block["signature"],
)
)
case "tool_use":
content_blocks.append(
ToolUseBlock(
id=block["id"],
name=block["name"],
input=block["input"],
)
)
case "tool_result":
content_blocks.append(
ToolResultBlock(
tool_use_id=block["tool_use_id"],
content=block.get("content"),
is_error=block.get("is_error"),
)
)
return AssistantMessage(
content=content_blocks,
model=data["message"]["model"],
parent_tool_use_id=data.get("parent_tool_use_id"),
error=data["message"].get("error"),
)
except KeyError as e:
raise MessageParseError(
f"Missing required field in assistant message: {e}", data
) from e
case "system":
try:
return SystemMessage(
subtype=data["subtype"],
data=data,
)
except KeyError as e:
raise MessageParseError(
f"Missing required field in system message: {e}", data
) from e
case "result":
try:
return ResultMessage(
subtype=data["subtype"],
duration_ms=data["duration_ms"],
duration_api_ms=data["duration_api_ms"],
is_error=data["is_error"],
num_turns=data["num_turns"],
session_id=data["session_id"],
total_cost_usd=data.get("total_cost_usd"),
usage=data.get("usage"),
result=data.get("result"),
structured_output=data.get("structured_output"),
)
except KeyError as e:
raise MessageParseError(
f"Missing required field in result message: {e}", data
) from e
case "stream_event":
try:
return StreamEvent(
uuid=data["uuid"],
session_id=data["session_id"],
event=data["event"],
parent_tool_use_id=data.get("parent_tool_use_id"),
)
except KeyError as e:
raise MessageParseError(
f"Missing required field in stream_event message: {e}", data
) from e
case _:
raise MessageParseError(f"Unknown message type: {message_type}", data)

View File

@@ -0,0 +1,621 @@
"""Query class for handling bidirectional control protocol."""
import json
import logging
import os
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from contextlib import suppress
from typing import TYPE_CHECKING, Any
import anyio
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
ListToolsRequest,
)
from ..types import (
PermissionResultAllow,
PermissionResultDeny,
SDKControlPermissionRequest,
SDKControlRequest,
SDKControlResponse,
SDKHookCallbackRequest,
ToolPermissionContext,
)
from .transport import Transport
if TYPE_CHECKING:
from mcp.server import Server as McpServer
logger = logging.getLogger(__name__)
def _convert_hook_output_for_cli(hook_output: dict[str, Any]) -> dict[str, Any]:
"""Convert Python-safe field names to CLI-expected field names.
The Python SDK uses `async_` and `continue_` to avoid keyword conflicts,
but the CLI expects `async` and `continue`. This function performs the
necessary conversion.
"""
converted = {}
for key, value in hook_output.items():
# Convert Python-safe names to JavaScript names
if key == "async_":
converted["async"] = value
elif key == "continue_":
converted["continue"] = value
else:
converted[key] = value
return converted
class Query:
"""Handles bidirectional control protocol on top of Transport.
This class manages:
- Control request/response routing
- Hook callbacks
- Tool permission callbacks
- Message streaming
- Initialization handshake
"""
def __init__(
self,
transport: Transport,
is_streaming_mode: bool,
can_use_tool: Callable[
[str, dict[str, Any], ToolPermissionContext],
Awaitable[PermissionResultAllow | PermissionResultDeny],
]
| None = None,
hooks: dict[str, list[dict[str, Any]]] | None = None,
sdk_mcp_servers: dict[str, "McpServer"] | None = None,
initialize_timeout: float = 60.0,
):
"""Initialize Query with transport and callbacks.
Args:
transport: Low-level transport for I/O
is_streaming_mode: Whether using streaming (bidirectional) mode
can_use_tool: Optional callback for tool permission requests
hooks: Optional hook configurations
sdk_mcp_servers: Optional SDK MCP server instances
initialize_timeout: Timeout in seconds for the initialize request
"""
self._initialize_timeout = initialize_timeout
self.transport = transport
self.is_streaming_mode = is_streaming_mode
self.can_use_tool = can_use_tool
self.hooks = hooks or {}
self.sdk_mcp_servers = sdk_mcp_servers or {}
# Control protocol state
self.pending_control_responses: dict[str, anyio.Event] = {}
self.pending_control_results: dict[str, dict[str, Any] | Exception] = {}
self.hook_callbacks: dict[str, Callable[..., Any]] = {}
self.next_callback_id = 0
self._request_counter = 0
# Message stream
self._message_send, self._message_receive = anyio.create_memory_object_stream[
dict[str, Any]
](max_buffer_size=100)
self._tg: anyio.abc.TaskGroup | None = None
self._initialized = False
self._closed = False
self._initialization_result: dict[str, Any] | None = None
# Track first result for proper stream closure with SDK MCP servers
self._first_result_event = anyio.Event()
self._stream_close_timeout = (
float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0
) # Convert ms to seconds
async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.
Returns:
Initialize response with supported commands, or None if not streaming
"""
if not self.is_streaming_mode:
return None
# Build hooks configuration for initialization
hooks_config: dict[str, Any] = {}
if self.hooks:
for event, matchers in self.hooks.items():
if matchers:
hooks_config[event] = []
for matcher in matchers:
callback_ids = []
for callback in matcher.get("hooks", []):
callback_id = f"hook_{self.next_callback_id}"
self.next_callback_id += 1
self.hook_callbacks[callback_id] = callback
callback_ids.append(callback_id)
hook_matcher_config: dict[str, Any] = {
"matcher": matcher.get("matcher"),
"hookCallbackIds": callback_ids,
}
if matcher.get("timeout") is not None:
hook_matcher_config["timeout"] = matcher.get("timeout")
hooks_config[event].append(hook_matcher_config)
# Send initialize request
request = {
"subtype": "initialize",
"hooks": hooks_config if hooks_config else None,
}
# Use longer timeout for initialize since MCP servers may take time to start
response = await self._send_control_request(
request, timeout=self._initialize_timeout
)
self._initialized = True
self._initialization_result = response # Store for later access
return response
async def start(self) -> None:
"""Start reading messages from transport."""
if self._tg is None:
self._tg = anyio.create_task_group()
await self._tg.__aenter__()
self._tg.start_soon(self._read_messages)
async def _read_messages(self) -> None:
"""Read messages from transport and route them."""
try:
async for message in self.transport.read_messages():
if self._closed:
break
msg_type = message.get("type")
# Route control messages
if msg_type == "control_response":
response = message.get("response", {})
request_id = response.get("request_id")
if request_id in self.pending_control_responses:
event = self.pending_control_responses[request_id]
if response.get("subtype") == "error":
self.pending_control_results[request_id] = Exception(
response.get("error", "Unknown error")
)
else:
self.pending_control_results[request_id] = response
event.set()
continue
elif msg_type == "control_request":
# Handle incoming control requests from CLI
# Cast message to SDKControlRequest for type safety
request: SDKControlRequest = message # type: ignore[assignment]
if self._tg:
self._tg.start_soon(self._handle_control_request, request)
continue
elif msg_type == "control_cancel_request":
# Handle cancel requests
# TODO: Implement cancellation support
continue
# Track results for proper stream closure
if msg_type == "result":
self._first_result_event.set()
# Regular SDK messages go to the stream
await self._message_send.send(message)
except anyio.get_cancelled_exc_class():
# Task was cancelled - this is expected behavior
logger.debug("Read task cancelled")
raise # Re-raise to properly handle cancellation
except Exception as e:
logger.error(f"Fatal error in message reader: {e}")
# Signal all pending control requests so they fail fast instead of timing out
for request_id, event in list(self.pending_control_responses.items()):
if request_id not in self.pending_control_results:
self.pending_control_results[request_id] = e
event.set()
# Put error in stream so iterators can handle it
await self._message_send.send({"type": "error", "error": str(e)})
finally:
# Always signal end of stream
await self._message_send.send({"type": "end"})
async def _handle_control_request(self, request: SDKControlRequest) -> None:
"""Handle incoming control request from CLI."""
request_id = request["request_id"]
request_data = request["request"]
subtype = request_data["subtype"]
try:
response_data: dict[str, Any] = {}
if subtype == "can_use_tool":
permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
original_input = permission_request["input"]
# Handle tool permission request
if not self.can_use_tool:
raise Exception("canUseTool callback is not provided")
context = ToolPermissionContext(
signal=None, # TODO: Add abort signal support
suggestions=permission_request.get("permission_suggestions", [])
or [],
)
response = await self.can_use_tool(
permission_request["tool_name"],
permission_request["input"],
context,
)
# Convert PermissionResult to expected dict format
if isinstance(response, PermissionResultAllow):
response_data = {
"behavior": "allow",
"updatedInput": (
response.updated_input
if response.updated_input is not None
else original_input
),
}
if response.updated_permissions is not None:
response_data["updatedPermissions"] = [
permission.to_dict()
for permission in response.updated_permissions
]
elif isinstance(response, PermissionResultDeny):
response_data = {"behavior": "deny", "message": response.message}
if response.interrupt:
response_data["interrupt"] = response.interrupt
else:
raise TypeError(
f"Tool permission callback must return PermissionResult (PermissionResultAllow or PermissionResultDeny), got {type(response)}"
)
elif subtype == "hook_callback":
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]
# Handle hook callback
callback_id = hook_callback_request["callback_id"]
callback = self.hook_callbacks.get(callback_id)
if not callback:
raise Exception(f"No hook callback found for ID: {callback_id}")
hook_output = await callback(
request_data.get("input"),
request_data.get("tool_use_id"),
{"signal": None}, # TODO: Add abort signal support
)
# Convert Python-safe field names (async_, continue_) to CLI-expected names (async, continue)
response_data = _convert_hook_output_for_cli(hook_output)
elif subtype == "mcp_message":
# Handle SDK MCP request
server_name = request_data.get("server_name")
mcp_message = request_data.get("message")
if not server_name or not mcp_message:
raise Exception("Missing server_name or message for MCP request")
# Type narrowing - we've verified these are not None above
assert isinstance(server_name, str)
assert isinstance(mcp_message, dict)
mcp_response = await self._handle_sdk_mcp_request(
server_name, mcp_message
)
# Wrap the MCP response as expected by the control protocol
response_data = {"mcp_response": mcp_response}
else:
raise Exception(f"Unsupported control request subtype: {subtype}")
# Send success response
success_response: SDKControlResponse = {
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": response_data,
},
}
await self.transport.write(json.dumps(success_response) + "\n")
except Exception as e:
# Send error response
error_response: SDKControlResponse = {
"type": "control_response",
"response": {
"subtype": "error",
"request_id": request_id,
"error": str(e),
},
}
await self.transport.write(json.dumps(error_response) + "\n")
async def _send_control_request(
self, request: dict[str, Any], timeout: float = 60.0
) -> dict[str, Any]:
"""Send control request to CLI and wait for response.
Args:
request: The control request to send
timeout: Timeout in seconds to wait for response (default 60s)
"""
if not self.is_streaming_mode:
raise Exception("Control requests require streaming mode")
# Generate unique request ID
self._request_counter += 1
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
# Create event for response
event = anyio.Event()
self.pending_control_responses[request_id] = event
# Build and send request
control_request = {
"type": "control_request",
"request_id": request_id,
"request": request,
}
await self.transport.write(json.dumps(control_request) + "\n")
# Wait for response
try:
with anyio.fail_after(timeout):
await event.wait()
result = self.pending_control_results.pop(request_id)
self.pending_control_responses.pop(request_id, None)
if isinstance(result, Exception):
raise result
response_data = result.get("response", {})
return response_data if isinstance(response_data, dict) else {}
except TimeoutError as e:
self.pending_control_responses.pop(request_id, None)
self.pending_control_results.pop(request_id, None)
raise Exception(f"Control request timeout: {request.get('subtype')}") from e
async def _handle_sdk_mcp_request(
self, server_name: str, message: dict[str, Any]
) -> dict[str, Any]:
"""Handle an MCP request for an SDK server.
This acts as a bridge between JSONRPC messages from the CLI
and the in-process MCP server. Ideally the MCP SDK would provide
a method to handle raw JSONRPC, but for now we route manually.
Args:
server_name: Name of the SDK MCP server
message: The JSONRPC message
Returns:
The response message
"""
if server_name not in self.sdk_mcp_servers:
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"error": {
"code": -32601,
"message": f"Server '{server_name}' not found",
},
}
server = self.sdk_mcp_servers[server_name]
method = message.get("method")
params = message.get("params", {})
try:
# TODO: Python MCP SDK lacks the Transport abstraction that TypeScript has.
# TypeScript: server.connect(transport) allows custom transports
# Python: server.run(read_stream, write_stream) requires actual streams
#
# This forces us to manually route methods. When Python MCP adds Transport
# support, we can refactor to match the TypeScript approach.
if method == "initialize":
# Handle MCP initialization - hardcoded for tools only, no listChanged
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {} # Tools capability without listChanged
},
"serverInfo": {
"name": server.name,
"version": server.version or "1.0.0",
},
},
}
elif method == "tools/list":
request = ListToolsRequest(method=method)
handler = server.request_handlers.get(ListToolsRequest)
if handler:
result = await handler(request)
# Convert MCP result to JSONRPC response
tools_data = [
{
"name": tool.name,
"description": tool.description,
"inputSchema": (
tool.inputSchema.model_dump()
if hasattr(tool.inputSchema, "model_dump")
else tool.inputSchema
)
if tool.inputSchema
else {},
}
for tool in result.root.tools # type: ignore[union-attr]
]
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": {"tools": tools_data},
}
elif method == "tools/call":
call_request = CallToolRequest(
method=method,
params=CallToolRequestParams(
name=params.get("name"), arguments=params.get("arguments", {})
),
)
handler = server.request_handlers.get(CallToolRequest)
if handler:
result = await handler(call_request)
# Convert MCP result to JSONRPC response
content = []
for item in result.root.content: # type: ignore[union-attr]
if hasattr(item, "text"):
content.append({"type": "text", "text": item.text})
elif hasattr(item, "data") and hasattr(item, "mimeType"):
content.append(
{
"type": "image",
"data": item.data,
"mimeType": item.mimeType,
}
)
response_data = {"content": content}
if hasattr(result.root, "is_error") and result.root.is_error:
response_data["is_error"] = True # type: ignore[assignment]
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"result": response_data,
}
elif method == "notifications/initialized":
# Handle initialized notification - just acknowledge it
return {"jsonrpc": "2.0", "result": {}}
# Add more methods here as MCP SDK adds them (resources, prompts, etc.)
# This is the limitation Ashwin pointed out - we have to manually update
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"error": {"code": -32601, "message": f"Method '{method}' not found"},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": message.get("id"),
"error": {"code": -32603, "message": str(e)},
}
async def interrupt(self) -> None:
"""Send interrupt control request."""
await self._send_control_request({"subtype": "interrupt"})
async def set_permission_mode(self, mode: str) -> None:
"""Change permission mode."""
await self._send_control_request(
{
"subtype": "set_permission_mode",
"mode": mode,
}
)
async def set_model(self, model: str | None) -> None:
"""Change the AI model."""
await self._send_control_request(
{
"subtype": "set_model",
"model": model,
}
)
async def rewind_files(self, user_message_id: str) -> None:
"""Rewind tracked files to their state at a specific user message.
Requires file checkpointing to be enabled via the `enable_file_checkpointing` option.
Args:
user_message_id: UUID of the user message to rewind to
"""
await self._send_control_request(
{
"subtype": "rewind_files",
"user_message_id": user_message_id,
}
)
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
"""Stream input messages to transport.
If SDK MCP servers or hooks are present, waits for the first result
before closing stdin to allow bidirectional control protocol communication.
"""
try:
async for message in stream:
if self._closed:
break
await self.transport.write(json.dumps(message) + "\n")
# If we have SDK MCP servers or hooks that need bidirectional communication,
# wait for first result before closing the channel
has_hooks = bool(self.hooks)
if self.sdk_mcp_servers or has_hooks:
logger.debug(
f"Waiting for first result before closing stdin "
f"(sdk_mcp_servers={len(self.sdk_mcp_servers)}, has_hooks={has_hooks})"
)
try:
with anyio.move_on_after(self._stream_close_timeout):
await self._first_result_event.wait()
logger.debug("Received first result, closing input stream")
except Exception:
logger.debug(
"Timed out waiting for first result, closing input stream"
)
# After all messages sent (and result received if needed), end input
await self.transport.end_input()
except Exception as e:
logger.debug(f"Error streaming input: {e}")
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Receive SDK messages (not control messages)."""
async for message in self._message_receive:
# Check for special messages
if message.get("type") == "end":
break
elif message.get("type") == "error":
raise Exception(message.get("error", "Unknown error"))
yield message
async def close(self) -> None:
"""Close the query and transport."""
self._closed = True
if self._tg:
self._tg.cancel_scope.cancel()
# Wait for task group to complete cancellation
with suppress(anyio.get_cancelled_exc_class()):
await self._tg.__aexit__(None, None, None)
await self.transport.close()
# Make Query an async iterator
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
"""Return async iterator for messages."""
return self.receive_messages()
async def __anext__(self) -> dict[str, Any]:
"""Get next message."""
async for message in self.receive_messages():
return message
raise StopAsyncIteration

View File

@@ -0,0 +1,68 @@
"""Transport implementations for Claude SDK."""
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
class Transport(ABC):
"""Abstract transport for Claude communication.
WARNING: This internal API is exposed for custom transport implementations
(e.g., remote Claude Code connections). The Claude Code team may change or
or remove this abstract class in any future release. Custom implementations
must be updated to match interface changes.
This is a low-level transport interface that handles raw I/O with the Claude
process or service. The Query class builds on top of this to implement the
control protocol and message routing.
"""
@abstractmethod
async def connect(self) -> None:
"""Connect the transport and prepare for communication.
For subprocess transports, this starts the process.
For network transports, this establishes the connection.
"""
pass
@abstractmethod
async def write(self, data: str) -> None:
"""Write raw data to the transport.
Args:
data: Raw string data to write (typically JSON + newline)
"""
pass
@abstractmethod
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport.
Yields:
Parsed JSON messages from the transport
"""
pass
@abstractmethod
async def close(self) -> None:
"""Close the transport connection and clean up resources."""
pass
@abstractmethod
def is_ready(self) -> bool:
"""Check if transport is ready for communication.
Returns:
True if transport is ready to send/receive messages
"""
pass
@abstractmethod
async def end_input(self) -> None:
"""End the input stream (close stdin for process transports)."""
pass
__all__ = ["Transport"]

View File

@@ -0,0 +1,672 @@
"""Subprocess transport implementation using Claude Code CLI."""
import json
import logging
import os
import platform
import re
import shutil
import sys
import tempfile
from collections.abc import AsyncIterable, AsyncIterator
from contextlib import suppress
from dataclasses import asdict
from pathlib import Path
from subprocess import PIPE
from typing import Any
import anyio
import anyio.abc
from anyio.abc import Process
from anyio.streams.text import TextReceiveStream, TextSendStream
from ..._errors import CLIConnectionError, CLINotFoundError, ProcessError
from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
from ..._version import __version__
from ...types import ClaudeAgentOptions
from . import Transport
logger = logging.getLogger(__name__)
_DEFAULT_MAX_BUFFER_SIZE = 1024 * 1024 # 1MB buffer limit
MINIMUM_CLAUDE_CODE_VERSION = "2.0.0"
# Platform-specific command line length limits
# Windows cmd.exe has a limit of 8191 characters, use 8000 for safety
# Other platforms have much higher limits
_CMD_LENGTH_LIMIT = 8000 if platform.system() == "Windows" else 100000
class SubprocessCLITransport(Transport):
"""Subprocess transport using Claude Code CLI."""
def __init__(
self,
prompt: str | AsyncIterable[dict[str, Any]],
options: ClaudeAgentOptions,
):
self._prompt = prompt
self._is_streaming = not isinstance(prompt, str)
self._options = options
self._cli_path = (
str(options.cli_path) if options.cli_path is not None else self._find_cli()
)
self._cwd = str(options.cwd) if options.cwd else None
self._process: Process | None = None
self._stdout_stream: TextReceiveStream | None = None
self._stdin_stream: TextSendStream | None = None
self._stderr_stream: TextReceiveStream | None = None
self._stderr_task_group: anyio.abc.TaskGroup | None = None
self._ready = False
self._exit_error: Exception | None = None # Track process exit errors
self._max_buffer_size = (
options.max_buffer_size
if options.max_buffer_size is not None
else _DEFAULT_MAX_BUFFER_SIZE
)
self._temp_files: list[str] = [] # Track temporary files for cleanup
self._write_lock: anyio.Lock = anyio.Lock()
def _find_cli(self) -> str:
"""Find Claude Code CLI binary."""
# First, check for bundled CLI
bundled_cli = self._find_bundled_cli()
if bundled_cli:
return bundled_cli
# Fall back to system-wide search
if cli := shutil.which("claude"):
return cli
locations = [
Path.home() / ".npm-global/bin/claude",
Path("/usr/local/bin/claude"),
Path.home() / ".local/bin/claude",
Path.home() / "node_modules/.bin/claude",
Path.home() / ".yarn/bin/claude",
Path.home() / ".claude/local/claude",
]
for path in locations:
if path.exists() and path.is_file():
return str(path)
raise CLINotFoundError(
"Claude Code not found. Install with:\n"
" npm install -g @anthropic-ai/claude-code\n"
"\nIf already installed locally, try:\n"
' export PATH="$HOME/node_modules/.bin:$PATH"\n'
"\nOr provide the path via ClaudeAgentOptions:\n"
" ClaudeAgentOptions(cli_path='/path/to/claude')"
)
def _find_bundled_cli(self) -> str | None:
"""Find bundled CLI binary if it exists."""
# Determine the CLI binary name based on platform
cli_name = "claude.exe" if platform.system() == "Windows" else "claude"
# Get the path to the bundled CLI
# The _bundled directory is in the same package as this module
bundled_path = Path(__file__).parent.parent.parent / "_bundled" / cli_name
if bundled_path.exists() and bundled_path.is_file():
logger.info(f"Using bundled Claude Code CLI: {bundled_path}")
return str(bundled_path)
return None
def _build_settings_value(self) -> str | None:
"""Build settings value, merging sandbox settings if provided.
Returns the settings value as either:
- A JSON string (if sandbox is provided or settings is JSON)
- A file path (if only settings path is provided without sandbox)
- None if neither settings nor sandbox is provided
"""
has_settings = self._options.settings is not None
has_sandbox = self._options.sandbox is not None
if not has_settings and not has_sandbox:
return None
# If only settings path and no sandbox, pass through as-is
if has_settings and not has_sandbox:
return self._options.settings
# If we have sandbox settings, we need to merge into a JSON object
settings_obj: dict[str, Any] = {}
if has_settings:
assert self._options.settings is not None
settings_str = self._options.settings.strip()
# Check if settings is a JSON string or a file path
if settings_str.startswith("{") and settings_str.endswith("}"):
# Parse JSON string
try:
settings_obj = json.loads(settings_str)
except json.JSONDecodeError:
# If parsing fails, treat as file path
logger.warning(
f"Failed to parse settings as JSON, treating as file path: {settings_str}"
)
# Read the file
settings_path = Path(settings_str)
if settings_path.exists():
with settings_path.open(encoding="utf-8") as f:
settings_obj = json.load(f)
else:
# It's a file path - read and parse
settings_path = Path(settings_str)
if settings_path.exists():
with settings_path.open(encoding="utf-8") as f:
settings_obj = json.load(f)
else:
logger.warning(f"Settings file not found: {settings_path}")
# Merge sandbox settings
if has_sandbox:
settings_obj["sandbox"] = self._options.sandbox
return json.dumps(settings_obj)
def _build_command(self) -> list[str]:
"""Build CLI command with arguments."""
cmd = [self._cli_path, "--output-format", "stream-json", "--verbose"]
if self._options.system_prompt is None:
cmd.extend(["--system-prompt", ""])
elif isinstance(self._options.system_prompt, str):
cmd.extend(["--system-prompt", self._options.system_prompt])
else:
if (
self._options.system_prompt.get("type") == "preset"
and "append" in self._options.system_prompt
):
cmd.extend(
["--append-system-prompt", self._options.system_prompt["append"]]
)
# Handle tools option (base set of tools)
if self._options.tools is not None:
tools = self._options.tools
if isinstance(tools, list):
if len(tools) == 0:
cmd.extend(["--tools", ""])
else:
cmd.extend(["--tools", ",".join(tools)])
else:
# Preset object - 'claude_code' preset maps to 'default'
cmd.extend(["--tools", "default"])
if self._options.allowed_tools:
cmd.extend(["--allowedTools", ",".join(self._options.allowed_tools)])
if self._options.max_turns:
cmd.extend(["--max-turns", str(self._options.max_turns)])
if self._options.max_budget_usd is not None:
cmd.extend(["--max-budget-usd", str(self._options.max_budget_usd)])
if self._options.disallowed_tools:
cmd.extend(["--disallowedTools", ",".join(self._options.disallowed_tools)])
if self._options.model:
cmd.extend(["--model", self._options.model])
if self._options.fallback_model:
cmd.extend(["--fallback-model", self._options.fallback_model])
if self._options.betas:
cmd.extend(["--betas", ",".join(self._options.betas)])
if self._options.permission_prompt_tool_name:
cmd.extend(
["--permission-prompt-tool", self._options.permission_prompt_tool_name]
)
if self._options.permission_mode:
cmd.extend(["--permission-mode", self._options.permission_mode])
if self._options.continue_conversation:
cmd.append("--continue")
if self._options.resume:
cmd.extend(["--resume", self._options.resume])
# Handle settings and sandbox: merge sandbox into settings if both are provided
settings_value = self._build_settings_value()
if settings_value:
cmd.extend(["--settings", settings_value])
if self._options.add_dirs:
# Convert all paths to strings and add each directory
for directory in self._options.add_dirs:
cmd.extend(["--add-dir", str(directory)])
if self._options.mcp_servers:
if isinstance(self._options.mcp_servers, dict):
# Process all servers, stripping instance field from SDK servers
servers_for_cli: dict[str, Any] = {}
for name, config in self._options.mcp_servers.items():
if isinstance(config, dict) and config.get("type") == "sdk":
# For SDK servers, pass everything except the instance field
sdk_config: dict[str, object] = {
k: v for k, v in config.items() if k != "instance"
}
servers_for_cli[name] = sdk_config
else:
# For external servers, pass as-is
servers_for_cli[name] = config
# Pass all servers to CLI
if servers_for_cli:
cmd.extend(
[
"--mcp-config",
json.dumps({"mcpServers": servers_for_cli}),
]
)
else:
# String or Path format: pass directly as file path or JSON string
cmd.extend(["--mcp-config", str(self._options.mcp_servers)])
if self._options.include_partial_messages:
cmd.append("--include-partial-messages")
if self._options.fork_session:
cmd.append("--fork-session")
if self._options.agents:
agents_dict = {
name: {k: v for k, v in asdict(agent_def).items() if v is not None}
for name, agent_def in self._options.agents.items()
}
agents_json = json.dumps(agents_dict)
cmd.extend(["--agents", agents_json])
sources_value = (
",".join(self._options.setting_sources)
if self._options.setting_sources is not None
else ""
)
cmd.extend(["--setting-sources", sources_value])
# Add plugin directories
if self._options.plugins:
for plugin in self._options.plugins:
if plugin["type"] == "local":
cmd.extend(["--plugin-dir", plugin["path"]])
else:
raise ValueError(f"Unsupported plugin type: {plugin['type']}")
# Add extra args for future CLI flags
for flag, value in self._options.extra_args.items():
if value is None:
# Boolean flag without value
cmd.append(f"--{flag}")
else:
# Flag with value
cmd.extend([f"--{flag}", str(value)])
if self._options.max_thinking_tokens is not None:
cmd.extend(
["--max-thinking-tokens", str(self._options.max_thinking_tokens)]
)
# Extract schema from output_format structure if provided
# Expected: {"type": "json_schema", "schema": {...}}
if (
self._options.output_format is not None
and isinstance(self._options.output_format, dict)
and self._options.output_format.get("type") == "json_schema"
):
schema = self._options.output_format.get("schema")
if schema is not None:
cmd.extend(["--json-schema", json.dumps(schema)])
# Add prompt handling based on mode
# IMPORTANT: This must come AFTER all flags because everything after "--" is treated as arguments
if self._is_streaming:
# Streaming mode: use --input-format stream-json
cmd.extend(["--input-format", "stream-json"])
else:
# String mode: use --print with the prompt
cmd.extend(["--print", "--", str(self._prompt)])
# Check if command line is too long (Windows limitation)
cmd_str = " ".join(cmd)
if len(cmd_str) > _CMD_LENGTH_LIMIT and self._options.agents:
# Command is too long - use temp file for agents
# Find the --agents argument and replace its value with @filepath
try:
agents_idx = cmd.index("--agents")
agents_json_value = cmd[agents_idx + 1]
# Create a temporary file
# ruff: noqa: SIM115
temp_file = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False, encoding="utf-8"
)
temp_file.write(agents_json_value)
temp_file.close()
# Track for cleanup
self._temp_files.append(temp_file.name)
# Replace agents JSON with @filepath reference
cmd[agents_idx + 1] = f"@{temp_file.name}"
logger.info(
f"Command line length ({len(cmd_str)}) exceeds limit ({_CMD_LENGTH_LIMIT}). "
f"Using temp file for --agents: {temp_file.name}"
)
except (ValueError, IndexError) as e:
# This shouldn't happen, but log it just in case
logger.warning(f"Failed to optimize command line length: {e}")
return cmd
async def connect(self) -> None:
"""Start subprocess."""
if self._process:
return
if not os.environ.get("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK"):
await self._check_claude_version()
cmd = self._build_command()
try:
# Merge environment variables: system -> user -> SDK required
process_env = {
**os.environ,
**self._options.env, # User-provided env vars
"CLAUDE_CODE_ENTRYPOINT": "sdk-py",
"CLAUDE_AGENT_SDK_VERSION": __version__,
}
# Enable file checkpointing if requested
if self._options.enable_file_checkpointing:
process_env["CLAUDE_CODE_ENABLE_SDK_FILE_CHECKPOINTING"] = "true"
if self._cwd:
process_env["PWD"] = self._cwd
# Pipe stderr if we have a callback OR debug mode is enabled
should_pipe_stderr = (
self._options.stderr is not None
or "debug-to-stderr" in self._options.extra_args
)
# For backward compat: use debug_stderr file object if no callback and debug is on
stderr_dest = PIPE if should_pipe_stderr else None
self._process = await anyio.open_process(
cmd,
stdin=PIPE,
stdout=PIPE,
stderr=stderr_dest,
cwd=self._cwd,
env=process_env,
user=self._options.user,
)
if self._process.stdout:
self._stdout_stream = TextReceiveStream(self._process.stdout)
# Setup stderr stream if piped
if should_pipe_stderr and self._process.stderr:
self._stderr_stream = TextReceiveStream(self._process.stderr)
# Start async task to read stderr
self._stderr_task_group = anyio.create_task_group()
await self._stderr_task_group.__aenter__()
self._stderr_task_group.start_soon(self._handle_stderr)
# Setup stdin for streaming mode
if self._is_streaming and self._process.stdin:
self._stdin_stream = TextSendStream(self._process.stdin)
elif not self._is_streaming and self._process.stdin:
# String mode: close stdin immediately
await self._process.stdin.aclose()
self._ready = True
except FileNotFoundError as e:
# Check if the error comes from the working directory or the CLI
if self._cwd and not Path(self._cwd).exists():
error = CLIConnectionError(
f"Working directory does not exist: {self._cwd}"
)
self._exit_error = error
raise error from e
error = CLINotFoundError(f"Claude Code not found at: {self._cli_path}")
self._exit_error = error
raise error from e
except Exception as e:
error = CLIConnectionError(f"Failed to start Claude Code: {e}")
self._exit_error = error
raise error from e
async def _handle_stderr(self) -> None:
"""Handle stderr stream - read and invoke callbacks."""
if not self._stderr_stream:
return
try:
async for line in self._stderr_stream:
line_str = line.rstrip()
if not line_str:
continue
# Call the stderr callback if provided
if self._options.stderr:
self._options.stderr(line_str)
# For backward compatibility: write to debug_stderr if in debug mode
elif (
"debug-to-stderr" in self._options.extra_args
and self._options.debug_stderr
):
self._options.debug_stderr.write(line_str + "\n")
if hasattr(self._options.debug_stderr, "flush"):
self._options.debug_stderr.flush()
except anyio.ClosedResourceError:
pass # Stream closed, exit normally
except Exception:
pass # Ignore other errors during stderr reading
async def close(self) -> None:
"""Close the transport and clean up resources."""
# Clean up temporary files first (before early return)
for temp_file in self._temp_files:
with suppress(Exception):
Path(temp_file).unlink(missing_ok=True)
self._temp_files.clear()
if not self._process:
self._ready = False
return
# Close stderr task group if active
if self._stderr_task_group:
with suppress(Exception):
self._stderr_task_group.cancel_scope.cancel()
await self._stderr_task_group.__aexit__(None, None, None)
self._stderr_task_group = None
# Close stdin stream (acquire lock to prevent race with concurrent writes)
async with self._write_lock:
self._ready = False # Set inside lock to prevent TOCTOU with write()
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
if self._stderr_stream:
with suppress(Exception):
await self._stderr_stream.aclose()
self._stderr_stream = None
# Terminate and wait for process
if self._process.returncode is None:
with suppress(ProcessLookupError):
self._process.terminate()
# Wait for process to finish with timeout
with suppress(Exception):
# Just try to wait, but don't block if it fails
await self._process.wait()
self._process = None
self._stdout_stream = None
self._stdin_stream = None
self._stderr_stream = None
self._exit_error = None
async def write(self, data: str) -> None:
"""Write raw data to the transport."""
async with self._write_lock:
# All checks inside lock to prevent TOCTOU races with close()/end_input()
if not self._ready or not self._stdin_stream:
raise CLIConnectionError("ProcessTransport is not ready for writing")
if self._process and self._process.returncode is not None:
raise CLIConnectionError(
f"Cannot write to terminated process (exit code: {self._process.returncode})"
)
if self._exit_error:
raise CLIConnectionError(
f"Cannot write to process that exited with error: {self._exit_error}"
) from self._exit_error
try:
await self._stdin_stream.send(data)
except Exception as e:
self._ready = False
self._exit_error = CLIConnectionError(
f"Failed to write to process stdin: {e}"
)
raise self._exit_error from e
async def end_input(self) -> None:
"""End the input stream (close stdin)."""
async with self._write_lock:
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport."""
return self._read_messages_impl()
async def _read_messages_impl(self) -> AsyncIterator[dict[str, Any]]:
"""Internal implementation of read_messages."""
if not self._process or not self._stdout_stream:
raise CLIConnectionError("Not connected")
json_buffer = ""
# Process stdout messages
try:
async for line in self._stdout_stream:
line_str = line.strip()
if not line_str:
continue
# Accumulate partial JSON until we can parse it
# Note: TextReceiveStream can truncate long lines, so we need to buffer
# and speculatively parse until we get a complete JSON object
json_lines = line_str.split("\n")
for json_line in json_lines:
json_line = json_line.strip()
if not json_line:
continue
# Keep accumulating partial JSON until we can parse it
json_buffer += json_line
if len(json_buffer) > self._max_buffer_size:
buffer_length = len(json_buffer)
json_buffer = ""
raise SDKJSONDecodeError(
f"JSON message exceeded maximum buffer size of {self._max_buffer_size} bytes",
ValueError(
f"Buffer size {buffer_length} exceeds limit {self._max_buffer_size}"
),
)
try:
data = json.loads(json_buffer)
json_buffer = ""
yield data
except json.JSONDecodeError:
# We are speculatively decoding the buffer until we get
# a full JSON object. If there is an actual issue, we
# raise an error after exceeding the configured limit.
continue
except anyio.ClosedResourceError:
pass
except GeneratorExit:
# Client disconnected
pass
# Check process completion and handle errors
try:
returncode = await self._process.wait()
except Exception:
returncode = -1
# Use exit code for error detection
if returncode is not None and returncode != 0:
self._exit_error = ProcessError(
f"Command failed with exit code {returncode}",
exit_code=returncode,
stderr="Check stderr output for details",
)
raise self._exit_error
async def _check_claude_version(self) -> None:
"""Check Claude Code version and warn if below minimum."""
version_process = None
try:
with anyio.fail_after(2): # 2 second timeout
version_process = await anyio.open_process(
[self._cli_path, "-v"],
stdout=PIPE,
stderr=PIPE,
)
if version_process.stdout:
stdout_bytes = await version_process.stdout.receive()
version_output = stdout_bytes.decode().strip()
match = re.match(r"([0-9]+\.[0-9]+\.[0-9]+)", version_output)
if match:
version = match.group(1)
version_parts = [int(x) for x in version.split(".")]
min_parts = [
int(x) for x in MINIMUM_CLAUDE_CODE_VERSION.split(".")
]
if version_parts < min_parts:
warning = (
f"Warning: Claude Code version {version} is unsupported in the Agent SDK. "
f"Minimum required version is {MINIMUM_CLAUDE_CODE_VERSION}. "
"Some features may not work correctly."
)
logger.warning(warning)
print(warning, file=sys.stderr)
except Exception:
pass
finally:
if version_process:
with suppress(Exception):
version_process.terminate()
with suppress(Exception):
await version_process.wait()
def is_ready(self) -> bool:
"""Check if transport is ready for communication."""
return self._ready