Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .server import NotificationOptions, Server
|
||||
|
||||
__all__ = ["Server", "NotificationOptions"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,288 @@
|
||||
"""Experimental handlers for the low-level MCP server.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from mcp.server.experimental.task_support import TaskSupport
|
||||
from mcp.server.lowlevel.func_inspection import create_call_wrapper
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.helpers import cancel_task
|
||||
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
|
||||
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import (
|
||||
INVALID_PARAMS,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResult,
|
||||
ErrorData,
|
||||
GetTaskPayloadRequest,
|
||||
GetTaskPayloadResult,
|
||||
GetTaskRequest,
|
||||
GetTaskResult,
|
||||
ListTasksRequest,
|
||||
ListTasksResult,
|
||||
ServerCapabilities,
|
||||
ServerResult,
|
||||
ServerTasksCapability,
|
||||
ServerTasksRequestsCapability,
|
||||
TasksCancelCapability,
|
||||
TasksListCapability,
|
||||
TasksToolsCapability,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.lowlevel.server import Server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExperimentalHandlers:
|
||||
"""Experimental request/notification handlers.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: Server,
|
||||
request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]],
|
||||
notification_handlers: dict[type, Callable[..., Awaitable[None]]],
|
||||
):
|
||||
self._server = server
|
||||
self._request_handlers = request_handlers
|
||||
self._notification_handlers = notification_handlers
|
||||
self._task_support: TaskSupport | None = None
|
||||
|
||||
@property
|
||||
def task_support(self) -> TaskSupport | None:
|
||||
"""Get the task support configuration, if enabled."""
|
||||
return self._task_support
|
||||
|
||||
def update_capabilities(self, capabilities: ServerCapabilities) -> None:
|
||||
# Only add tasks capability if handlers are registered
|
||||
if not any(
|
||||
req_type in self._request_handlers
|
||||
for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest]
|
||||
):
|
||||
return
|
||||
|
||||
capabilities.tasks = ServerTasksCapability()
|
||||
if ListTasksRequest in self._request_handlers:
|
||||
capabilities.tasks.list = TasksListCapability()
|
||||
if CancelTaskRequest in self._request_handlers:
|
||||
capabilities.tasks.cancel = TasksCancelCapability()
|
||||
|
||||
capabilities.tasks.requests = ServerTasksRequestsCapability(
|
||||
tools=TasksToolsCapability()
|
||||
) # assuming always supported for now
|
||||
|
||||
def enable_tasks(
|
||||
self,
|
||||
store: TaskStore | None = None,
|
||||
queue: TaskMessageQueue | None = None,
|
||||
) -> TaskSupport:
|
||||
"""
|
||||
Enable experimental task support.
|
||||
|
||||
This sets up the task infrastructure and auto-registers default handlers
|
||||
for tasks/get, tasks/result, tasks/list, and tasks/cancel.
|
||||
|
||||
Args:
|
||||
store: Custom TaskStore implementation (defaults to InMemoryTaskStore)
|
||||
queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue)
|
||||
|
||||
Returns:
|
||||
The TaskSupport configuration object
|
||||
|
||||
Example:
|
||||
# Simple in-memory setup
|
||||
server.experimental.enable_tasks()
|
||||
|
||||
# Custom store/queue for distributed systems
|
||||
server.experimental.enable_tasks(
|
||||
store=RedisTaskStore(redis_url),
|
||||
queue=RedisTaskMessageQueue(redis_url),
|
||||
)
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
if store is None:
|
||||
store = InMemoryTaskStore()
|
||||
if queue is None:
|
||||
queue = InMemoryTaskMessageQueue()
|
||||
|
||||
self._task_support = TaskSupport(store=store, queue=queue)
|
||||
|
||||
# Auto-register default handlers
|
||||
self._register_default_task_handlers()
|
||||
|
||||
return self._task_support
|
||||
|
||||
def _register_default_task_handlers(self) -> None:
|
||||
"""Register default handlers for task operations."""
|
||||
assert self._task_support is not None
|
||||
support = self._task_support
|
||||
|
||||
# Register get_task handler if not already registered
|
||||
if GetTaskRequest not in self._request_handlers:
|
||||
|
||||
async def _default_get_task(req: GetTaskRequest) -> ServerResult:
|
||||
task = await support.store.get_task(req.params.taskId)
|
||||
if task is None:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message=f"Task not found: {req.params.taskId}",
|
||||
)
|
||||
)
|
||||
return ServerResult(
|
||||
GetTaskResult(
|
||||
taskId=task.taskId,
|
||||
status=task.status,
|
||||
statusMessage=task.statusMessage,
|
||||
createdAt=task.createdAt,
|
||||
lastUpdatedAt=task.lastUpdatedAt,
|
||||
ttl=task.ttl,
|
||||
pollInterval=task.pollInterval,
|
||||
)
|
||||
)
|
||||
|
||||
self._request_handlers[GetTaskRequest] = _default_get_task
|
||||
|
||||
# Register get_task_result handler if not already registered
|
||||
if GetTaskPayloadRequest not in self._request_handlers:
|
||||
|
||||
async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult:
|
||||
ctx = self._server.request_context
|
||||
result = await support.handler.handle(req, ctx.session, ctx.request_id)
|
||||
return ServerResult(result)
|
||||
|
||||
self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result
|
||||
|
||||
# Register list_tasks handler if not already registered
|
||||
if ListTasksRequest not in self._request_handlers:
|
||||
|
||||
async def _default_list_tasks(req: ListTasksRequest) -> ServerResult:
|
||||
cursor = req.params.cursor if req.params else None
|
||||
tasks, next_cursor = await support.store.list_tasks(cursor)
|
||||
return ServerResult(ListTasksResult(tasks=tasks, nextCursor=next_cursor))
|
||||
|
||||
self._request_handlers[ListTasksRequest] = _default_list_tasks
|
||||
|
||||
# Register cancel_task handler if not already registered
|
||||
if CancelTaskRequest not in self._request_handlers:
|
||||
|
||||
async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult:
|
||||
result = await cancel_task(support.store, req.params.taskId)
|
||||
return ServerResult(result)
|
||||
|
||||
self._request_handlers[CancelTaskRequest] = _default_cancel_task
|
||||
|
||||
def list_tasks(
|
||||
self,
|
||||
) -> Callable[
|
||||
[Callable[[ListTasksRequest], Awaitable[ListTasksResult]]],
|
||||
Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
|
||||
]:
|
||||
"""Register a handler for listing tasks.
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
|
||||
) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]:
|
||||
logger.debug("Registering handler for ListTasksRequest")
|
||||
wrapper = create_call_wrapper(func, ListTasksRequest)
|
||||
|
||||
async def handler(req: ListTasksRequest) -> ServerResult:
|
||||
result = await wrapper(req)
|
||||
return ServerResult(result)
|
||||
|
||||
self._request_handlers[ListTasksRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_task(
|
||||
self,
|
||||
) -> Callable[
|
||||
[Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]]
|
||||
]:
|
||||
"""Register a handler for getting task status.
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]],
|
||||
) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]:
|
||||
logger.debug("Registering handler for GetTaskRequest")
|
||||
wrapper = create_call_wrapper(func, GetTaskRequest)
|
||||
|
||||
async def handler(req: GetTaskRequest) -> ServerResult:
|
||||
result = await wrapper(req)
|
||||
return ServerResult(result)
|
||||
|
||||
self._request_handlers[GetTaskRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_task_result(
|
||||
self,
|
||||
) -> Callable[
|
||||
[Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]],
|
||||
Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]],
|
||||
]:
|
||||
"""Register a handler for getting task results/payload.
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]],
|
||||
) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]:
|
||||
logger.debug("Registering handler for GetTaskPayloadRequest")
|
||||
wrapper = create_call_wrapper(func, GetTaskPayloadRequest)
|
||||
|
||||
async def handler(req: GetTaskPayloadRequest) -> ServerResult:
|
||||
result = await wrapper(req)
|
||||
return ServerResult(result)
|
||||
|
||||
self._request_handlers[GetTaskPayloadRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def cancel_task(
|
||||
self,
|
||||
) -> Callable[
|
||||
[Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]],
|
||||
Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]],
|
||||
]:
|
||||
"""Register a handler for cancelling tasks.
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]],
|
||||
) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]:
|
||||
logger.debug("Registering handler for CancelTaskRequest")
|
||||
wrapper = create_call_wrapper(func, CancelTaskRequest)
|
||||
|
||||
async def handler(req: CancelTaskRequest) -> ServerResult:
|
||||
result = await wrapper(req)
|
||||
return ServerResult(result)
|
||||
|
||||
self._request_handlers[CancelTaskRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,54 @@
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar, get_type_hints
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]:
|
||||
"""
|
||||
Create a wrapper function that knows how to call func with the request object.
|
||||
|
||||
Returns a wrapper function that takes the request and calls func appropriately.
|
||||
|
||||
The wrapper handles three calling patterns:
|
||||
1. Positional-only parameter typed as request_type (no default): func(req)
|
||||
2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req})
|
||||
3. No request parameter or parameter with default: func()
|
||||
"""
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
type_hints = get_type_hints(func)
|
||||
except (ValueError, TypeError, NameError): # pragma: no cover
|
||||
return lambda _: func()
|
||||
|
||||
# Check for positional-only parameter typed as request_type
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
|
||||
param_type = type_hints.get(param_name)
|
||||
if param_type == request_type: # pragma: no branch
|
||||
# Check if it has a default - if so, treat as old style
|
||||
if param.default is not inspect.Parameter.empty: # pragma: no cover
|
||||
return lambda _: func()
|
||||
# Found positional-only parameter with correct type and no default
|
||||
return lambda req: func(req)
|
||||
|
||||
# Check for any positional/keyword parameter typed as request_type
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): # pragma: no branch
|
||||
param_type = type_hints.get(param_name)
|
||||
if param_type == request_type:
|
||||
# Check if it has a default - if so, treat as old style
|
||||
if param.default is not inspect.Parameter.empty: # pragma: no cover
|
||||
return lambda _: func()
|
||||
|
||||
# Found keyword parameter with correct type and no default
|
||||
# Need to capture param_name in closure properly
|
||||
def make_keyword_wrapper(name: str) -> Callable[[Any], Any]:
|
||||
return lambda req: func(**{name: req})
|
||||
|
||||
return make_keyword_wrapper(param_name)
|
||||
|
||||
# No request parameter found - use old style
|
||||
return lambda _: func()
|
||||
@@ -0,0 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReadResourceContents:
|
||||
"""Contents returned from a read_resource call."""
|
||||
|
||||
content: str | bytes
|
||||
mime_type: str | None = None
|
||||
796
.venv/lib/python3.11/site-packages/mcp/server/lowlevel/server.py
Normal file
796
.venv/lib/python3.11/site-packages/mcp/server/lowlevel/server.py
Normal file
@@ -0,0 +1,796 @@
|
||||
"""
|
||||
MCP Server Module
|
||||
|
||||
This module provides a framework for creating an MCP (Model Context Protocol) server.
|
||||
It allows you to easily define and handle various types of requests and notifications
|
||||
in an asynchronous manner.
|
||||
|
||||
Usage:
|
||||
1. Create a Server instance:
|
||||
server = Server("your_server_name")
|
||||
|
||||
2. Define request handlers using decorators:
|
||||
@server.list_prompts()
|
||||
async def handle_list_prompts(request: types.ListPromptsRequest) -> types.ListPromptsResult:
|
||||
# Implementation
|
||||
|
||||
@server.get_prompt()
|
||||
async def handle_get_prompt(
|
||||
name: str, arguments: dict[str, str] | None
|
||||
) -> types.GetPromptResult:
|
||||
# Implementation
|
||||
|
||||
@server.list_tools()
|
||||
async def handle_list_tools(request: types.ListToolsRequest) -> types.ListToolsResult:
|
||||
# Implementation
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict | None
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
# Implementation
|
||||
|
||||
@server.list_resource_templates()
|
||||
async def handle_list_resource_templates() -> list[types.ResourceTemplate]:
|
||||
# Implementation
|
||||
|
||||
3. Define notification handlers if needed:
|
||||
@server.progress_notification()
|
||||
async def handle_progress(
|
||||
progress_token: str | int, progress: float, total: float | None,
|
||||
message: str | None
|
||||
) -> None:
|
||||
# Implementation
|
||||
|
||||
4. Run the server:
|
||||
async def main():
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="your_server_name",
|
||||
server_version="your_version",
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
The Server class provides methods to register handlers for various MCP requests and
|
||||
notifications. It automatically manages the request context and handles incoming
|
||||
messages from the client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import base64
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
||||
from typing import Any, Generic, TypeAlias, cast
|
||||
|
||||
import anyio
|
||||
import jsonschema
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.experimental.request_context import Experimental
|
||||
from mcp.server.lowlevel.experimental import ExperimentalHandlers
|
||||
from mcp.server.lowlevel.func_inspection import create_call_wrapper
|
||||
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.exceptions import McpError, UrlElicitationRequiredError
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LifespanResultT = TypeVar("LifespanResultT", default=Any)
|
||||
RequestT = TypeVar("RequestT", default=Any)
|
||||
|
||||
# type aliases for tool call results
|
||||
StructuredContent: TypeAlias = dict[str, Any]
|
||||
UnstructuredContent: TypeAlias = Iterable[types.ContentBlock]
|
||||
CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent]
|
||||
|
||||
# This will be properly typed in each Server instance's context
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
|
||||
|
||||
|
||||
class NotificationOptions:
|
||||
def __init__(
|
||||
self,
|
||||
prompts_changed: bool = False,
|
||||
resources_changed: bool = False,
|
||||
tools_changed: bool = False,
|
||||
):
|
||||
self.prompts_changed = prompts_changed
|
||||
self.resources_changed = resources_changed
|
||||
self.tools_changed = tools_changed
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Default lifespan context manager that does nothing.
|
||||
|
||||
Args:
|
||||
server: The server instance this lifespan is managing
|
||||
|
||||
Returns:
|
||||
An empty context object
|
||||
"""
|
||||
yield {}
|
||||
|
||||
|
||||
class Server(Generic[LifespanResultT, RequestT]):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
version: str | None = None,
|
||||
instructions: str | None = None,
|
||||
website_url: str | None = None,
|
||||
icons: list[types.Icon] | None = None,
|
||||
lifespan: Callable[
|
||||
[Server[LifespanResultT, RequestT]],
|
||||
AbstractAsyncContextManager[LifespanResultT],
|
||||
] = lifespan,
|
||||
):
|
||||
self.name = name
|
||||
self.version = version
|
||||
self.instructions = instructions
|
||||
self.website_url = website_url
|
||||
self.icons = icons
|
||||
self.lifespan = lifespan
|
||||
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
|
||||
types.PingRequest: _ping_handler,
|
||||
}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
self._tool_cache: dict[str, types.Tool] = {}
|
||||
self._experimental_handlers: ExperimentalHandlers | None = None
|
||||
logger.debug("Initializing server %r", name)
|
||||
|
||||
def create_initialization_options(
|
||||
self,
|
||||
notification_options: NotificationOptions | None = None,
|
||||
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
|
||||
) -> InitializationOptions:
|
||||
"""Create initialization options from this server instance."""
|
||||
|
||||
def pkg_version(package: str) -> str:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version(package)
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
|
||||
return "unknown" # pragma: no cover
|
||||
|
||||
return InitializationOptions(
|
||||
server_name=self.name,
|
||||
server_version=self.version if self.version else pkg_version("mcp"),
|
||||
capabilities=self.get_capabilities(
|
||||
notification_options or NotificationOptions(),
|
||||
experimental_capabilities or {},
|
||||
),
|
||||
instructions=self.instructions,
|
||||
website_url=self.website_url,
|
||||
icons=self.icons,
|
||||
)
|
||||
|
||||
def get_capabilities(
|
||||
self,
|
||||
notification_options: NotificationOptions,
|
||||
experimental_capabilities: dict[str, dict[str, Any]],
|
||||
) -> types.ServerCapabilities:
|
||||
"""Convert existing handlers to a ServerCapabilities object."""
|
||||
prompts_capability = None
|
||||
resources_capability = None
|
||||
tools_capability = None
|
||||
logging_capability = None
|
||||
completions_capability = None
|
||||
|
||||
# Set prompt capabilities if handler exists
|
||||
if types.ListPromptsRequest in self.request_handlers:
|
||||
prompts_capability = types.PromptsCapability(listChanged=notification_options.prompts_changed)
|
||||
|
||||
# Set resource capabilities if handler exists
|
||||
if types.ListResourcesRequest in self.request_handlers:
|
||||
resources_capability = types.ResourcesCapability(
|
||||
subscribe=False, listChanged=notification_options.resources_changed
|
||||
)
|
||||
|
||||
# Set tool capabilities if handler exists
|
||||
if types.ListToolsRequest in self.request_handlers:
|
||||
tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed)
|
||||
|
||||
# Set logging capabilities if handler exists
|
||||
if types.SetLevelRequest in self.request_handlers: # pragma: no cover
|
||||
logging_capability = types.LoggingCapability()
|
||||
|
||||
# Set completions capabilities if handler exists
|
||||
if types.CompleteRequest in self.request_handlers:
|
||||
completions_capability = types.CompletionsCapability()
|
||||
|
||||
capabilities = types.ServerCapabilities(
|
||||
prompts=prompts_capability,
|
||||
resources=resources_capability,
|
||||
tools=tools_capability,
|
||||
logging=logging_capability,
|
||||
experimental=experimental_capabilities,
|
||||
completions=completions_capability,
|
||||
)
|
||||
if self._experimental_handlers:
|
||||
self._experimental_handlers.update_capabilities(capabilities)
|
||||
return capabilities
|
||||
|
||||
@property
|
||||
def request_context(
|
||||
self,
|
||||
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
@property
|
||||
def experimental(self) -> ExperimentalHandlers:
|
||||
"""Experimental APIs for tasks and other features.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
# We create this inline so we only add these capabilities _if_ they're actually used
|
||||
if self._experimental_handlers is None:
|
||||
self._experimental_handlers = ExperimentalHandlers(self, self.request_handlers, self.notification_handlers)
|
||||
return self._experimental_handlers
|
||||
|
||||
def list_prompts(self):
|
||||
def decorator(
|
||||
func: Callable[[], Awaitable[list[types.Prompt]]]
|
||||
| Callable[[types.ListPromptsRequest], Awaitable[types.ListPromptsResult]],
|
||||
):
|
||||
logger.debug("Registering handler for PromptListRequest")
|
||||
|
||||
wrapper = create_call_wrapper(func, types.ListPromptsRequest)
|
||||
|
||||
async def handler(req: types.ListPromptsRequest):
|
||||
result = await wrapper(req)
|
||||
# Handle both old style (list[Prompt]) and new style (ListPromptsResult)
|
||||
if isinstance(result, types.ListPromptsResult):
|
||||
return types.ServerResult(result)
|
||||
else:
|
||||
# Old style returns list[Prompt]
|
||||
return types.ServerResult(types.ListPromptsResult(prompts=result))
|
||||
|
||||
self.request_handlers[types.ListPromptsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_prompt(self):
|
||||
def decorator(
|
||||
func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]],
|
||||
):
|
||||
logger.debug("Registering handler for GetPromptRequest")
|
||||
|
||||
async def handler(req: types.GetPromptRequest):
|
||||
prompt_get = await func(req.params.name, req.params.arguments)
|
||||
return types.ServerResult(prompt_get)
|
||||
|
||||
self.request_handlers[types.GetPromptRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resources(self):
|
||||
def decorator(
|
||||
func: Callable[[], Awaitable[list[types.Resource]]]
|
||||
| Callable[[types.ListResourcesRequest], Awaitable[types.ListResourcesResult]],
|
||||
):
|
||||
logger.debug("Registering handler for ListResourcesRequest")
|
||||
|
||||
wrapper = create_call_wrapper(func, types.ListResourcesRequest)
|
||||
|
||||
async def handler(req: types.ListResourcesRequest):
|
||||
result = await wrapper(req)
|
||||
# Handle both old style (list[Resource]) and new style (ListResourcesResult)
|
||||
if isinstance(result, types.ListResourcesResult):
|
||||
return types.ServerResult(result)
|
||||
else:
|
||||
# Old style returns list[Resource]
|
||||
return types.ServerResult(types.ListResourcesResult(resources=result))
|
||||
|
||||
self.request_handlers[types.ListResourcesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resource_templates(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]):
|
||||
logger.debug("Registering handler for ListResourceTemplatesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
templates = await func()
|
||||
return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates))
|
||||
|
||||
self.request_handlers[types.ListResourceTemplatesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def read_resource(self):
|
||||
def decorator(
|
||||
func: Callable[[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]],
|
||||
):
|
||||
logger.debug("Registering handler for ReadResourceRequest")
|
||||
|
||||
async def handler(req: types.ReadResourceRequest):
|
||||
result = await func(req.params.uri)
|
||||
|
||||
def create_content(data: str | bytes, mime_type: str | None):
|
||||
match data:
|
||||
case str() as data:
|
||||
return types.TextResourceContents(
|
||||
uri=req.params.uri,
|
||||
text=data,
|
||||
mimeType=mime_type or "text/plain",
|
||||
)
|
||||
case bytes() as data: # pragma: no cover
|
||||
return types.BlobResourceContents(
|
||||
uri=req.params.uri,
|
||||
blob=base64.b64encode(data).decode(),
|
||||
mimeType=mime_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
match result:
|
||||
case str() | bytes() as data: # pragma: no cover
|
||||
warnings.warn(
|
||||
"Returning str or bytes from read_resource is deprecated. "
|
||||
"Use Iterable[ReadResourceContents] instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
content = create_content(data, None)
|
||||
case Iterable() as contents:
|
||||
contents_list = [
|
||||
create_content(content_item.content, content_item.mime_type) for content_item in contents
|
||||
]
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
contents=contents_list,
|
||||
)
|
||||
)
|
||||
case _: # pragma: no cover
|
||||
raise ValueError(f"Unexpected return type from read_resource: {type(result)}")
|
||||
|
||||
return types.ServerResult( # pragma: no cover
|
||||
types.ReadResourceResult(
|
||||
contents=[content],
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ReadResourceRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_logging_level(self): # pragma: no cover
|
||||
def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SetLevelRequest")
|
||||
|
||||
async def handler(req: types.SetLevelRequest):
|
||||
await func(req.params.level)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.SetLevelRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def subscribe_resource(self): # pragma: no cover
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SubscribeRequest")
|
||||
|
||||
async def handler(req: types.SubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.SubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def unsubscribe_resource(self): # pragma: no cover
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for UnsubscribeRequest")
|
||||
|
||||
async def handler(req: types.UnsubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.UnsubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_tools(self):
|
||||
def decorator(
|
||||
func: Callable[[], Awaitable[list[types.Tool]]]
|
||||
| Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]],
|
||||
):
|
||||
logger.debug("Registering handler for ListToolsRequest")
|
||||
|
||||
wrapper = create_call_wrapper(func, types.ListToolsRequest)
|
||||
|
||||
async def handler(req: types.ListToolsRequest):
|
||||
result = await wrapper(req)
|
||||
|
||||
# Handle both old style (list[Tool]) and new style (ListToolsResult)
|
||||
if isinstance(result, types.ListToolsResult): # pragma: no cover
|
||||
# Refresh the tool cache with returned tools
|
||||
for tool in result.tools:
|
||||
validate_and_warn_tool_name(tool.name)
|
||||
self._tool_cache[tool.name] = tool
|
||||
return types.ServerResult(result)
|
||||
else:
|
||||
# Old style returns list[Tool]
|
||||
# Clear and refresh the entire tool cache
|
||||
self._tool_cache.clear()
|
||||
for tool in result:
|
||||
validate_and_warn_tool_name(tool.name)
|
||||
self._tool_cache[tool.name] = tool
|
||||
return types.ServerResult(types.ListToolsResult(tools=result))
|
||||
|
||||
self.request_handlers[types.ListToolsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def _make_error_result(self, error_message: str) -> types.ServerResult:
|
||||
"""Create a ServerResult with an error CallToolResult."""
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(
|
||||
content=[types.TextContent(type="text", text=error_message)],
|
||||
isError=True,
|
||||
)
|
||||
)
|
||||
|
||||
async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None:
|
||||
"""Get tool definition from cache, refreshing if necessary.
|
||||
|
||||
Returns the Tool object if found, None otherwise.
|
||||
"""
|
||||
if tool_name not in self._tool_cache:
|
||||
if types.ListToolsRequest in self.request_handlers:
|
||||
logger.debug("Tool cache miss for %s, refreshing cache", tool_name)
|
||||
await self.request_handlers[types.ListToolsRequest](None)
|
||||
|
||||
tool = self._tool_cache.get(tool_name)
|
||||
if tool is None:
|
||||
logger.warning("Tool '%s' not listed, no validation will be performed", tool_name)
|
||||
|
||||
return tool
|
||||
|
||||
def call_tool(self, *, validate_input: bool = True):
|
||||
"""Register a tool call handler.
|
||||
|
||||
Args:
|
||||
validate_input: If True, validates input against inputSchema. Default is True.
|
||||
|
||||
The handler validates input against inputSchema (if validate_input=True), calls the tool function,
|
||||
and builds a CallToolResult with the results:
|
||||
- Unstructured content (iterable of ContentBlock): returned in content
|
||||
- Structured content (dict): returned in structuredContent, serialized JSON text returned in content
|
||||
- Both: returned in content and structuredContent
|
||||
|
||||
If outputSchema is defined, validates structuredContent or errors if missing.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[
|
||||
UnstructuredContent
|
||||
| StructuredContent
|
||||
| CombinationContent
|
||||
| types.CallToolResult
|
||||
| types.CreateTaskResult
|
||||
],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CallToolRequest")
|
||||
|
||||
async def handler(req: types.CallToolRequest):
|
||||
try:
|
||||
tool_name = req.params.name
|
||||
arguments = req.params.arguments or {}
|
||||
tool = await self._get_cached_tool_definition(tool_name)
|
||||
|
||||
# input validation
|
||||
if validate_input and tool:
|
||||
try:
|
||||
jsonschema.validate(instance=arguments, schema=tool.inputSchema)
|
||||
except jsonschema.ValidationError as e:
|
||||
return self._make_error_result(f"Input validation error: {e.message}")
|
||||
|
||||
# tool call
|
||||
results = await func(tool_name, arguments)
|
||||
|
||||
# output normalization
|
||||
unstructured_content: UnstructuredContent
|
||||
maybe_structured_content: StructuredContent | None
|
||||
if isinstance(results, types.CallToolResult):
|
||||
return types.ServerResult(results)
|
||||
elif isinstance(results, types.CreateTaskResult):
|
||||
# Task-augmented execution returns task info instead of result
|
||||
return types.ServerResult(results)
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
# tool returned both structured and unstructured content
|
||||
unstructured_content, maybe_structured_content = cast(CombinationContent, results)
|
||||
elif isinstance(results, dict):
|
||||
# tool returned structured content only
|
||||
maybe_structured_content = cast(StructuredContent, results)
|
||||
unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))]
|
||||
elif hasattr(results, "__iter__"): # pragma: no cover
|
||||
# tool returned unstructured content only
|
||||
unstructured_content = cast(UnstructuredContent, results)
|
||||
maybe_structured_content = None
|
||||
else: # pragma: no cover
|
||||
return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}")
|
||||
|
||||
# output validation
|
||||
if tool and tool.outputSchema is not None:
|
||||
if maybe_structured_content is None:
|
||||
return self._make_error_result(
|
||||
"Output validation error: outputSchema defined but no structured output returned"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
jsonschema.validate(instance=maybe_structured_content, schema=tool.outputSchema)
|
||||
except jsonschema.ValidationError as e:
|
||||
return self._make_error_result(f"Output validation error: {e.message}")
|
||||
|
||||
# result
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(
|
||||
content=list(unstructured_content),
|
||||
structuredContent=maybe_structured_content,
|
||||
isError=False,
|
||||
)
|
||||
)
|
||||
except UrlElicitationRequiredError:
|
||||
# Re-raise UrlElicitationRequiredError so it can be properly handled
|
||||
# by _handle_request, which converts it to an error response with code -32042
|
||||
raise
|
||||
except Exception as e:
|
||||
return self._make_error_result(str(e))
|
||||
|
||||
self.request_handlers[types.CallToolRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None, str | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: types.ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken,
|
||||
req.params.progress,
|
||||
req.params.total,
|
||||
req.params.message,
|
||||
)
|
||||
|
||||
self.notification_handlers[types.ProgressNotification] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def completion(self):
|
||||
"""Provides completions for prompts and resource templates"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[
|
||||
types.PromptReference | types.ResourceTemplateReference,
|
||||
types.CompletionArgument,
|
||||
types.CompletionContext | None,
|
||||
],
|
||||
Awaitable[types.Completion | None],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CompleteRequest")
|
||||
|
||||
async def handler(req: types.CompleteRequest):
|
||||
completion = await func(req.params.ref, req.params.argument, req.params.context)
|
||||
return types.ServerResult(
|
||||
types.CompleteResult(
|
||||
completion=completion
|
||||
if completion is not None
|
||||
else types.Completion(values=[], total=None, hasMore=None),
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.CompleteRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
initialization_options: InitializationOptions,
|
||||
# When False, exceptions are returned as messages to the client.
|
||||
# When True, exceptions are raised, which will cause the server to shut down
|
||||
# but also make tracing exceptions much easier during testing and when using
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
# When True, the server is stateless and
|
||||
# clients can perform initialization with any node. The client must still follow
|
||||
# the initialization lifecycle, but can do so with any available node
|
||||
# rather than requiring initialization for each connection.
|
||||
stateless: bool = False,
|
||||
):
|
||||
async with AsyncExitStack() as stack:
|
||||
lifespan_context = await stack.enter_async_context(self.lifespan(self))
|
||||
session = await stack.enter_async_context(
|
||||
ServerSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
initialization_options,
|
||||
stateless=stateless,
|
||||
)
|
||||
)
|
||||
|
||||
# Configure task support for this session if enabled
|
||||
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
|
||||
if task_support is not None:
|
||||
task_support.configure_session(session)
|
||||
await stack.enter_async_context(task_support.run())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug("Received message: %s", message)
|
||||
|
||||
tg.start_soon(
|
||||
self._handle_message,
|
||||
message,
|
||||
session,
|
||||
lifespan_context,
|
||||
raise_exceptions,
|
||||
)
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
match message:
|
||||
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
|
||||
with responder:
|
||||
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
case Exception(): # pragma: no cover
|
||||
logger.error(f"Received exception from stream: {message}")
|
||||
await session.send_log_message(
|
||||
level="error",
|
||||
data="Internal Server Error",
|
||||
logger="mcp.server.exception_handler",
|
||||
)
|
||||
if raise_exceptions:
|
||||
raise message
|
||||
|
||||
for warning in w: # pragma: no cover
|
||||
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult],
|
||||
req: types.ClientRequestType,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool,
|
||||
):
|
||||
logger.info("Processing request of type %s", type(req).__name__)
|
||||
|
||||
if handler := self.request_handlers.get(type(req)):
|
||||
logger.debug("Dispatching request of type %s", type(req).__name__)
|
||||
|
||||
token = None
|
||||
try:
|
||||
# Extract request context and close_sse_stream from message metadata
|
||||
request_data = None
|
||||
close_sse_stream_cb = None
|
||||
close_standalone_sse_stream_cb = None
|
||||
if message.message_metadata is not None and isinstance(
|
||||
message.message_metadata, ServerMessageMetadata
|
||||
): # pragma: no cover
|
||||
request_data = message.message_metadata.request_context
|
||||
close_sse_stream_cb = message.message_metadata.close_sse_stream
|
||||
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream
|
||||
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
client_capabilities = session.client_params.capabilities if session.client_params else None
|
||||
task_support = self._experimental_handlers.task_support if self._experimental_handlers else None
|
||||
# Get task metadata from request params if present
|
||||
task_metadata = None
|
||||
if hasattr(req, "params") and req.params is not None:
|
||||
task_metadata = getattr(req.params, "task", None)
|
||||
token = request_ctx.set(
|
||||
RequestContext(
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
lifespan_context,
|
||||
Experimental(
|
||||
task_metadata=task_metadata,
|
||||
_client_capabilities=client_capabilities,
|
||||
_session=session,
|
||||
_task_support=task_support,
|
||||
),
|
||||
request=request_data,
|
||||
close_sse_stream=close_sse_stream_cb,
|
||||
close_standalone_sse_stream=close_standalone_sse_stream_cb,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
except McpError as err: # pragma: no cover
|
||||
response = err.error
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
logger.info(
|
||||
"Request %s cancelled - duplicate response suppressed",
|
||||
message.request_id,
|
||||
)
|
||||
return
|
||||
except Exception as err: # pragma: no cover
|
||||
if raise_exceptions:
|
||||
raise err
|
||||
response = types.ErrorData(code=0, message=str(err), data=None)
|
||||
finally:
|
||||
# Reset the global state after we are done
|
||||
if token is not None: # pragma: no branch
|
||||
request_ctx.reset(token)
|
||||
|
||||
await message.respond(response)
|
||||
else: # pragma: no cover
|
||||
await message.respond(
|
||||
types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="Method not found",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Response sent")
|
||||
|
||||
async def _handle_notification(self, notify: Any):
|
||||
if handler := self.notification_handlers.get(type(notify)): # type: ignore
|
||||
logger.debug("Dispatching notification of type %s", type(notify).__name__)
|
||||
|
||||
try:
|
||||
await handler(notify)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Uncaught exception in notification handler")
|
||||
|
||||
|
||||
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
Reference in New Issue
Block a user