Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Experimental client features.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from mcp.client.experimental.tasks import ExperimentalClientFeatures
|
||||
|
||||
__all__ = ["ExperimentalClientFeatures"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Experimental task handler protocols for server -> client requests.
|
||||
|
||||
This module provides Protocol types and default handlers for when servers
|
||||
send task-related requests to clients (the reverse of normal client -> server flow).
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Use cases:
|
||||
- Server sends task-augmented sampling/elicitation request to client
|
||||
- Client creates a local task, spawns background work, returns CreateTaskResult
|
||||
- Server polls client's task status via tasks/get, tasks/result, etc.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
|
||||
class GetTaskHandlerFnT(Protocol):
|
||||
"""Handler for tasks/get requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskRequestParams,
|
||||
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class GetTaskResultHandlerFnT(Protocol):
|
||||
"""Handler for tasks/result requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskPayloadRequestParams,
|
||||
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ListTasksHandlerFnT(Protocol):
|
||||
"""Handler for tasks/list requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.PaginatedRequestParams | None,
|
||||
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class CancelTaskHandlerFnT(Protocol):
|
||||
"""Handler for tasks/cancel requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CancelTaskRequestParams,
|
||||
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class TaskAugmentedSamplingFnT(Protocol):
|
||||
"""Handler for task-augmented sampling/createMessage requests from server.
|
||||
|
||||
When server sends a CreateMessageRequest with task field, this callback
|
||||
is invoked. The callback should create a task, spawn background work,
|
||||
and return CreateTaskResult immediately.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class TaskAugmentedElicitationFnT(Protocol):
|
||||
"""Handler for task-augmented elicitation/create requests from server.
|
||||
|
||||
When server sends an ElicitRequest with task field, this callback
|
||||
is invoked. The callback should create a task, spawn background work,
|
||||
and return CreateTaskResult immediately.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
async def default_get_task_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskRequestParams,
|
||||
) -> types.GetTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/get not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_get_task_result_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskPayloadRequestParams,
|
||||
) -> types.GetTaskPayloadResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/result not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_list_tasks_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.PaginatedRequestParams | None,
|
||||
) -> types.ListTasksResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/list not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_cancel_task_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CancelTaskRequestParams,
|
||||
) -> types.CancelTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/cancel not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_task_augmented_sampling(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Task-augmented sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_task_augmented_elicitation(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Task-augmented elicitation not supported",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentalTaskHandlers:
|
||||
"""Container for experimental task handlers.
|
||||
|
||||
Groups all task-related handlers that handle server -> client requests.
|
||||
This includes both pure task requests (get, list, cancel, result) and
|
||||
task-augmented request handlers (sampling, elicitation with task field).
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Example:
|
||||
handlers = ExperimentalTaskHandlers(
|
||||
get_task=my_get_task_handler,
|
||||
list_tasks=my_list_tasks_handler,
|
||||
)
|
||||
session = ClientSession(..., experimental_task_handlers=handlers)
|
||||
"""
|
||||
|
||||
# Pure task request handlers
|
||||
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
|
||||
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
|
||||
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
|
||||
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
|
||||
|
||||
# Task-augmented request handlers
|
||||
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
|
||||
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
|
||||
|
||||
def build_capability(self) -> types.ClientTasksCapability | None:
|
||||
"""Build ClientTasksCapability from the configured handlers.
|
||||
|
||||
Returns a capability object that reflects which handlers are configured
|
||||
(i.e., not using the default "not supported" handlers).
|
||||
|
||||
Returns:
|
||||
ClientTasksCapability if any handlers are provided, None otherwise
|
||||
"""
|
||||
has_list = self.list_tasks is not default_list_tasks_handler
|
||||
has_cancel = self.cancel_task is not default_cancel_task_handler
|
||||
has_sampling = self.augmented_sampling is not default_task_augmented_sampling
|
||||
has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation
|
||||
|
||||
# If no handlers are provided, return None
|
||||
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
|
||||
return None
|
||||
|
||||
# Build requests capability if any request handlers are provided
|
||||
requests_capability: types.ClientTasksRequestsCapability | None = None
|
||||
if has_sampling or has_elicitation:
|
||||
requests_capability = types.ClientTasksRequestsCapability(
|
||||
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
|
||||
if has_sampling
|
||||
else None,
|
||||
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
|
||||
if has_elicitation
|
||||
else None,
|
||||
)
|
||||
|
||||
return types.ClientTasksCapability(
|
||||
list=types.TasksListCapability() if has_list else None,
|
||||
cancel=types.TasksCancelCapability() if has_cancel else None,
|
||||
requests=requests_capability,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handles_request(request: types.ServerRequest) -> bool:
|
||||
"""Check if this handler handles the given request type."""
|
||||
return isinstance(
|
||||
request.root,
|
||||
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
|
||||
)
|
||||
|
||||
async def handle_request(
|
||||
self,
|
||||
ctx: RequestContext["ClientSession", Any],
|
||||
responder: RequestResponder[types.ServerRequest, types.ClientResult],
|
||||
) -> None:
|
||||
"""Handle a task-related request from the server.
|
||||
|
||||
Call handles_request() first to check if this handler can handle the request.
|
||||
"""
|
||||
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
|
||||
types.ClientResult | types.ErrorData
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.GetTaskRequest(params=params):
|
||||
response = await self.get_task(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.GetTaskPayloadRequest(params=params):
|
||||
response = await self.get_task_result(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListTasksRequest(params=params):
|
||||
response = await self.list_tasks(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.CancelTaskRequest(params=params):
|
||||
response = await self.cancel_task(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case _: # pragma: no cover
|
||||
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
|
||||
|
||||
|
||||
# Backwards compatibility aliases
|
||||
default_task_augmented_sampling_callback = default_task_augmented_sampling
|
||||
default_task_augmented_elicitation_callback = default_task_augmented_elicitation
|
||||
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Experimental client-side task support.
|
||||
|
||||
This module provides client methods for interacting with MCP tasks.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Example:
|
||||
# Call a tool as a task
|
||||
result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"})
|
||||
task_id = result.task.taskId
|
||||
|
||||
# Get task status
|
||||
status = await session.experimental.get_task(task_id)
|
||||
|
||||
# Get task result when complete
|
||||
if status.status == "completed":
|
||||
result = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
|
||||
# List all tasks
|
||||
tasks = await session.experimental.list_tasks()
|
||||
|
||||
# Cancel a task
|
||||
await session.experimental.cancel_task(task_id)
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.experimental.tasks.polling import poll_until_terminal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
ResultT = TypeVar("ResultT", bound=types.Result)
|
||||
|
||||
|
||||
class ExperimentalClientFeatures:
|
||||
"""
|
||||
Experimental client features for tasks and other experimental APIs.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Access via session.experimental:
|
||||
status = await session.experimental.get_task(task_id)
|
||||
"""
|
||||
|
||||
def __init__(self, session: "ClientSession") -> None:
|
||||
self._session = session
|
||||
|
||||
async def call_tool_as_task(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
*,
|
||||
ttl: int = 60000,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CreateTaskResult:
|
||||
"""Call a tool as a task, returning a CreateTaskResult for polling.
|
||||
|
||||
This is a convenience method for calling tools that support task execution.
|
||||
The server will return a task reference instead of the immediate result,
|
||||
which can then be polled via `get_task()` and retrieved via `get_task_result()`.
|
||||
|
||||
Args:
|
||||
name: The tool name
|
||||
arguments: Tool arguments
|
||||
ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute)
|
||||
meta: Optional metadata to include in the request
|
||||
|
||||
Returns:
|
||||
CreateTaskResult containing the task reference
|
||||
|
||||
Example:
|
||||
# Create task
|
||||
result = await session.experimental.call_tool_as_task(
|
||||
"long_running_tool", {"input": "data"}
|
||||
)
|
||||
task_id = result.task.taskId
|
||||
|
||||
# Poll for completion
|
||||
while True:
|
||||
status = await session.experimental.get_task(task_id)
|
||||
if status.status == "completed":
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get result
|
||||
final = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
"""
|
||||
_meta: types.RequestParams.Meta | None = None
|
||||
if meta is not None:
|
||||
_meta = types.RequestParams.Meta(**meta)
|
||||
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
params=types.CallToolRequestParams(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
task=types.TaskMetadata(ttl=ttl),
|
||||
_meta=_meta,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CreateTaskResult,
|
||||
)
|
||||
|
||||
async def get_task(self, task_id: str) -> types.GetTaskResult:
|
||||
"""
|
||||
Get the current status of a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
GetTaskResult containing the task status and metadata
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetTaskRequest(
|
||||
params=types.GetTaskRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
types.GetTaskResult,
|
||||
)
|
||||
|
||||
async def get_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
result_type: type[ResultT],
|
||||
) -> ResultT:
|
||||
"""
|
||||
Get the result of a completed task.
|
||||
|
||||
The result type depends on the original request type:
|
||||
- tools/call tasks return CallToolResult
|
||||
- Other request types return their corresponding result type
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
result_type: The expected result type (e.g., CallToolResult)
|
||||
|
||||
Returns:
|
||||
The task result, validated against result_type
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetTaskPayloadRequest(
|
||||
params=types.GetTaskPayloadRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
result_type,
|
||||
)
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
) -> types.ListTasksResult:
|
||||
"""
|
||||
List all tasks.
|
||||
|
||||
Args:
|
||||
cursor: Optional pagination cursor
|
||||
|
||||
Returns:
|
||||
ListTasksResult containing tasks and optional next cursor
|
||||
"""
|
||||
params = types.PaginatedRequestParams(cursor=cursor) if cursor else None
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListTasksRequest(params=params),
|
||||
),
|
||||
types.ListTasksResult,
|
||||
)
|
||||
|
||||
async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
|
||||
"""
|
||||
Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
CancelTaskResult with the updated task state
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.CancelTaskRequest(
|
||||
params=types.CancelTaskRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
types.CancelTaskResult,
|
||||
)
|
||||
|
||||
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
|
||||
"""
|
||||
Poll a task until it reaches a terminal status.
|
||||
|
||||
Yields GetTaskResult for each poll, allowing the caller to react to
|
||||
status changes (e.g., handle input_required). Exits when task reaches
|
||||
a terminal status (completed, failed, cancelled).
|
||||
|
||||
Respects the pollInterval hint from the server.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Yields:
|
||||
GetTaskResult for each poll
|
||||
|
||||
Example:
|
||||
async for status in session.experimental.poll_task(task_id):
|
||||
print(f"Status: {status.status}")
|
||||
if status.status == "input_required":
|
||||
# Handle elicitation request via tasks/result
|
||||
pass
|
||||
|
||||
# Task is now terminal, get the result
|
||||
result = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
"""
|
||||
async for status in poll_until_terminal(self.get_task, task_id):
|
||||
yield status
|
||||
Reference in New Issue
Block a user