Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from .fastmcp import FastMCP
|
||||
from .lowlevel import NotificationOptions, Server
|
||||
from .models import InitializationOptions
|
||||
|
||||
__all__ = ["Server", "FastMCP", "NotificationOptions", "InitializationOptions"]
|
||||
50
.venv/lib/python3.11/site-packages/mcp/server/__main__.py
Normal file
50
.venv/lib/python3.11/site-packages/mcp/server/__main__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import ServerCapabilities
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("server")
|
||||
|
||||
|
||||
async def receive_loop(session: ServerSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from client: %s", message)
|
||||
|
||||
|
||||
async def main():
|
||||
version = importlib.metadata.version("mcp")
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
async with (
|
||||
ServerSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="mcp",
|
||||
server_version=version,
|
||||
capabilities=ServerCapabilities(),
|
||||
),
|
||||
) as session,
|
||||
write_stream,
|
||||
):
|
||||
await receive_loop(session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main, backend="trio")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
MCP OAuth server authorization components.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def stringify_pydantic_error(validation_error: ValidationError) -> str:
|
||||
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Request handlers for MCP authorization endpoints.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,224 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.datastructures import FormData, QueryParams
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import (
|
||||
AuthorizationErrorCode,
|
||||
AuthorizationParams,
|
||||
AuthorizeError,
|
||||
OAuthAuthorizationServerProvider,
|
||||
construct_redirect_uri,
|
||||
)
|
||||
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthorizationRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
client_id: str = Field(..., description="The client ID")
|
||||
redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
|
||||
|
||||
# see OAuthClientMetadata; we only support `code`
|
||||
response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
|
||||
code_challenge: str = Field(..., description="PKCE code challenge")
|
||||
code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
|
||||
state: str | None = Field(None, description="Optional state parameter")
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope; if specified, should be a space-separated list of scope strings",
|
||||
)
|
||||
resource: str | None = Field(
|
||||
None,
|
||||
description="RFC 8707 resource indicator - the MCP server this token will be used with",
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationErrorResponse(BaseModel):
|
||||
error: AuthorizationErrorCode
|
||||
error_description: str | None
|
||||
error_uri: AnyUrl | None = None
|
||||
# must be set if provided in the request
|
||||
state: str | None = None
|
||||
|
||||
|
||||
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
|
||||
if params is None: # pragma: no cover
|
||||
return None
|
||||
value = params.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class AnyUrlModel(RootModel[AnyUrl]):
|
||||
root: AnyUrl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
# implements authorization requests for grant_type=code;
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
|
||||
state = None
|
||||
redirect_uri = None
|
||||
client = None
|
||||
params = None
|
||||
|
||||
async def error_response(
|
||||
error: AuthorizationErrorCode,
|
||||
error_description: str | None,
|
||||
attempt_load_client: bool = True,
|
||||
):
|
||||
# Error responses take two different formats:
|
||||
# 1. The request has a valid client ID & redirect_uri: we issue a redirect
|
||||
# back to the redirect_uri with the error response fields as query
|
||||
# parameters. This allows the client to be notified of the error.
|
||||
# 2. Otherwise, we return an error response directly to the end user;
|
||||
# we choose to do so in JSON, but this is left undefined in the
|
||||
# specification.
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
|
||||
#
|
||||
# This logic is a bit awkward to handle, because the error might be thrown
|
||||
# very early in request validation, before we've done the usual Pydantic
|
||||
# validation, loaded the client, etc. To handle this, error_response()
|
||||
# contains fallback logic which attempts to load the parameters directly
|
||||
# from the request.
|
||||
|
||||
nonlocal client, redirect_uri, state
|
||||
if client is None and attempt_load_client:
|
||||
# make last-ditch attempt to load the client
|
||||
client_id = best_effort_extract_string("client_id", params)
|
||||
client = await self.provider.get_client(client_id) if client_id else None
|
||||
if redirect_uri is None and client:
|
||||
# make last-ditch effort to load the redirect uri
|
||||
try:
|
||||
if params is not None and "redirect_uri" not in params:
|
||||
raw_redirect_uri = None
|
||||
else:
|
||||
raw_redirect_uri = AnyUrlModel.model_validate(
|
||||
best_effort_extract_string("redirect_uri", params)
|
||||
).root
|
||||
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
|
||||
except (ValidationError, InvalidRedirectUriError):
|
||||
# if the redirect URI is invalid, ignore it & just return the
|
||||
# initial error
|
||||
pass
|
||||
|
||||
# the error response MUST contain the state specified by the client, if any
|
||||
if state is None: # pragma: no cover
|
||||
# make last-ditch effort to load state
|
||||
state = best_effort_extract_string("state", params)
|
||||
|
||||
error_resp = AuthorizationErrorResponse(
|
||||
error=error,
|
||||
error_description=error_description,
|
||||
state=state,
|
||||
)
|
||||
|
||||
if redirect_uri and client:
|
||||
return RedirectResponse(
|
||||
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
else:
|
||||
return PydanticJSONResponse(
|
||||
status_code=400,
|
||||
content=error_resp,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse request parameters
|
||||
if request.method == "GET":
|
||||
# Convert query_params to dict for pydantic validation
|
||||
params = request.query_params
|
||||
else:
|
||||
# Parse form data for POST requests
|
||||
params = await request.form()
|
||||
|
||||
# Save state if it exists, even before validation
|
||||
state = best_effort_extract_string("state", params)
|
||||
|
||||
try:
|
||||
auth_request = AuthorizationRequest.model_validate(params)
|
||||
state = auth_request.state # Update with validated state
|
||||
except ValidationError as validation_error:
|
||||
error: AuthorizationErrorCode = "invalid_request"
|
||||
for e in validation_error.errors():
|
||||
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
|
||||
error = "unsupported_response_type"
|
||||
break
|
||||
return await error_response(error, stringify_pydantic_error(validation_error))
|
||||
|
||||
# Get client information
|
||||
client = await self.provider.get_client(
|
||||
auth_request.client_id,
|
||||
)
|
||||
if not client:
|
||||
# For client_id validation errors, return direct error (no redirect)
|
||||
return await error_response(
|
||||
error="invalid_request",
|
||||
error_description=f"Client ID '{auth_request.client_id}' not found",
|
||||
attempt_load_client=False,
|
||||
)
|
||||
|
||||
# Validate redirect_uri against client's registered URIs
|
||||
try:
|
||||
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
|
||||
except InvalidRedirectUriError as validation_error:
|
||||
# For redirect_uri validation errors, return direct error (no redirect)
|
||||
return await error_response(
|
||||
error="invalid_request",
|
||||
error_description=validation_error.message,
|
||||
)
|
||||
|
||||
# Validate scope - for scope errors, we can redirect
|
||||
try:
|
||||
scopes = client.validate_scope(auth_request.scope)
|
||||
except InvalidScopeError as validation_error:
|
||||
# For scope errors, redirect with error parameters
|
||||
return await error_response(
|
||||
error="invalid_scope",
|
||||
error_description=validation_error.message,
|
||||
)
|
||||
|
||||
# Setup authorization parameters
|
||||
auth_params = AuthorizationParams(
|
||||
state=state,
|
||||
scopes=scopes,
|
||||
code_challenge=auth_request.code_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
redirect_uri_provided_explicitly=auth_request.redirect_uri is not None,
|
||||
resource=auth_request.resource, # RFC 8707
|
||||
)
|
||||
|
||||
try:
|
||||
# Let the provider pick the next URI to redirect to
|
||||
return RedirectResponse(
|
||||
url=await self.provider.authorize(
|
||||
client,
|
||||
auth_params,
|
||||
),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
except AuthorizeError as e:
|
||||
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
|
||||
return await error_response(error=e.error, error_description=e.error_description)
|
||||
|
||||
except Exception as validation_error: # pragma: no cover
|
||||
# Catch-all for unexpected errors
|
||||
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
|
||||
return await error_response(error="server_error", error_description="An unexpected error occurred")
|
||||
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataHandler:
|
||||
metadata: OAuthMetadata
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
return PydanticJSONResponse(
|
||||
content=self.metadata,
|
||||
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtectedResourceMetadataHandler:
|
||||
metadata: ProtectedResourceMetadata
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
return PydanticJSONResponse(
|
||||
content=self.metadata,
|
||||
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
|
||||
)
|
||||
@@ -0,0 +1,136 @@
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||
|
||||
|
||||
class RegistrationRequest(RootModel[OAuthClientMetadata]):
|
||||
# this wrapper is a no-op; it's just to separate out the types exposed to the
|
||||
# provider from what we use in the HTTP handler
|
||||
root: OAuthClientMetadata
|
||||
|
||||
|
||||
class RegistrationErrorResponse(BaseModel):
|
||||
error: RegistrationErrorCode
|
||||
error_description: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistrationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
options: ClientRegistrationOptions
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
|
||||
try:
|
||||
# Parse request body as JSON
|
||||
body = await request.json()
|
||||
client_metadata = OAuthClientMetadata.model_validate(body)
|
||||
|
||||
# Scope validation is handled below
|
||||
except ValidationError as validation_error:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description=stringify_pydantic_error(validation_error),
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id = str(uuid4())
|
||||
|
||||
# If auth method is None, default to client_secret_post
|
||||
if client_metadata.token_endpoint_auth_method is None:
|
||||
client_metadata.token_endpoint_auth_method = "client_secret_post"
|
||||
|
||||
client_secret = None
|
||||
if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch
|
||||
# cryptographically secure random 32-byte hex string
|
||||
client_secret = secrets.token_hex(32)
|
||||
|
||||
if client_metadata.scope is None and self.options.default_scopes is not None:
|
||||
client_metadata.scope = " ".join(self.options.default_scopes)
|
||||
elif client_metadata.scope is not None and self.options.valid_scopes is not None:
|
||||
requested_scopes = set(client_metadata.scope.split())
|
||||
valid_scopes = set(self.options.valid_scopes)
|
||||
if not requested_scopes.issubset(valid_scopes): # pragma: no branch
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="Requested scopes are not valid: "
|
||||
f"{', '.join(requested_scopes - valid_scopes)}",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
if not {"authorization_code", "refresh_token"}.issubset(set(client_metadata.grant_types)):
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="grant_types must be authorization_code and refresh_token",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# The MCP spec requires servers to use the authorization `code` flow
|
||||
# with PKCE
|
||||
if "code" not in client_metadata.response_types:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="response_types must include 'code' for authorization_code grant",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id_issued_at = int(time.time())
|
||||
client_secret_expires_at = (
|
||||
client_id_issued_at + self.options.client_secret_expiry_seconds
|
||||
if self.options.client_secret_expiry_seconds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
client_info = OAuthClientInformationFull(
|
||||
client_id=client_id,
|
||||
client_id_issued_at=client_id_issued_at,
|
||||
client_secret=client_secret,
|
||||
client_secret_expires_at=client_secret_expires_at,
|
||||
# passthrough information from the client request
|
||||
redirect_uris=client_metadata.redirect_uris,
|
||||
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
|
||||
grant_types=client_metadata.grant_types,
|
||||
response_types=client_metadata.response_types,
|
||||
client_name=client_metadata.client_name,
|
||||
client_uri=client_metadata.client_uri,
|
||||
logo_uri=client_metadata.logo_uri,
|
||||
scope=client_metadata.scope,
|
||||
contacts=client_metadata.contacts,
|
||||
tos_uri=client_metadata.tos_uri,
|
||||
policy_uri=client_metadata.policy_uri,
|
||||
jwks_uri=client_metadata.jwks_uri,
|
||||
jwks=client_metadata.jwks,
|
||||
software_id=client_metadata.software_id,
|
||||
software_version=client_metadata.software_version,
|
||||
)
|
||||
try:
|
||||
# Register client
|
||||
await self.provider.register_client(client_info)
|
||||
|
||||
# Return client information
|
||||
return PydanticJSONResponse(content=client_info, status_code=201)
|
||||
except RegistrationError as e:
|
||||
# Handle registration errors as defined in RFC 7591 Section 3.2.2
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
|
||||
status_code=400,
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import (
|
||||
stringify_pydantic_error,
|
||||
)
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
|
||||
|
||||
|
||||
class RevocationRequest(BaseModel):
|
||||
"""
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
|
||||
"""
|
||||
|
||||
token: str
|
||||
token_type_hint: Literal["access_token", "refresh_token"] | None = None
|
||||
client_id: str
|
||||
client_secret: str | None
|
||||
|
||||
|
||||
class RevocationErrorResponse(BaseModel):
|
||||
error: Literal["invalid_request", "unauthorized_client"]
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RevocationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
client_authenticator: ClientAuthenticator
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
"""
|
||||
Handler for the OAuth 2.0 Token Revocation endpoint.
|
||||
"""
|
||||
try:
|
||||
client = await self.client_authenticator.authenticate_request(request)
|
||||
except AuthenticationError as e: # pragma: no cover
|
||||
return PydanticJSONResponse(
|
||||
status_code=401,
|
||||
content=RevocationErrorResponse(
|
||||
error="unauthorized_client",
|
||||
error_description=e.message,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
form_data = await request.form()
|
||||
revocation_request = RevocationRequest.model_validate(dict(form_data))
|
||||
except ValidationError as e:
|
||||
return PydanticJSONResponse(
|
||||
status_code=400,
|
||||
content=RevocationErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=stringify_pydantic_error(e),
|
||||
),
|
||||
)
|
||||
|
||||
loaders = [
|
||||
self.provider.load_access_token,
|
||||
partial(self.provider.load_refresh_token, client),
|
||||
]
|
||||
if revocation_request.token_type_hint == "refresh_token": # pragma: no cover
|
||||
loaders = reversed(loaders)
|
||||
|
||||
token: None | AccessToken | RefreshToken = None
|
||||
for loader in loaders:
|
||||
token = await loader(revocation_request.token)
|
||||
if token is not None:
|
||||
break
|
||||
|
||||
# if token is not found, just return HTTP 200 per the RFC
|
||||
if token and token.client_id == client.client_id:
|
||||
# Revoke token; provider is not meant to be able to do validation
|
||||
# at this point that would result in an error
|
||||
await self.provider.revoke_token(token)
|
||||
|
||||
# Return successful empty response
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,241 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
|
||||
class AuthorizationCodeRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(..., description="The authorization code")
|
||||
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
|
||||
code_verifier: str = Field(..., description="PKCE code verifier")
|
||||
# RFC 8707 resource indicator
|
||||
resource: str | None = Field(None, description="Resource indicator for the token")
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str = Field(..., description="The refresh token")
|
||||
scope: str | None = Field(None, description="Optional scope parameter")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
# RFC 8707 resource indicator
|
||||
resource: str | None = Field(None, description="Resource indicator for the token")
|
||||
|
||||
|
||||
class TokenRequest(
|
||||
RootModel[
|
||||
Annotated[
|
||||
AuthorizationCodeRequest | RefreshTokenRequest,
|
||||
Field(discriminator="grant_type"),
|
||||
]
|
||||
]
|
||||
):
|
||||
root: Annotated[
|
||||
AuthorizationCodeRequest | RefreshTokenRequest,
|
||||
Field(discriminator="grant_type"),
|
||||
]
|
||||
|
||||
|
||||
class TokenErrorResponse(BaseModel):
|
||||
"""
|
||||
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
"""
|
||||
|
||||
error: TokenErrorCode
|
||||
error_description: str | None = None
|
||||
error_uri: AnyHttpUrl | None = None
|
||||
|
||||
|
||||
class TokenSuccessResponse(RootModel[OAuthToken]):
|
||||
# this is just a wrapper over OAuthToken; the only reason we do this
|
||||
# is to have some separation between the HTTP response type, and the
|
||||
# type returned by the provider
|
||||
root: OAuthToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
client_authenticator: ClientAuthenticator
|
||||
|
||||
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
|
||||
status_code = 200
|
||||
if isinstance(obj, TokenErrorResponse):
|
||||
status_code = 400
|
||||
|
||||
return PydanticJSONResponse(
|
||||
content=obj,
|
||||
status_code=status_code,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
async def handle(self, request: Request):
|
||||
try:
|
||||
client_info = await self.client_authenticator.authenticate_request(request)
|
||||
except AuthenticationError as e:
|
||||
# Authentication failures should return 401
|
||||
return PydanticJSONResponse(
|
||||
content=TokenErrorResponse(
|
||||
error="unauthorized_client",
|
||||
error_description=e.message,
|
||||
),
|
||||
status_code=401,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
form_data = await request.form()
|
||||
token_request = TokenRequest.model_validate(dict(form_data)).root
|
||||
except ValidationError as validation_error: # pragma: no cover
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=stringify_pydantic_error(validation_error),
|
||||
)
|
||||
)
|
||||
|
||||
if token_request.grant_type not in client_info.grant_types: # pragma: no cover
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="unsupported_grant_type",
|
||||
error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"),
|
||||
)
|
||||
)
|
||||
|
||||
tokens: OAuthToken
|
||||
|
||||
match token_request:
|
||||
case AuthorizationCodeRequest():
|
||||
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
|
||||
if auth_code is None or auth_code.client_id != token_request.client_id:
|
||||
# if code belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="authorization code does not exist",
|
||||
)
|
||||
)
|
||||
|
||||
# make auth codes expire after a deadline
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
|
||||
if auth_code.expires_at < time.time():
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="authorization code has expired",
|
||||
)
|
||||
)
|
||||
|
||||
# verify redirect_uri doesn't change between /authorize and /tokens
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
|
||||
if auth_code.redirect_uri_provided_explicitly:
|
||||
authorize_request_redirect_uri = auth_code.redirect_uri
|
||||
else: # pragma: no cover
|
||||
authorize_request_redirect_uri = None
|
||||
|
||||
# Convert both sides to strings for comparison to handle AnyUrl vs string issues
|
||||
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
|
||||
auth_redirect_str = (
|
||||
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
|
||||
)
|
||||
|
||||
if token_redirect_str != auth_redirect_str:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=("redirect_uri did not match the one used when creating auth code"),
|
||||
)
|
||||
)
|
||||
|
||||
# Verify PKCE code verifier
|
||||
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
|
||||
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
||||
|
||||
if hashed_code_verifier != auth_code.code_challenge:
|
||||
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="incorrect code_verifier",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange authorization code for tokens
|
||||
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
)
|
||||
|
||||
case RefreshTokenRequest(): # pragma: no cover
|
||||
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
|
||||
if refresh_token is None or refresh_token.client_id != token_request.client_id:
|
||||
# if token belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="refresh token does not exist",
|
||||
)
|
||||
)
|
||||
|
||||
if refresh_token.expires_at and refresh_token.expires_at < time.time():
|
||||
# if the refresh token has expired, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="refresh token has expired",
|
||||
)
|
||||
)
|
||||
|
||||
# Parse scopes if provided
|
||||
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
|
||||
|
||||
for scope in scopes:
|
||||
if scope not in refresh_token.scopes:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_scope",
|
||||
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange refresh token for new tokens
|
||||
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
)
|
||||
|
||||
return self.response(TokenSuccessResponse(root=tokens))
|
||||
@@ -0,0 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class PydanticJSONResponse(JSONResponse):
|
||||
# use pydantic json serialization instead of the stock `json.dumps`,
|
||||
# so that we can handle serializing pydantic models like AnyHttpUrl
|
||||
def render(self, content: Any) -> bytes:
|
||||
return content.model_dump_json(exclude_none=True).encode("utf-8")
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Middleware for MCP authorization.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,48 @@
|
||||
import contextvars
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
|
||||
from mcp.server.auth.provider import AccessToken
|
||||
|
||||
# Create a contextvar to store the authenticated user
|
||||
# The default is None, indicating no authenticated user is present
|
||||
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
|
||||
|
||||
|
||||
def get_access_token() -> AccessToken | None:
|
||||
"""
|
||||
Get the access token from the current context.
|
||||
|
||||
Returns:
|
||||
The access token if an authenticated user is available, None otherwise.
|
||||
"""
|
||||
auth_user = auth_context_var.get()
|
||||
return auth_user.access_token if auth_user else None
|
||||
|
||||
|
||||
class AuthContextMiddleware:
|
||||
"""
|
||||
Middleware that extracts the authenticated user from the request
|
||||
and sets it in a contextvar for easy access throughout the request lifecycle.
|
||||
|
||||
This middleware should be added after the AuthenticationMiddleware in the
|
||||
middleware stack to ensure that the user is properly authenticated before
|
||||
being stored in the context.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
user = scope.get("user")
|
||||
if isinstance(user, AuthenticatedUser):
|
||||
# Set the authenticated user in the contextvar
|
||||
token = auth_context_var.set(user)
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
auth_context_var.reset(token)
|
||||
else:
|
||||
# No authenticated user, just process the request
|
||||
await self.app(scope, receive, send)
|
||||
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.provider import AccessToken, TokenVerifier
|
||||
|
||||
|
||||
class AuthenticatedUser(SimpleUser):
|
||||
"""User with authentication info."""
|
||||
|
||||
def __init__(self, auth_info: AccessToken):
|
||||
super().__init__(auth_info.client_id)
|
||||
self.access_token = auth_info
|
||||
self.scopes = auth_info.scopes
|
||||
|
||||
|
||||
class BearerAuthBackend(AuthenticationBackend):
|
||||
"""
|
||||
Authentication backend that validates Bearer tokens using a TokenVerifier.
|
||||
"""
|
||||
|
||||
def __init__(self, token_verifier: TokenVerifier):
|
||||
self.token_verifier = token_verifier
|
||||
|
||||
async def authenticate(self, conn: HTTPConnection):
|
||||
auth_header = next(
|
||||
(conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
|
||||
None,
|
||||
)
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Validate the token with the verifier
|
||||
auth_info = await self.token_verifier.verify_token(token)
|
||||
|
||||
if not auth_info:
|
||||
return None
|
||||
|
||||
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
|
||||
return None
|
||||
|
||||
return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info)
|
||||
|
||||
|
||||
class RequireAuthMiddleware:
|
||||
"""
|
||||
Middleware that requires a valid Bearer token in the Authorization header.
|
||||
|
||||
This will validate the token with the auth provider and store the resulting
|
||||
auth info in the request state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: Any,
|
||||
required_scopes: list[str],
|
||||
resource_metadata_url: AnyHttpUrl | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application
|
||||
required_scopes: List of scopes that the token must have
|
||||
resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header
|
||||
"""
|
||||
self.app = app
|
||||
self.required_scopes = required_scopes
|
||||
self.resource_metadata_url = resource_metadata_url
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
auth_user = scope.get("user")
|
||||
if not isinstance(auth_user, AuthenticatedUser):
|
||||
await self._send_auth_error(
|
||||
send, status_code=401, error="invalid_token", description="Authentication required"
|
||||
)
|
||||
return
|
||||
|
||||
auth_credentials = scope.get("auth")
|
||||
|
||||
for required_scope in self.required_scopes:
|
||||
# auth_credentials should always be provided; this is just paranoia
|
||||
if auth_credentials is None or required_scope not in auth_credentials.scopes:
|
||||
await self._send_auth_error(
|
||||
send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}"
|
||||
)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None:
|
||||
"""Send an authentication error response with WWW-Authenticate header."""
|
||||
# Build WWW-Authenticate header value
|
||||
www_auth_parts = [f'error="{error}"', f'error_description="{description}"']
|
||||
if self.resource_metadata_url: # pragma: no cover
|
||||
www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"')
|
||||
|
||||
www_authenticate = f"Bearer {', '.join(www_auth_parts)}"
|
||||
|
||||
# Send response
|
||||
body = {"error": error, "error_description": description}
|
||||
body_bytes = json.dumps(body).encode()
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"content-length", str(len(body_bytes)).encode()),
|
||||
(b"www-authenticate", www_authenticate.encode()),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": body_bytes,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,115 @@
|
||||
import base64
|
||||
import binascii
|
||||
import hmac
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message # pragma: no cover
|
||||
|
||||
|
||||
class ClientAuthenticator:
|
||||
"""
|
||||
ClientAuthenticator is a callable which validates requests from a client
|
||||
application, used to verify /token calls.
|
||||
If, during registration, the client requested to be issued a secret, the
|
||||
authenticator asserts that /token calls must be authenticated with
|
||||
that same token.
|
||||
NOTE: clients can opt for no authentication during registration, in which case this
|
||||
logic is skipped.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||
"""
|
||||
Initialize the dependency.
|
||||
|
||||
Args:
|
||||
provider: Provider to look up client information
|
||||
"""
|
||||
self.provider = provider
|
||||
|
||||
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull:
|
||||
"""
|
||||
Authenticate a client from an HTTP request.
|
||||
|
||||
Extracts client credentials from the appropriate location based on the
|
||||
client's registered authentication method and validates them.
|
||||
|
||||
Args:
|
||||
request: The HTTP request containing client credentials
|
||||
|
||||
Returns:
|
||||
The authenticated client information
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
"""
|
||||
form_data = await request.form()
|
||||
client_id = form_data.get("client_id")
|
||||
if not client_id:
|
||||
raise AuthenticationError("Missing client_id")
|
||||
|
||||
client = await self.provider.get_client(str(client_id))
|
||||
if not client:
|
||||
raise AuthenticationError("Invalid client_id") # pragma: no cover
|
||||
|
||||
request_client_secret: str | None = None
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
||||
if client.token_endpoint_auth_method == "client_secret_basic":
|
||||
if not auth_header.startswith("Basic "):
|
||||
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")
|
||||
|
||||
try:
|
||||
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
|
||||
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
|
||||
if ":" not in decoded:
|
||||
raise ValueError("Invalid Basic auth format")
|
||||
basic_client_id, request_client_secret = decoded.split(":", 1)
|
||||
|
||||
# URL-decode both parts per RFC 6749 Section 2.3.1
|
||||
basic_client_id = unquote(basic_client_id)
|
||||
request_client_secret = unquote(request_client_secret)
|
||||
|
||||
if basic_client_id != client_id:
|
||||
raise AuthenticationError("Client ID mismatch in Basic auth")
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error):
|
||||
raise AuthenticationError("Invalid Basic authentication header")
|
||||
|
||||
elif client.token_endpoint_auth_method == "client_secret_post":
|
||||
raw_form_data = form_data.get("client_secret")
|
||||
# form_data.get() can return a UploadFile or None, so we need to check if it's a string
|
||||
if isinstance(raw_form_data, str):
|
||||
request_client_secret = str(raw_form_data)
|
||||
|
||||
elif client.token_endpoint_auth_method == "none":
|
||||
request_client_secret = None
|
||||
else:
|
||||
raise AuthenticationError( # pragma: no cover
|
||||
f"Unsupported auth method: {client.token_endpoint_auth_method}"
|
||||
)
|
||||
|
||||
# If client from the store expects a secret, validate that the request provides
|
||||
# that secret
|
||||
if client.client_secret: # pragma: no branch
|
||||
if not request_client_secret:
|
||||
raise AuthenticationError("Client secret is required") # pragma: no cover
|
||||
|
||||
# hmac.compare_digest requires that both arguments are either bytes or a `str` containing
|
||||
# only ASCII characters. Since we do not control `request_client_secret`, we encode both
|
||||
# arguments to bytes.
|
||||
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()):
|
||||
raise AuthenticationError("Invalid client_secret") # pragma: no cover
|
||||
|
||||
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
|
||||
raise AuthenticationError("Client secret has expired") # pragma: no cover
|
||||
|
||||
return client
|
||||
301
.venv/lib/python3.11/site-packages/mcp/server/auth/provider.py
Normal file
301
.venv/lib/python3.11/site-packages/mcp/server/auth/provider.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Literal, Protocol, TypeVar
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
|
||||
class AuthorizationParams(BaseModel):
|
||||
state: str | None
|
||||
scopes: list[str] | None
|
||||
code_challenge: str
|
||||
redirect_uri: AnyUrl
|
||||
redirect_uri_provided_explicitly: bool
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
class AuthorizationCode(BaseModel):
|
||||
code: str
|
||||
scopes: list[str]
|
||||
expires_at: float
|
||||
client_id: str
|
||||
code_challenge: str
|
||||
redirect_uri: AnyUrl
|
||||
redirect_uri_provided_explicitly: bool
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
class RefreshToken(BaseModel):
|
||||
token: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
expires_at: int | None = None
|
||||
|
||||
|
||||
class AccessToken(BaseModel):
|
||||
token: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
expires_at: int | None = None
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
RegistrationErrorCode = Literal[
|
||||
"invalid_redirect_uri",
|
||||
"invalid_client_metadata",
|
||||
"invalid_software_statement",
|
||||
"unapproved_software_statement",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistrationError(Exception):
|
||||
error: RegistrationErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
AuthorizationErrorCode = Literal[
|
||||
"invalid_request",
|
||||
"unauthorized_client",
|
||||
"access_denied",
|
||||
"unsupported_response_type",
|
||||
"invalid_scope",
|
||||
"server_error",
|
||||
"temporarily_unavailable",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthorizeError(Exception):
|
||||
error: AuthorizationErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
TokenErrorCode = Literal[
|
||||
"invalid_request",
|
||||
"invalid_client",
|
||||
"invalid_grant",
|
||||
"unauthorized_client",
|
||||
"unsupported_grant_type",
|
||||
"invalid_scope",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenError(Exception):
|
||||
error: TokenErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
class TokenVerifier(Protocol):
|
||||
"""Protocol for verifying bearer tokens."""
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
"""Verify a bearer token and return access info if valid."""
|
||||
|
||||
|
||||
# NOTE: FastMCP doesn't render any of these types in the user response, so it's
|
||||
# OK to add fields to subclasses which should not be exposed externally.
|
||||
AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode)
|
||||
RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
|
||||
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
|
||||
|
||||
|
||||
class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""
|
||||
Retrieves client information by client ID.
|
||||
|
||||
Implementors MAY raise NotImplementedError if dynamic client registration is
|
||||
disabled in ClientRegistrationOptions.
|
||||
|
||||
Args:
|
||||
client_id: The ID of the client to retrieve.
|
||||
|
||||
Returns:
|
||||
The client information, or None if the client does not exist.
|
||||
"""
|
||||
|
||||
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""
|
||||
Saves client information as part of registering it.
|
||||
|
||||
Implementors MAY raise NotImplementedError if dynamic client registration is
|
||||
disabled in ClientRegistrationOptions.
|
||||
|
||||
Args:
|
||||
client_info: The client metadata to register.
|
||||
|
||||
Raises:
|
||||
RegistrationError: If the client metadata is invalid.
|
||||
"""
|
||||
|
||||
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||
"""
|
||||
Called as part of the /authorize endpoint, and returns a URL that the client
|
||||
will be redirected to.
|
||||
Many MCP implementations will redirect to a third-party provider to perform
|
||||
a second OAuth exchange with that provider. In this sort of setup, the client
|
||||
has an OAuth connection with the MCP server, and the MCP server has an OAuth
|
||||
connection with the 3rd-party provider. At the end of this flow, the client
|
||||
should be redirected to the redirect_uri from params.redirect_uri.
|
||||
|
||||
+--------+ +------------+ +-------------------+
|
||||
| | | | | |
|
||||
| Client | --> | MCP Server | --> | 3rd Party OAuth |
|
||||
| | | | | Server |
|
||||
+--------+ +------------+ +-------------------+
|
||||
| ^ |
|
||||
+------------+ | | |
|
||||
| | | | Redirect |
|
||||
|redirect_uri|<-----+ +------------------+
|
||||
| |
|
||||
+------------+
|
||||
|
||||
Implementations will need to define another handler on the MCP server return
|
||||
flow to perform the second redirect, and generate and store an authorization
|
||||
code as part of completing the OAuth authorization step.
|
||||
|
||||
Implementations SHOULD generate an authorization code with at least 160 bits of
|
||||
entropy,
|
||||
and MUST generate an authorization code with at least 128 bits of entropy.
|
||||
See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10.
|
||||
|
||||
Args:
|
||||
client: The client requesting authorization.
|
||||
params: The parameters of the authorization request.
|
||||
|
||||
Returns:
|
||||
A URL to redirect the client to for authorization.
|
||||
|
||||
Raises:
|
||||
AuthorizeError: If the authorization request is invalid.
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: str
|
||||
) -> AuthorizationCodeT | None:
|
||||
"""
|
||||
Loads an AuthorizationCode by its code.
|
||||
|
||||
Args:
|
||||
client: The client that requested the authorization code.
|
||||
authorization_code: The authorization code to get the challenge for.
|
||||
|
||||
Returns:
|
||||
The AuthorizationCode, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
async def exchange_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT
|
||||
) -> OAuthToken:
|
||||
"""
|
||||
Exchanges an authorization code for an access token and refresh token.
|
||||
|
||||
Args:
|
||||
client: The client exchanging the authorization code.
|
||||
authorization_code: The authorization code to exchange.
|
||||
|
||||
Returns:
|
||||
The OAuth token, containing access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
TokenError: If the request is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None:
|
||||
"""
|
||||
Loads a RefreshToken by its token string.
|
||||
|
||||
Args:
|
||||
client: The client that is requesting to load the refresh token.
|
||||
refresh_token: The refresh token string to load.
|
||||
|
||||
Returns:
|
||||
The RefreshToken object if found, or None if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def exchange_refresh_token(
|
||||
self,
|
||||
client: OAuthClientInformationFull,
|
||||
refresh_token: RefreshTokenT,
|
||||
scopes: list[str],
|
||||
) -> OAuthToken:
|
||||
"""
|
||||
Exchanges a refresh token for an access token and refresh token.
|
||||
|
||||
Implementations SHOULD rotate both the access token and refresh token.
|
||||
|
||||
Args:
|
||||
client: The client exchanging the refresh token.
|
||||
refresh_token: The refresh token to exchange.
|
||||
scopes: Optional scopes to request with the new access token.
|
||||
|
||||
Returns:
|
||||
The OAuth token, containing access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
TokenError: If the request is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_access_token(self, token: str) -> AccessTokenT | None:
|
||||
"""
|
||||
Loads an access token by its token.
|
||||
|
||||
Args:
|
||||
token: The access token to verify.
|
||||
|
||||
Returns:
|
||||
The AuthInfo, or None if the token is invalid.
|
||||
"""
|
||||
|
||||
async def revoke_token(
|
||||
self,
|
||||
token: AccessTokenT | RefreshTokenT,
|
||||
) -> None:
|
||||
"""
|
||||
Revokes an access or refresh token.
|
||||
|
||||
If the given token is invalid or already revoked, this method should do nothing.
|
||||
|
||||
Implementations SHOULD revoke both the access token and its corresponding
|
||||
refresh token, regardless of which of the access token or refresh token is
|
||||
provided.
|
||||
|
||||
Args:
|
||||
token: the token to revoke
|
||||
"""
|
||||
|
||||
|
||||
def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
|
||||
parsed_uri = urlparse(redirect_uri_base)
|
||||
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs]
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
query_params.append((k, v))
|
||||
|
||||
redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params)))
|
||||
return redirect_uri
|
||||
|
||||
|
||||
class ProviderTokenVerifier(TokenVerifier):
|
||||
"""Token verifier that uses an OAuthAuthorizationServerProvider.
|
||||
|
||||
This is provided for backwards compatibility with existing auth_server_provider
|
||||
configurations. For new implementations using AS/RS separation, consider using
|
||||
the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"):
|
||||
self.provider = provider
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
"""Verify token using the provider's load_access_token method."""
|
||||
return await self.provider.load_access_token(token)
|
||||
253
.venv/lib/python3.11/site-packages/mcp/server/auth/routes.py
Normal file
253
.venv/lib/python3.11/site-packages/mcp/server/auth/routes.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Route, request_response # type: ignore
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from mcp.server.auth.handlers.authorize import AuthorizationHandler
|
||||
from mcp.server.auth.handlers.metadata import MetadataHandler
|
||||
from mcp.server.auth.handlers.register import RegistrationHandler
|
||||
from mcp.server.auth.handlers.revoke import RevocationHandler
|
||||
from mcp.server.auth.handlers.token import TokenHandler
|
||||
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
|
||||
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
|
||||
from mcp.shared.auth import OAuthMetadata
|
||||
|
||||
|
||||
def validate_issuer_url(url: AnyHttpUrl):
|
||||
"""
|
||||
Validate that the issuer URL meets OAuth 2.0 requirements.
|
||||
|
||||
Args:
|
||||
url: The issuer URL to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If the issuer URL is invalid
|
||||
"""
|
||||
|
||||
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
|
||||
if (
|
||||
url.scheme != "https"
|
||||
and url.host != "localhost"
|
||||
and (url.host is not None and not url.host.startswith("127.0.0.1"))
|
||||
):
|
||||
raise ValueError("Issuer URL must be HTTPS") # pragma: no cover
|
||||
|
||||
# No fragments or query parameters allowed
|
||||
if url.fragment:
|
||||
raise ValueError("Issuer URL must not have a fragment") # pragma: no cover
|
||||
if url.query:
|
||||
raise ValueError("Issuer URL must not have a query string") # pragma: no cover
|
||||
|
||||
|
||||
AUTHORIZATION_PATH = "/authorize"
|
||||
TOKEN_PATH = "/token"
|
||||
REGISTRATION_PATH = "/register"
|
||||
REVOCATION_PATH = "/revoke"
|
||||
|
||||
|
||||
def cors_middleware(
|
||||
handler: Callable[[Request], Response | Awaitable[Response]],
|
||||
allow_methods: list[str],
|
||||
) -> ASGIApp:
|
||||
cors_app = CORSMiddleware(
|
||||
app=request_response(handler),
|
||||
allow_origins="*",
|
||||
allow_methods=allow_methods,
|
||||
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
|
||||
)
|
||||
return cors_app
|
||||
|
||||
|
||||
def create_auth_routes(
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
issuer_url: AnyHttpUrl,
|
||||
service_documentation_url: AnyHttpUrl | None = None,
|
||||
client_registration_options: ClientRegistrationOptions | None = None,
|
||||
revocation_options: RevocationOptions | None = None,
|
||||
) -> list[Route]:
|
||||
validate_issuer_url(issuer_url)
|
||||
|
||||
client_registration_options = client_registration_options or ClientRegistrationOptions()
|
||||
revocation_options = revocation_options or RevocationOptions()
|
||||
metadata = build_metadata(
|
||||
issuer_url,
|
||||
service_documentation_url,
|
||||
client_registration_options,
|
||||
revocation_options,
|
||||
)
|
||||
client_authenticator = ClientAuthenticator(provider)
|
||||
|
||||
# Create routes
|
||||
# Allow CORS requests for endpoints meant to be hit by the OAuth client
|
||||
# (with the client secret). This is intended to support things like MCP Inspector,
|
||||
# where the client runs in a web browser.
|
||||
routes = [
|
||||
Route(
|
||||
"/.well-known/oauth-authorization-server",
|
||||
endpoint=cors_middleware(
|
||||
MetadataHandler(metadata).handle,
|
||||
["GET", "OPTIONS"],
|
||||
),
|
||||
methods=["GET", "OPTIONS"],
|
||||
),
|
||||
Route(
|
||||
AUTHORIZATION_PATH,
|
||||
# do not allow CORS for authorization endpoint;
|
||||
# clients should just redirect to this
|
||||
endpoint=AuthorizationHandler(provider).handle,
|
||||
methods=["GET", "POST"],
|
||||
),
|
||||
Route(
|
||||
TOKEN_PATH,
|
||||
endpoint=cors_middleware(
|
||||
TokenHandler(provider, client_authenticator).handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
),
|
||||
]
|
||||
|
||||
if client_registration_options.enabled: # pragma: no branch
|
||||
registration_handler = RegistrationHandler(
|
||||
provider,
|
||||
options=client_registration_options,
|
||||
)
|
||||
routes.append(
|
||||
Route(
|
||||
REGISTRATION_PATH,
|
||||
endpoint=cors_middleware(
|
||||
registration_handler.handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
if revocation_options.enabled: # pragma: no branch
|
||||
revocation_handler = RevocationHandler(provider, client_authenticator)
|
||||
routes.append(
|
||||
Route(
|
||||
REVOCATION_PATH,
|
||||
endpoint=cors_middleware(
|
||||
revocation_handler.handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
return routes
|
||||
|
||||
|
||||
def build_metadata(
|
||||
issuer_url: AnyHttpUrl,
|
||||
service_documentation_url: AnyHttpUrl | None,
|
||||
client_registration_options: ClientRegistrationOptions,
|
||||
revocation_options: RevocationOptions,
|
||||
) -> OAuthMetadata:
|
||||
authorization_url = AnyHttpUrl(str(issuer_url).rstrip("/") + AUTHORIZATION_PATH)
|
||||
token_url = AnyHttpUrl(str(issuer_url).rstrip("/") + TOKEN_PATH)
|
||||
|
||||
# Create metadata
|
||||
metadata = OAuthMetadata(
|
||||
issuer=issuer_url,
|
||||
authorization_endpoint=authorization_url,
|
||||
token_endpoint=token_url,
|
||||
scopes_supported=client_registration_options.valid_scopes,
|
||||
response_types_supported=["code"],
|
||||
response_modes_supported=None,
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"],
|
||||
token_endpoint_auth_signing_alg_values_supported=None,
|
||||
service_documentation=service_documentation_url,
|
||||
ui_locales_supported=None,
|
||||
op_policy_uri=None,
|
||||
op_tos_uri=None,
|
||||
introspection_endpoint=None,
|
||||
code_challenge_methods_supported=["S256"],
|
||||
)
|
||||
|
||||
# Add registration endpoint if supported
|
||||
if client_registration_options.enabled: # pragma: no branch
|
||||
metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH)
|
||||
|
||||
# Add revocation endpoint if supported
|
||||
if revocation_options.enabled: # pragma: no branch
|
||||
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
|
||||
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"]
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def build_resource_metadata_url(resource_server_url: AnyHttpUrl) -> AnyHttpUrl:
|
||||
"""
|
||||
Build RFC 9728 compliant protected resource metadata URL.
|
||||
|
||||
Inserts /.well-known/oauth-protected-resource between host and resource path
|
||||
as specified in RFC 9728 §3.1.
|
||||
|
||||
Args:
|
||||
resource_server_url: The resource server URL (e.g., https://example.com/mcp)
|
||||
|
||||
Returns:
|
||||
The metadata URL (e.g., https://example.com/.well-known/oauth-protected-resource/mcp)
|
||||
"""
|
||||
parsed = urlparse(str(resource_server_url))
|
||||
# Handle trailing slash: if path is just "/", treat as empty
|
||||
resource_path = parsed.path if parsed.path != "/" else ""
|
||||
return AnyHttpUrl(f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-protected-resource{resource_path}")
|
||||
|
||||
|
||||
def create_protected_resource_routes(
|
||||
resource_url: AnyHttpUrl,
|
||||
authorization_servers: list[AnyHttpUrl],
|
||||
scopes_supported: list[str] | None = None,
|
||||
resource_name: str | None = None,
|
||||
resource_documentation: AnyHttpUrl | None = None,
|
||||
) -> list[Route]:
|
||||
"""
|
||||
Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728).
|
||||
|
||||
Args:
|
||||
resource_url: The URL of this resource server
|
||||
authorization_servers: List of authorization servers that can issue tokens
|
||||
scopes_supported: Optional list of scopes supported by this resource
|
||||
|
||||
Returns:
|
||||
List of Starlette routes for protected resource metadata
|
||||
"""
|
||||
from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler
|
||||
from mcp.shared.auth import ProtectedResourceMetadata
|
||||
|
||||
metadata = ProtectedResourceMetadata(
|
||||
resource=resource_url,
|
||||
authorization_servers=authorization_servers,
|
||||
scopes_supported=scopes_supported,
|
||||
resource_name=resource_name,
|
||||
resource_documentation=resource_documentation,
|
||||
# bearer_methods_supported defaults to ["header"] in the model
|
||||
)
|
||||
|
||||
handler = ProtectedResourceMetadataHandler(metadata)
|
||||
|
||||
# RFC 9728 §3.1: Register route at /.well-known/oauth-protected-resource + resource path
|
||||
metadata_url = build_resource_metadata_url(resource_url)
|
||||
# Extract just the path part for route registration
|
||||
parsed = urlparse(str(metadata_url))
|
||||
well_known_path = parsed.path
|
||||
|
||||
return [
|
||||
Route(
|
||||
well_known_path,
|
||||
endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]),
|
||||
methods=["GET", "OPTIONS"],
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
|
||||
class ClientRegistrationOptions(BaseModel):
|
||||
enabled: bool = False
|
||||
client_secret_expiry_seconds: int | None = None
|
||||
valid_scopes: list[str] | None = None
|
||||
default_scopes: list[str] | None = None
|
||||
|
||||
|
||||
class RevocationOptions(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
|
||||
class AuthSettings(BaseModel):
|
||||
issuer_url: AnyHttpUrl = Field(
|
||||
...,
|
||||
description="OAuth authorization server URL that issues tokens for this resource server.",
|
||||
)
|
||||
service_documentation_url: AnyHttpUrl | None = None
|
||||
client_registration_options: ClientRegistrationOptions | None = None
|
||||
revocation_options: RevocationOptions | None = None
|
||||
required_scopes: list[str] | None = None
|
||||
|
||||
# Resource Server settings (when operating as RS only)
|
||||
resource_server_url: AnyHttpUrl | None = Field(
|
||||
...,
|
||||
description="The URL of the MCP server to be used as the resource identifier "
|
||||
"and base route to look up OAuth Protected Resource Metadata.",
|
||||
)
|
||||
190
.venv/lib/python3.11/site-packages/mcp/server/elicitation.py
Normal file
190
.venv/lib/python3.11/site-packages/mcp/server/elicitation.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Elicitation utilities for MCP servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.types import RequestId
|
||||
|
||||
ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
|
||||
|
||||
|
||||
class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]):
|
||||
"""Result when user accepts the elicitation."""
|
||||
|
||||
action: Literal["accept"] = "accept"
|
||||
data: ElicitSchemaModelT
|
||||
|
||||
|
||||
class DeclinedElicitation(BaseModel):
|
||||
"""Result when user declines the elicitation."""
|
||||
|
||||
action: Literal["decline"] = "decline"
|
||||
|
||||
|
||||
class CancelledElicitation(BaseModel):
|
||||
"""Result when user cancels the elicitation."""
|
||||
|
||||
action: Literal["cancel"] = "cancel"
|
||||
|
||||
|
||||
ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation
|
||||
|
||||
|
||||
class AcceptedUrlElicitation(BaseModel):
|
||||
"""Result when user accepts a URL mode elicitation."""
|
||||
|
||||
action: Literal["accept"] = "accept"
|
||||
|
||||
|
||||
UrlElicitationResult = AcceptedUrlElicitation | DeclinedElicitation | CancelledElicitation
|
||||
|
||||
|
||||
# Primitive types allowed in elicitation schemas
|
||||
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
|
||||
|
||||
|
||||
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
|
||||
"""Validate that a Pydantic model only contains primitive field types."""
|
||||
for field_name, field_info in schema.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
|
||||
if annotation is None or annotation is types.NoneType: # pragma: no cover
|
||||
continue
|
||||
elif _is_primitive_field(annotation):
|
||||
continue
|
||||
elif _is_string_sequence(annotation):
|
||||
continue
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Elicitation schema field '{field_name}' must be a primitive type "
|
||||
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
|
||||
f"or Optional of these types. Nested models and complex types are not allowed."
|
||||
)
|
||||
|
||||
|
||||
def _is_string_sequence(annotation: type) -> bool:
|
||||
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
|
||||
origin = get_origin(annotation)
|
||||
# Check if it's a sequence-like type with str elements
|
||||
if origin:
|
||||
try:
|
||||
if issubclass(origin, Sequence):
|
||||
args = get_args(annotation)
|
||||
# Should have single str type arg
|
||||
return len(args) == 1 and args[0] is str
|
||||
except TypeError: # pragma: no cover
|
||||
# origin is not a class, so it can't be a subclass of Sequence
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _is_primitive_field(annotation: type) -> bool:
|
||||
"""Check if a field is a primitive type allowed in elicitation schemas."""
|
||||
# Handle basic primitive types
|
||||
if annotation in _ELICITATION_PRIMITIVE_TYPES:
|
||||
return True
|
||||
|
||||
# Handle Union types
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is types.UnionType:
|
||||
args = get_args(annotation)
|
||||
# All args must be primitive types, None, or string sequences
|
||||
return all(
|
||||
arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES or _is_string_sequence(arg) for arg in args
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def elicit_with_validation(
|
||||
session: ServerSession,
|
||||
message: str,
|
||||
schema: type[ElicitSchemaModelT],
|
||||
related_request_id: RequestId | None = None,
|
||||
) -> ElicitationResult[ElicitSchemaModelT]:
|
||||
"""Elicit information from the client/user with schema validation (form mode).
|
||||
|
||||
This method can be used to interactively ask for additional information from the
|
||||
client within a tool's execution. The client might display the message to the
|
||||
user and collect a response according to the provided schema. Or in case a
|
||||
client is an agent, it might decide how to handle the elicitation -- either by asking
|
||||
the user or automatically generating a response.
|
||||
|
||||
For sensitive data like credentials or OAuth flows, use elicit_url() instead.
|
||||
"""
|
||||
# Validate that schema only contains primitive types and fail loudly if not
|
||||
_validate_elicitation_schema(schema)
|
||||
|
||||
json_schema = schema.model_json_schema()
|
||||
|
||||
result = await session.elicit_form(
|
||||
message=message,
|
||||
requestedSchema=json_schema,
|
||||
related_request_id=related_request_id,
|
||||
)
|
||||
|
||||
if result.action == "accept" and result.content is not None:
|
||||
# Validate and parse the content using the schema
|
||||
validated_data = schema.model_validate(result.content)
|
||||
return AcceptedElicitation(data=validated_data)
|
||||
elif result.action == "decline":
|
||||
return DeclinedElicitation()
|
||||
elif result.action == "cancel": # pragma: no cover
|
||||
return CancelledElicitation()
|
||||
else: # pragma: no cover
|
||||
# This should never happen, but handle it just in case
|
||||
raise ValueError(f"Unexpected elicitation action: {result.action}")
|
||||
|
||||
|
||||
async def elicit_url(
|
||||
session: ServerSession,
|
||||
message: str,
|
||||
url: str,
|
||||
elicitation_id: str,
|
||||
related_request_id: RequestId | None = None,
|
||||
) -> UrlElicitationResult:
|
||||
"""Elicit information from the user via out-of-band URL navigation (URL mode).
|
||||
|
||||
This method directs the user to an external URL where sensitive interactions can
|
||||
occur without passing data through the MCP client. Use this for:
|
||||
- Collecting sensitive credentials (API keys, passwords)
|
||||
- OAuth authorization flows with third-party services
|
||||
- Payment and subscription flows
|
||||
- Any interaction where data should not pass through the LLM context
|
||||
|
||||
The response indicates whether the user consented to navigate to the URL.
|
||||
The actual interaction happens out-of-band. When the elicitation completes,
|
||||
the server should send an ElicitCompleteNotification to notify the client.
|
||||
|
||||
Args:
|
||||
session: The server session
|
||||
message: Human-readable explanation of why the interaction is needed
|
||||
url: The URL the user should navigate to
|
||||
elicitation_id: Unique identifier for tracking this elicitation
|
||||
related_request_id: Optional ID of the request that triggered this elicitation
|
||||
|
||||
Returns:
|
||||
UrlElicitationResult indicating accept, decline, or cancel
|
||||
"""
|
||||
result = await session.elicit_url(
|
||||
message=message,
|
||||
url=url,
|
||||
elicitation_id=elicitation_id,
|
||||
related_request_id=related_request_id,
|
||||
)
|
||||
|
||||
if result.action == "accept":
|
||||
return AcceptedUrlElicitation()
|
||||
elif result.action == "decline":
|
||||
return DeclinedElicitation()
|
||||
elif result.action == "cancel":
|
||||
return CancelledElicitation()
|
||||
else: # pragma: no cover
|
||||
# This should never happen, but handle it just in case
|
||||
raise ValueError(f"Unexpected elicitation action: {result.action}")
|
||||
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Server-side experimental features.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Import directly from submodules:
|
||||
- mcp.server.experimental.task_context.ServerTaskContext
|
||||
- mcp.server.experimental.task_support.TaskSupport
|
||||
- mcp.server.experimental.task_result_handler.TaskResultHandler
|
||||
- mcp.server.experimental.request_context.Experimental
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Experimental request context features.
|
||||
|
||||
This module provides the Experimental class which gives access to experimental
|
||||
features within a request context, such as task-augmented request handling.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.experimental.task_context import ServerTaskContext
|
||||
from mcp.server.experimental.task_support import TaskSupport
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY, is_terminal
|
||||
from mcp.types import (
|
||||
METHOD_NOT_FOUND,
|
||||
TASK_FORBIDDEN,
|
||||
TASK_REQUIRED,
|
||||
ClientCapabilities,
|
||||
CreateTaskResult,
|
||||
ErrorData,
|
||||
Result,
|
||||
TaskExecutionMode,
|
||||
TaskMetadata,
|
||||
Tool,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Experimental:
|
||||
"""
|
||||
Experimental features context for task-augmented requests.
|
||||
|
||||
Provides helpers for validating task execution compatibility and
|
||||
running tasks with automatic lifecycle management.
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
task_metadata: TaskMetadata | None = None
|
||||
_client_capabilities: ClientCapabilities | None = field(default=None, repr=False)
|
||||
_session: ServerSession | None = field(default=None, repr=False)
|
||||
_task_support: TaskSupport | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def is_task(self) -> bool:
|
||||
"""Check if this request is task-augmented."""
|
||||
return self.task_metadata is not None
|
||||
|
||||
@property
|
||||
def client_supports_tasks(self) -> bool:
|
||||
"""Check if the client declared task support."""
|
||||
if self._client_capabilities is None:
|
||||
return False
|
||||
return self._client_capabilities.tasks is not None
|
||||
|
||||
def validate_task_mode(
|
||||
self,
|
||||
tool_task_mode: TaskExecutionMode | None,
|
||||
*,
|
||||
raise_error: bool = True,
|
||||
) -> ErrorData | None:
|
||||
"""
|
||||
Validate that the request is compatible with the tool's task execution mode.
|
||||
|
||||
Per MCP spec:
|
||||
- "required": Clients MUST invoke as task. Server returns -32601 if not.
|
||||
- "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do.
|
||||
- "optional": Either is acceptable.
|
||||
|
||||
Args:
|
||||
tool_task_mode: The tool's execution.taskSupport value
|
||||
("forbidden", "optional", "required", or None)
|
||||
raise_error: If True, raises McpError on validation failure. If False, returns ErrorData.
|
||||
|
||||
Returns:
|
||||
None if valid, ErrorData if invalid and raise_error=False
|
||||
|
||||
Raises:
|
||||
McpError: If invalid and raise_error=True
|
||||
"""
|
||||
|
||||
mode = tool_task_mode or TASK_FORBIDDEN
|
||||
|
||||
error: ErrorData | None = None
|
||||
|
||||
if mode == TASK_REQUIRED and not self.is_task:
|
||||
error = ErrorData(
|
||||
code=METHOD_NOT_FOUND,
|
||||
message="This tool requires task-augmented invocation",
|
||||
)
|
||||
elif mode == TASK_FORBIDDEN and self.is_task:
|
||||
error = ErrorData(
|
||||
code=METHOD_NOT_FOUND,
|
||||
message="This tool does not support task-augmented invocation",
|
||||
)
|
||||
|
||||
if error is not None and raise_error:
|
||||
raise McpError(error)
|
||||
|
||||
return error
|
||||
|
||||
def validate_for_tool(
|
||||
self,
|
||||
tool: Tool,
|
||||
*,
|
||||
raise_error: bool = True,
|
||||
) -> ErrorData | None:
|
||||
"""
|
||||
Validate that the request is compatible with the given tool.
|
||||
|
||||
Convenience wrapper around validate_task_mode that extracts the mode from a Tool.
|
||||
|
||||
Args:
|
||||
tool: The Tool definition
|
||||
raise_error: If True, raises McpError on validation failure.
|
||||
|
||||
Returns:
|
||||
None if valid, ErrorData if invalid and raise_error=False
|
||||
"""
|
||||
mode = tool.execution.taskSupport if tool.execution else None
|
||||
return self.validate_task_mode(mode, raise_error=raise_error)
|
||||
|
||||
def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool:
|
||||
"""
|
||||
Check if this client can use a tool with the given task mode.
|
||||
|
||||
Useful for filtering tool lists or providing warnings.
|
||||
Returns False if tool requires "required" but client doesn't support tasks.
|
||||
|
||||
Args:
|
||||
tool_task_mode: The tool's execution.taskSupport value
|
||||
|
||||
Returns:
|
||||
True if the client can use this tool, False otherwise
|
||||
"""
|
||||
mode = tool_task_mode or TASK_FORBIDDEN
|
||||
if mode == TASK_REQUIRED and not self.client_supports_tasks:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
work: Callable[[ServerTaskContext], Awaitable[Result]],
|
||||
*,
|
||||
task_id: str | None = None,
|
||||
model_immediate_response: str | None = None,
|
||||
) -> CreateTaskResult:
|
||||
"""
|
||||
Create a task, spawn background work, and return CreateTaskResult immediately.
|
||||
|
||||
This is the recommended way to handle task-augmented tool calls. It:
|
||||
1. Creates a task in the store
|
||||
2. Spawns the work function in a background task
|
||||
3. Returns CreateTaskResult immediately
|
||||
|
||||
The work function receives a ServerTaskContext with:
|
||||
- elicit() for sending elicitation requests
|
||||
- create_message() for sampling requests
|
||||
- update_status() for progress updates
|
||||
- complete()/fail() for finishing the task
|
||||
|
||||
When work() returns a Result, the task is auto-completed with that result.
|
||||
If work() raises an exception, the task is auto-failed.
|
||||
|
||||
Args:
|
||||
work: Async function that does the actual work
|
||||
task_id: Optional task ID (generated if not provided)
|
||||
model_immediate_response: Optional string to include in _meta as
|
||||
io.modelcontextprotocol/model-immediate-response
|
||||
|
||||
Returns:
|
||||
CreateTaskResult to return to the client
|
||||
|
||||
Raises:
|
||||
RuntimeError: If task support is not enabled or task_metadata is missing
|
||||
|
||||
Example:
|
||||
@server.call_tool()
|
||||
async def handle_tool(name: str, args: dict):
|
||||
ctx = server.request_context
|
||||
|
||||
async def work(task: ServerTaskContext) -> CallToolResult:
|
||||
result = await task.elicit(
|
||||
message="Are you sure?",
|
||||
requestedSchema={"type": "object", ...}
|
||||
)
|
||||
confirmed = result.content.get("confirm", False)
|
||||
return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")])
|
||||
|
||||
return await ctx.experimental.run_task(work)
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
if self._task_support is None:
|
||||
raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.")
|
||||
if self._session is None:
|
||||
raise RuntimeError("Session not available.")
|
||||
if self.task_metadata is None:
|
||||
raise RuntimeError(
|
||||
"Request is not task-augmented (no task field in params). "
|
||||
"The client must send a task-augmented request."
|
||||
)
|
||||
|
||||
support = self._task_support
|
||||
# Access task_group via TaskSupport - raises if not in run() context
|
||||
task_group = support.task_group
|
||||
|
||||
task = await support.store.create_task(self.task_metadata, task_id)
|
||||
|
||||
task_ctx = ServerTaskContext(
|
||||
task=task,
|
||||
store=support.store,
|
||||
session=self._session,
|
||||
queue=support.queue,
|
||||
handler=support.handler,
|
||||
)
|
||||
|
||||
async def execute() -> None:
|
||||
try:
|
||||
result = await work(task_ctx)
|
||||
if not is_terminal(task_ctx.task.status):
|
||||
await task_ctx.complete(result)
|
||||
except Exception as e:
|
||||
if not is_terminal(task_ctx.task.status):
|
||||
await task_ctx.fail(str(e))
|
||||
|
||||
task_group.start_soon(execute)
|
||||
|
||||
meta: dict[str, Any] | None = None
|
||||
if model_immediate_response is not None:
|
||||
meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response}
|
||||
|
||||
return CreateTaskResult(task=task, **{"_meta": meta} if meta else {})
|
||||
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Experimental server session features for server→client task operations.
|
||||
|
||||
This module provides the server-side equivalent of ExperimentalClientFeatures,
|
||||
allowing the server to send task-augmented requests to the client and poll for results.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
|
||||
from mcp.shared.experimental.tasks.capabilities import (
|
||||
require_task_augmented_elicitation,
|
||||
require_task_augmented_sampling,
|
||||
)
|
||||
from mcp.shared.experimental.tasks.polling import poll_until_terminal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.session import ServerSession
|
||||
|
||||
ResultT = TypeVar("ResultT", bound=types.Result)
|
||||
|
||||
|
||||
class ExperimentalServerSessionFeatures:
|
||||
"""
|
||||
Experimental server session features for server→client task operations.
|
||||
|
||||
This provides the server-side equivalent of ExperimentalClientFeatures,
|
||||
allowing the server to send task-augmented requests to the client and
|
||||
poll for results.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Access via session.experimental:
|
||||
result = await session.experimental.elicit_as_task(...)
|
||||
"""
|
||||
|
||||
def __init__(self, session: "ServerSession") -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_task(self, task_id: str) -> types.GetTaskResult:
|
||||
"""
|
||||
Send tasks/get to the client to get task status.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
GetTaskResult containing the task status
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ServerRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))),
|
||||
types.GetTaskResult,
|
||||
)
|
||||
|
||||
async def get_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
result_type: type[ResultT],
|
||||
) -> ResultT:
|
||||
"""
|
||||
Send tasks/result to the client to retrieve the final result.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
result_type: The expected result type
|
||||
|
||||
Returns:
|
||||
The task result, validated against result_type
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ServerRequest(types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))),
|
||||
result_type,
|
||||
)
|
||||
|
||||
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
|
||||
"""
|
||||
Poll a client task until it reaches terminal status.
|
||||
|
||||
Yields GetTaskResult for each poll, allowing the caller to react to
|
||||
status changes. Exits when task reaches a terminal status.
|
||||
|
||||
Respects the pollInterval hint from the client.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Yields:
|
||||
GetTaskResult for each poll
|
||||
"""
|
||||
async for status in poll_until_terminal(self.get_task, task_id):
|
||||
yield status
|
||||
|
||||
async def elicit_as_task(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: types.ElicitRequestedSchema,
|
||||
*,
|
||||
ttl: int = 60000,
|
||||
) -> types.ElicitResult:
|
||||
"""
|
||||
Send a task-augmented elicitation to the client and poll until complete.
|
||||
|
||||
The client will create a local task, process the elicitation asynchronously,
|
||||
and return the result when ready. This method handles the full flow:
|
||||
1. Send elicitation with task field
|
||||
2. Receive CreateTaskResult from client
|
||||
3. Poll client's task until terminal
|
||||
4. Retrieve and return the final ElicitResult
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response
|
||||
ttl: Task time-to-live in milliseconds
|
||||
|
||||
Returns:
|
||||
The client's elicitation response
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented elicitation
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_elicitation(client_caps)
|
||||
|
||||
create_result = await self._session.send_request(
|
||||
types.ServerRequest(
|
||||
types.ElicitRequest(
|
||||
params=types.ElicitRequestFormParams(
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
task=types.TaskMetadata(ttl=ttl),
|
||||
)
|
||||
)
|
||||
),
|
||||
types.CreateTaskResult,
|
||||
)
|
||||
|
||||
task_id = create_result.task.taskId
|
||||
|
||||
async for _ in self.poll_task(task_id):
|
||||
pass
|
||||
|
||||
return await self.get_task_result(task_id, types.ElicitResult)
|
||||
|
||||
async def create_message_as_task(
|
||||
self,
|
||||
messages: list[types.SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
ttl: int = 60000,
|
||||
system_prompt: str | None = None,
|
||||
include_context: types.IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: types.ModelPreferences | None = None,
|
||||
tools: list[types.Tool] | None = None,
|
||||
tool_choice: types.ToolChoice | None = None,
|
||||
) -> types.CreateMessageResult:
|
||||
"""
|
||||
Send a task-augmented sampling request and poll until complete.
|
||||
|
||||
The client will create a local task, process the sampling request
|
||||
asynchronously, and return the result when ready.
|
||||
|
||||
Args:
|
||||
messages: The conversation messages for sampling
|
||||
max_tokens: Maximum tokens in the response
|
||||
ttl: Task time-to-live in milliseconds
|
||||
system_prompt: Optional system prompt
|
||||
include_context: Context inclusion strategy
|
||||
temperature: Sampling temperature
|
||||
stop_sequences: Stop sequences
|
||||
metadata: Additional metadata
|
||||
model_preferences: Model selection preferences
|
||||
tools: Optional list of tools the LLM can use during sampling
|
||||
tool_choice: Optional control over tool usage behavior
|
||||
|
||||
Returns:
|
||||
The sampling result from the client
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented sampling or tools
|
||||
ValueError: If tool_use or tool_result message structure is invalid
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_sampling(client_caps)
|
||||
validate_sampling_tools(client_caps, tools, tool_choice)
|
||||
validate_tool_use_result_messages(messages)
|
||||
|
||||
create_result = await self._session.send_request(
|
||||
types.ServerRequest(
|
||||
types.CreateMessageRequest(
|
||||
params=types.CreateMessageRequestParams(
|
||||
messages=messages,
|
||||
maxTokens=max_tokens,
|
||||
systemPrompt=system_prompt,
|
||||
includeContext=include_context,
|
||||
temperature=temperature,
|
||||
stopSequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
modelPreferences=model_preferences,
|
||||
tools=tools,
|
||||
toolChoice=tool_choice,
|
||||
task=types.TaskMetadata(ttl=ttl),
|
||||
)
|
||||
)
|
||||
),
|
||||
types.CreateTaskResult,
|
||||
)
|
||||
|
||||
task_id = create_result.task.taskId
|
||||
|
||||
async for _ in self.poll_task(task_id):
|
||||
pass
|
||||
|
||||
return await self.get_task_result(task_id, types.CreateMessageResult)
|
||||
@@ -0,0 +1,612 @@
|
||||
"""
|
||||
ServerTaskContext - Server-integrated task context with elicitation and sampling.
|
||||
|
||||
This wraps the pure TaskContext and adds server-specific functionality:
|
||||
- Elicitation (task.elicit())
|
||||
- Sampling (task.create_message())
|
||||
- Status notifications
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.experimental.task_result_handler import TaskResultHandler
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.capabilities import (
|
||||
require_task_augmented_elicitation,
|
||||
require_task_augmented_sampling,
|
||||
)
|
||||
from mcp.shared.experimental.tasks.context import TaskContext
|
||||
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
|
||||
from mcp.shared.experimental.tasks.resolver import Resolver
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import (
|
||||
INVALID_REQUEST,
|
||||
TASK_STATUS_INPUT_REQUIRED,
|
||||
TASK_STATUS_WORKING,
|
||||
ClientCapabilities,
|
||||
CreateMessageResult,
|
||||
CreateTaskResult,
|
||||
ElicitationCapability,
|
||||
ElicitRequestedSchema,
|
||||
ElicitResult,
|
||||
ErrorData,
|
||||
IncludeContext,
|
||||
ModelPreferences,
|
||||
RequestId,
|
||||
Result,
|
||||
SamplingCapability,
|
||||
SamplingMessage,
|
||||
ServerNotification,
|
||||
Task,
|
||||
TaskMetadata,
|
||||
TaskStatusNotification,
|
||||
TaskStatusNotificationParams,
|
||||
Tool,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
|
||||
class ServerTaskContext:
|
||||
"""
|
||||
Server-integrated task context with elicitation and sampling.
|
||||
|
||||
This wraps a pure TaskContext and adds server-specific functionality:
|
||||
- elicit() for sending elicitation requests to the client
|
||||
- create_message() for sampling requests
|
||||
- Status notifications via the session
|
||||
|
||||
Example:
|
||||
async def my_task_work(task: ServerTaskContext) -> CallToolResult:
|
||||
await task.update_status("Starting...")
|
||||
|
||||
result = await task.elicit(
|
||||
message="Continue?",
|
||||
requestedSchema={"type": "object", "properties": {"ok": {"type": "boolean"}}}
|
||||
)
|
||||
|
||||
if result.content.get("ok"):
|
||||
return CallToolResult(content=[TextContent(text="Done!")])
|
||||
else:
|
||||
return CallToolResult(content=[TextContent(text="Cancelled")])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task: Task,
|
||||
store: TaskStore,
|
||||
session: ServerSession,
|
||||
queue: TaskMessageQueue,
|
||||
handler: TaskResultHandler | None = None,
|
||||
):
|
||||
"""
|
||||
Create a ServerTaskContext.
|
||||
|
||||
Args:
|
||||
task: The Task object
|
||||
store: The task store
|
||||
session: The server session
|
||||
queue: The message queue for elicitation/sampling
|
||||
handler: The result handler for response routing (required for elicit/create_message)
|
||||
"""
|
||||
self._ctx = TaskContext(task=task, store=store)
|
||||
self._session = session
|
||||
self._queue = queue
|
||||
self._handler = handler
|
||||
self._store = store
|
||||
|
||||
# Delegate pure properties to inner context
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""The task identifier."""
|
||||
return self._ctx.task_id
|
||||
|
||||
@property
|
||||
def task(self) -> Task:
|
||||
"""The current task state."""
|
||||
return self._ctx.task
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Whether cancellation has been requested."""
|
||||
return self._ctx.is_cancelled
|
||||
|
||||
def request_cancellation(self) -> None:
|
||||
"""Request cancellation of this task."""
|
||||
self._ctx.request_cancellation()
|
||||
|
||||
# Enhanced methods with notifications
|
||||
|
||||
async def update_status(self, message: str, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Update the task's status message.
|
||||
|
||||
Args:
|
||||
message: The new status message
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.update_status(message)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def complete(self, result: Result, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Mark the task as completed with the given result.
|
||||
|
||||
Args:
|
||||
result: The task result
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.complete(result)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def fail(self, error: str, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Mark the task as failed with an error message.
|
||||
|
||||
Args:
|
||||
error: The error message
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.fail(error)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def _send_notification(self) -> None:
|
||||
"""Send a task status notification to the client."""
|
||||
task = self._ctx.task
|
||||
await self._session.send_notification(
|
||||
ServerNotification(
|
||||
TaskStatusNotification(
|
||||
params=TaskStatusNotificationParams(
|
||||
taskId=task.taskId,
|
||||
status=task.status,
|
||||
statusMessage=task.statusMessage,
|
||||
createdAt=task.createdAt,
|
||||
lastUpdatedAt=task.lastUpdatedAt,
|
||||
ttl=task.ttl,
|
||||
pollInterval=task.pollInterval,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Server-specific methods: elicitation and sampling
|
||||
|
||||
def _check_elicitation_capability(self) -> None:
|
||||
"""Check if the client supports elicitation."""
|
||||
if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support elicitation capability",
|
||||
)
|
||||
)
|
||||
|
||||
def _check_sampling_capability(self) -> None:
|
||||
"""Check if the client supports sampling."""
|
||||
if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support sampling capability",
|
||||
)
|
||||
)
|
||||
|
||||
async def elicit(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: ElicitRequestedSchema,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send an elicitation request via the task message queue.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the elicitation request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response structure
|
||||
|
||||
Returns:
|
||||
The client's response
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support elicitation capability
|
||||
"""
|
||||
self._check_elicitation_capability()
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return ElicitResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
# Coverage can't track async exception handlers reliably.
|
||||
# This path is tested in test_elicit_restores_status_on_cancellation
|
||||
# which verifies status is restored to "working" after cancellation.
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def elicit_url(
|
||||
self,
|
||||
message: str,
|
||||
url: str,
|
||||
elicitation_id: str,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send a URL mode elicitation request via the task message queue.
|
||||
|
||||
This directs the user to an external URL for out-of-band interactions
|
||||
like OAuth flows, credential collection, or payment processing.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the elicitation request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
message: Human-readable explanation of why the interaction is needed
|
||||
url: The URL the user should navigate to
|
||||
elicitation_id: Unique identifier for tracking this elicitation
|
||||
|
||||
Returns:
|
||||
The client's response indicating acceptance, decline, or cancellation
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support elicitation capability
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
self._check_elicitation_capability()
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
url=url,
|
||||
elicitation_id=elicitation_id,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return ElicitResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: ModelPreferences | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""
|
||||
Send a sampling request via the task message queue.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the sampling request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
messages: The conversation messages for sampling
|
||||
max_tokens: Maximum tokens in the response
|
||||
system_prompt: Optional system prompt
|
||||
include_context: Context inclusion strategy
|
||||
temperature: Sampling temperature
|
||||
stop_sequences: Stop sequences
|
||||
metadata: Additional metadata
|
||||
model_preferences: Model selection preferences
|
||||
tools: Optional list of tools the LLM can use during sampling
|
||||
tool_choice: Optional control over tool usage behavior
|
||||
|
||||
Returns:
|
||||
The sampling result from the client
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support sampling capability or tools
|
||||
ValueError: If tool_use or tool_result message structure is invalid
|
||||
"""
|
||||
self._check_sampling_capability()
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
validate_sampling_tools(client_caps, tools, tool_choice)
|
||||
validate_tool_use_result_messages(messages)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
include_context=include_context,
|
||||
temperature=temperature,
|
||||
stop_sequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
model_preferences=model_preferences,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return CreateMessageResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
# Coverage can't track async exception handlers reliably.
|
||||
# This path is tested in test_create_message_restores_status_on_cancellation
|
||||
# which verifies status is restored to "working" after cancellation.
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def elicit_as_task(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: ElicitRequestedSchema,
|
||||
*,
|
||||
ttl: int = 60000,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send a task-augmented elicitation via the queue, then poll client.
|
||||
|
||||
This is for use inside a task-augmented tool call when you want the client
|
||||
to handle the elicitation as its own task. The elicitation request is queued
|
||||
and delivered when the client calls tasks/result. After the client responds
|
||||
with CreateTaskResult, we poll the client's task until complete.
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response structure
|
||||
ttl: Task time-to-live in milliseconds for the client's task
|
||||
|
||||
Returns:
|
||||
The client's elicitation response
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented elicitation
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_elicitation(client_caps)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit_as_task()")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
related_task_id=self.task_id,
|
||||
task=TaskMetadata(ttl=ttl),
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for initial response (CreateTaskResult from client)
|
||||
response_data = await resolver.wait()
|
||||
create_result = CreateTaskResult.model_validate(response_data)
|
||||
client_task_id = create_result.task.taskId
|
||||
|
||||
# Poll the client's task using session.experimental
|
||||
async for _ in self._session.experimental.poll_task(client_task_id):
|
||||
pass
|
||||
|
||||
# Get final result from client
|
||||
result = await self._session.experimental.get_task_result(
|
||||
client_task_id,
|
||||
ElicitResult,
|
||||
)
|
||||
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return result
|
||||
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def create_message_as_task(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
ttl: int = 60000,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: ModelPreferences | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""
|
||||
Send a task-augmented sampling request via the queue, then poll client.
|
||||
|
||||
This is for use inside a task-augmented tool call when you want the client
|
||||
to handle the sampling as its own task. The request is queued and delivered
|
||||
when the client calls tasks/result. After the client responds with
|
||||
CreateTaskResult, we poll the client's task until complete.
|
||||
|
||||
Args:
|
||||
messages: The conversation messages for sampling
|
||||
max_tokens: Maximum tokens in the response
|
||||
ttl: Task time-to-live in milliseconds for the client's task
|
||||
system_prompt: Optional system prompt
|
||||
include_context: Context inclusion strategy
|
||||
temperature: Sampling temperature
|
||||
stop_sequences: Stop sequences
|
||||
metadata: Additional metadata
|
||||
model_preferences: Model selection preferences
|
||||
tools: Optional list of tools the LLM can use during sampling
|
||||
tool_choice: Optional control over tool usage behavior
|
||||
|
||||
Returns:
|
||||
The sampling result from the client
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented sampling or tools
|
||||
ValueError: If tool_use or tool_result message structure is invalid
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_sampling(client_caps)
|
||||
validate_sampling_tools(client_caps, tools, tool_choice)
|
||||
validate_tool_use_result_messages(messages)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for create_message_as_task()")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build request WITH task field for task-augmented sampling
|
||||
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
include_context=include_context,
|
||||
temperature=temperature,
|
||||
stop_sequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
model_preferences=model_preferences,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
related_task_id=self.task_id,
|
||||
task=TaskMetadata(ttl=ttl),
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for initial response (CreateTaskResult from client)
|
||||
response_data = await resolver.wait()
|
||||
create_result = CreateTaskResult.model_validate(response_data)
|
||||
client_task_id = create_result.task.taskId
|
||||
|
||||
# Poll the client's task using session.experimental
|
||||
async for _ in self._session.experimental.poll_task(client_task_id):
|
||||
pass
|
||||
|
||||
# Get final result from client
|
||||
result = await self._session.experimental.get_task_result(
|
||||
client_task_id,
|
||||
CreateMessageResult,
|
||||
)
|
||||
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return result
|
||||
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
TaskResultHandler - Integrated handler for tasks/result endpoint.
|
||||
|
||||
This implements the dequeue-send-wait pattern from the MCP Tasks spec:
|
||||
1. Dequeue all pending messages for the task
|
||||
2. Send them to the client via transport with relatedRequestId routing
|
||||
3. Wait if task is not in terminal state
|
||||
4. Return final result when task completes
|
||||
|
||||
This is the core of the task message queue pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal
|
||||
from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue
|
||||
from mcp.shared.experimental.tasks.resolver import Resolver
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
INVALID_PARAMS,
|
||||
ErrorData,
|
||||
GetTaskPayloadRequest,
|
||||
GetTaskPayloadResult,
|
||||
JSONRPCMessage,
|
||||
RelatedTaskMetadata,
|
||||
RequestId,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskResultHandler:
|
||||
"""
|
||||
Handler for tasks/result that implements the message queue pattern.
|
||||
|
||||
This handler:
|
||||
1. Dequeues pending messages (elicitations, notifications) for the task
|
||||
2. Sends them to the client via the response stream
|
||||
3. Waits for responses and resolves them back to callers
|
||||
4. Blocks until task reaches terminal state
|
||||
5. Returns the final result
|
||||
|
||||
Usage:
|
||||
# Create handler with store and queue
|
||||
handler = TaskResultHandler(task_store, message_queue)
|
||||
|
||||
# Register it with the server
|
||||
@server.experimental.get_task_result()
|
||||
async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult:
|
||||
ctx = server.request_context
|
||||
return await handler.handle(req, ctx.session, ctx.request_id)
|
||||
|
||||
# Or use the convenience method
|
||||
handler.register(server)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: TaskStore,
|
||||
queue: TaskMessageQueue,
|
||||
):
|
||||
self._store = store
|
||||
self._queue = queue
|
||||
# Map from internal request ID to resolver for routing responses
|
||||
self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {}
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
session: ServerSession,
|
||||
message: SessionMessage,
|
||||
) -> None:
|
||||
"""
|
||||
Send a message via the session.
|
||||
|
||||
This is a helper for delivering queued task messages.
|
||||
"""
|
||||
await session.send_message(message)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
request: GetTaskPayloadRequest,
|
||||
session: ServerSession,
|
||||
request_id: RequestId,
|
||||
) -> GetTaskPayloadResult:
|
||||
"""
|
||||
Handle a tasks/result request.
|
||||
|
||||
This implements the dequeue-send-wait loop:
|
||||
1. Dequeue all pending messages
|
||||
2. Send each via transport with relatedRequestId = this request's ID
|
||||
3. If task not terminal, wait for status change
|
||||
4. Loop until task is terminal
|
||||
5. Return final result
|
||||
|
||||
Args:
|
||||
request: The GetTaskPayloadRequest
|
||||
session: The server session for sending messages
|
||||
request_id: The request ID for relatedRequestId routing
|
||||
|
||||
Returns:
|
||||
GetTaskPayloadResult with the task's final payload
|
||||
"""
|
||||
task_id = request.params.taskId
|
||||
|
||||
while True:
|
||||
task = await self._store.get_task(task_id)
|
||||
if task is None:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message=f"Task not found: {task_id}",
|
||||
)
|
||||
)
|
||||
|
||||
await self._deliver_queued_messages(task_id, session, request_id)
|
||||
|
||||
# If task is terminal, return result
|
||||
if is_terminal(task.status):
|
||||
result = await self._store.get_result(task_id)
|
||||
# GetTaskPayloadResult is a Result with extra="allow"
|
||||
# The stored result contains the actual payload data
|
||||
# Per spec: tasks/result MUST include _meta with related-task metadata
|
||||
related_task = RelatedTaskMetadata(taskId=task_id)
|
||||
related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)}
|
||||
if result is not None:
|
||||
result_data = result.model_dump(by_alias=True)
|
||||
existing_meta: dict[str, Any] = result_data.get("_meta") or {}
|
||||
result_data["_meta"] = {**existing_meta, **related_task_meta}
|
||||
return GetTaskPayloadResult.model_validate(result_data)
|
||||
return GetTaskPayloadResult.model_validate({"_meta": related_task_meta})
|
||||
|
||||
# Wait for task update (status change or new messages)
|
||||
await self._wait_for_task_update(task_id)
|
||||
|
||||
async def _deliver_queued_messages(
|
||||
self,
|
||||
task_id: str,
|
||||
session: ServerSession,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""
|
||||
Dequeue and send all pending messages for a task.
|
||||
|
||||
Each message is sent via the session's write stream with
|
||||
relatedRequestId set so responses route back to this stream.
|
||||
"""
|
||||
while True:
|
||||
message = await self._queue.dequeue(task_id)
|
||||
if message is None:
|
||||
break
|
||||
|
||||
# If this is a request (not notification), wait for response
|
||||
if message.type == "request" and message.resolver is not None:
|
||||
# Store the resolver so we can route the response back
|
||||
original_id = message.original_request_id
|
||||
if original_id is not None:
|
||||
self._pending_requests[original_id] = message.resolver
|
||||
|
||||
logger.debug("Delivering queued message for task %s: %s", task_id, message.type)
|
||||
|
||||
# Send the message with relatedRequestId for routing
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(message.message),
|
||||
metadata=ServerMessageMetadata(related_request_id=request_id),
|
||||
)
|
||||
await self.send_message(session, session_message)
|
||||
|
||||
async def _wait_for_task_update(self, task_id: str) -> None:
|
||||
"""
|
||||
Wait for task to be updated (status change or new message).
|
||||
|
||||
Races between store update and queue message - first one wins.
|
||||
"""
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
async def wait_for_store() -> None:
|
||||
try:
|
||||
await self._store.wait_for_update(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
async def wait_for_queue() -> None:
|
||||
try:
|
||||
await self._queue.wait_for_message(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
tg.start_soon(wait_for_store)
|
||||
tg.start_soon(wait_for_queue)
|
||||
|
||||
def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Route a response back to the waiting resolver.
|
||||
|
||||
This is called when a response arrives for a queued request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID from the response
|
||||
response: The response data
|
||||
|
||||
Returns:
|
||||
True if response was routed, False if no pending request
|
||||
"""
|
||||
resolver = self._pending_requests.pop(request_id, None)
|
||||
if resolver is not None and not resolver.done():
|
||||
resolver.set_result(response)
|
||||
return True
|
||||
return False
|
||||
|
||||
def route_error(self, request_id: RequestId, error: ErrorData) -> bool:
|
||||
"""
|
||||
Route an error back to the waiting resolver.
|
||||
|
||||
Args:
|
||||
request_id: The request ID from the error response
|
||||
error: The error data
|
||||
|
||||
Returns:
|
||||
True if error was routed, False if no pending request
|
||||
"""
|
||||
resolver = self._pending_requests.pop(request_id, None)
|
||||
if resolver is not None and not resolver.done():
|
||||
resolver.set_exception(McpError(error))
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
TaskSupport - Configuration for experimental task support.
|
||||
|
||||
This module provides the TaskSupport class which encapsulates all the
|
||||
infrastructure needed for task-augmented requests: store, queue, and handler.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
|
||||
from mcp.server.experimental.task_result_handler import TaskResultHandler
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
|
||||
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskSupport:
|
||||
"""
|
||||
Configuration for experimental task support.
|
||||
|
||||
Encapsulates the task store, message queue, result handler, and task group
|
||||
for spawning background work.
|
||||
|
||||
When enabled on a server, this automatically:
|
||||
- Configures response routing for each session
|
||||
- Provides default handlers for task operations
|
||||
- Manages a task group for background task execution
|
||||
|
||||
Example:
|
||||
# Simple in-memory setup
|
||||
server.experimental.enable_tasks()
|
||||
|
||||
# Custom store/queue for distributed systems
|
||||
server.experimental.enable_tasks(
|
||||
store=RedisTaskStore(redis_url),
|
||||
queue=RedisTaskMessageQueue(redis_url),
|
||||
)
|
||||
"""
|
||||
|
||||
store: TaskStore
|
||||
queue: TaskMessageQueue
|
||||
handler: TaskResultHandler = field(init=False)
|
||||
_task_group: TaskGroup | None = field(init=False, default=None)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create the result handler from store and queue."""
|
||||
self.handler = TaskResultHandler(self.store, self.queue)
|
||||
|
||||
@property
|
||||
def task_group(self) -> TaskGroup:
|
||||
"""Get the task group for spawning background work.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not within a run() context
|
||||
"""
|
||||
if self._task_group is None:
|
||||
raise RuntimeError("TaskSupport not running. Ensure Server.run() is active.")
|
||||
return self._task_group
|
||||
|
||||
@asynccontextmanager
|
||||
async def run(self) -> AsyncIterator[None]:
|
||||
"""
|
||||
Run the task support lifecycle.
|
||||
|
||||
This creates a task group for spawning background task work.
|
||||
Called automatically by Server.run().
|
||||
|
||||
Usage:
|
||||
async with task_support.run():
|
||||
# Task group is now available
|
||||
...
|
||||
"""
|
||||
async with anyio.create_task_group() as tg:
|
||||
self._task_group = tg
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._task_group = None
|
||||
|
||||
def configure_session(self, session: ServerSession) -> None:
|
||||
"""
|
||||
Configure a session for task support.
|
||||
|
||||
This registers the result handler as a response router so that
|
||||
responses to queued requests (elicitation, sampling) are routed
|
||||
back to the waiting resolvers.
|
||||
|
||||
Called automatically by Server.run() for each new session.
|
||||
|
||||
Args:
|
||||
session: The session to configure
|
||||
"""
|
||||
session.add_response_router(self.handler)
|
||||
|
||||
@classmethod
|
||||
def in_memory(cls) -> "TaskSupport":
|
||||
"""
|
||||
Create in-memory task support.
|
||||
|
||||
Suitable for development, testing, and single-process servers.
|
||||
For distributed systems, provide custom store and queue implementations.
|
||||
|
||||
Returns:
|
||||
TaskSupport configured with in-memory store and queue
|
||||
"""
|
||||
return cls(
|
||||
store=InMemoryTaskStore(),
|
||||
queue=InMemoryTaskMessageQueue(),
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""FastMCP - A more ergonomic interface for MCP servers."""
|
||||
|
||||
from importlib.metadata import version
|
||||
|
||||
from mcp.types import Icon
|
||||
|
||||
from .server import Context, FastMCP
|
||||
from .utilities.types import Audio, Image
|
||||
|
||||
__version__ = version("mcp")
|
||||
__all__ = ["FastMCP", "Context", "Image", "Audio", "Icon"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
"""Custom exceptions for FastMCP."""
|
||||
|
||||
|
||||
class FastMCPError(Exception):
|
||||
"""Base error for FastMCP."""
|
||||
|
||||
|
||||
class ValidationError(FastMCPError):
|
||||
"""Error in validating parameters or return values."""
|
||||
|
||||
|
||||
class ResourceError(FastMCPError):
|
||||
"""Error in resource operations."""
|
||||
|
||||
|
||||
class ToolError(FastMCPError):
|
||||
"""Error in tool operations."""
|
||||
|
||||
|
||||
class InvalidSignature(Exception):
|
||||
"""Invalid signature for use with FastMCP."""
|
||||
@@ -0,0 +1,4 @@
|
||||
from .base import Prompt
|
||||
from .manager import PromptManager
|
||||
|
||||
__all__ = ["Prompt", "PromptManager"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,183 @@
|
||||
"""Base classes for FastMCP prompts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import pydantic_core
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
|
||||
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
|
||||
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
|
||||
from mcp.types import ContentBlock, Icon, TextContent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
from mcp.shared.context import LifespanContextT, RequestT
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base class for all prompt messages."""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: ContentBlock
|
||||
|
||||
def __init__(self, content: str | ContentBlock, **kwargs: Any):
|
||||
if isinstance(content, str):
|
||||
content = TextContent(type="text", text=content)
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
"""A message from the user."""
|
||||
|
||||
role: Literal["user", "assistant"] = "user"
|
||||
|
||||
def __init__(self, content: str | ContentBlock, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class AssistantMessage(Message):
|
||||
"""A message from the assistant."""
|
||||
|
||||
role: Literal["user", "assistant"] = "assistant"
|
||||
|
||||
def __init__(self, content: str | ContentBlock, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage)
|
||||
|
||||
SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
|
||||
|
||||
|
||||
class PromptArgument(BaseModel):
|
||||
"""An argument that can be passed to a prompt."""
|
||||
|
||||
name: str = Field(description="Name of the argument")
|
||||
description: str | None = Field(None, description="Description of what the argument does")
|
||||
required: bool = Field(default=False, description="Whether the argument is required")
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt template that can be rendered with parameters."""
|
||||
|
||||
name: str = Field(description="Name of the prompt")
|
||||
title: str | None = Field(None, description="Human-readable title of the prompt")
|
||||
description: str | None = Field(None, description="Description of what the prompt does")
|
||||
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
|
||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
|
||||
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this prompt")
|
||||
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
|
||||
name: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
icons: list[Icon] | None = None,
|
||||
context_kwarg: str | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a Prompt from a function.
|
||||
|
||||
The function can return:
|
||||
- A string (converted to a message)
|
||||
- A Message object
|
||||
- A dict (converted to a message)
|
||||
- A sequence of any of the above
|
||||
"""
|
||||
func_name = name or fn.__name__
|
||||
|
||||
if func_name == "<lambda>": # pragma: no cover
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
# Find context parameter if it exists
|
||||
if context_kwarg is None: # pragma: no branch
|
||||
context_kwarg = find_context_parameter(fn)
|
||||
|
||||
# Get schema from func_metadata, excluding context parameter
|
||||
func_arg_metadata = func_metadata(
|
||||
fn,
|
||||
skip_names=[context_kwarg] if context_kwarg is not None else [],
|
||||
)
|
||||
parameters = func_arg_metadata.arg_model.model_json_schema()
|
||||
|
||||
# Convert parameters to PromptArguments
|
||||
arguments: list[PromptArgument] = []
|
||||
if "properties" in parameters: # pragma: no branch
|
||||
for param_name, param in parameters["properties"].items():
|
||||
required = param_name in parameters.get("required", [])
|
||||
arguments.append(
|
||||
PromptArgument(
|
||||
name=param_name,
|
||||
description=param.get("description"),
|
||||
required=required,
|
||||
)
|
||||
)
|
||||
|
||||
# ensure the arguments are properly cast
|
||||
fn = validate_call(fn)
|
||||
|
||||
return cls(
|
||||
name=func_name,
|
||||
title=title,
|
||||
description=description or fn.__doc__ or "",
|
||||
arguments=arguments,
|
||||
fn=fn,
|
||||
icons=icons,
|
||||
context_kwarg=context_kwarg,
|
||||
)
|
||||
|
||||
async def render(
|
||||
self,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||
) -> list[Message]:
|
||||
"""Render the prompt with arguments."""
|
||||
# Validate required arguments
|
||||
if self.arguments:
|
||||
required = {arg.name for arg in self.arguments if arg.required}
|
||||
provided = set(arguments or {})
|
||||
missing = required - provided
|
||||
if missing:
|
||||
raise ValueError(f"Missing required arguments: {missing}")
|
||||
|
||||
try:
|
||||
# Add context to arguments if needed
|
||||
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)
|
||||
|
||||
# Call function and check if result is a coroutine
|
||||
result = self.fn(**call_args)
|
||||
if inspect.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
# Validate messages
|
||||
if not isinstance(result, list | tuple):
|
||||
result = [result]
|
||||
|
||||
# Convert result to messages
|
||||
messages: list[Message] = []
|
||||
for msg in result: # type: ignore[reportUnknownVariableType]
|
||||
try:
|
||||
if isinstance(msg, Message):
|
||||
messages.append(msg)
|
||||
elif isinstance(msg, dict):
|
||||
messages.append(message_validator.validate_python(msg))
|
||||
elif isinstance(msg, str):
|
||||
content = TextContent(type="text", text=msg)
|
||||
messages.append(UserMessage(content=content))
|
||||
else: # pragma: no cover
|
||||
content = pydantic_core.to_json(msg, fallback=str, indent=2).decode()
|
||||
messages.append(Message(role="user", content=content))
|
||||
except Exception: # pragma: no cover
|
||||
raise ValueError(f"Could not convert prompt result to message: {msg}")
|
||||
|
||||
return messages
|
||||
except Exception as e: # pragma: no cover
|
||||
raise ValueError(f"Error rendering prompt {self.name}: {e}")
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Prompt management functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from mcp.server.fastmcp.prompts.base import Message, Prompt
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
from mcp.shared.context import LifespanContextT, RequestT
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Manages FastMCP prompts."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_prompts: bool = True):
|
||||
self._prompts: dict[str, Prompt] = {}
|
||||
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
|
||||
|
||||
def get_prompt(self, name: str) -> Prompt | None:
|
||||
"""Get prompt by name."""
|
||||
return self._prompts.get(name)
|
||||
|
||||
def list_prompts(self) -> list[Prompt]:
|
||||
"""List all registered prompts."""
|
||||
return list(self._prompts.values())
|
||||
|
||||
def add_prompt(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
) -> Prompt:
|
||||
"""Add a prompt to the manager."""
|
||||
|
||||
# Check for duplicates
|
||||
existing = self._prompts.get(prompt.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_prompts:
|
||||
logger.warning(f"Prompt already exists: {prompt.name}")
|
||||
return existing
|
||||
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def render_prompt(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||
) -> list[Message]:
|
||||
"""Render a prompt by name with arguments."""
|
||||
prompt = self.get_prompt(name)
|
||||
if not prompt:
|
||||
raise ValueError(f"Unknown prompt: {name}")
|
||||
|
||||
return await prompt.render(arguments, context=context)
|
||||
@@ -0,0 +1,23 @@
|
||||
from .base import Resource
|
||||
from .resource_manager import ResourceManager
|
||||
from .templates import ResourceTemplate
|
||||
from .types import (
|
||||
BinaryResource,
|
||||
DirectoryResource,
|
||||
FileResource,
|
||||
FunctionResource,
|
||||
HttpResource,
|
||||
TextResource,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Resource",
|
||||
"TextResource",
|
||||
"BinaryResource",
|
||||
"FunctionResource",
|
||||
"FileResource",
|
||||
"HttpResource",
|
||||
"DirectoryResource",
|
||||
"ResourceTemplate",
|
||||
"ResourceManager",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,49 @@
|
||||
"""Base classes and interfaces for FastMCP resources."""
|
||||
|
||||
import abc
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import (
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
UrlConstraints,
|
||||
ValidationInfo,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from mcp.types import Annotations, Icon
|
||||
|
||||
|
||||
class Resource(BaseModel, abc.ABC):
|
||||
"""Base class for all resources."""
|
||||
|
||||
model_config = ConfigDict(validate_default=True)
|
||||
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource")
|
||||
name: str | None = Field(description="Name of the resource", default=None)
|
||||
title: str | None = Field(description="Human-readable title of the resource", default=None)
|
||||
description: str | None = Field(description="Description of the resource", default=None)
|
||||
mime_type: str = Field(
|
||||
default="text/plain",
|
||||
description="MIME type of the resource content",
|
||||
pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+(;\s*[a-zA-Z0-9\-_.]+=[a-zA-Z0-9\-_.]+)*$",
|
||||
)
|
||||
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this resource")
|
||||
annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource")
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
|
||||
"""Set default name from URI if not provided."""
|
||||
if name:
|
||||
return name
|
||||
if uri := info.data.get("uri"):
|
||||
return str(uri)
|
||||
raise ValueError("Either name or uri must be provided")
|
||||
|
||||
@abc.abstractmethod
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the resource content."""
|
||||
pass # pragma: no cover
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Resource manager functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp.server.fastmcp.resources.base import Resource
|
||||
from mcp.server.fastmcp.resources.templates import ResourceTemplate
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
from mcp.types import Annotations, Icon
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
from mcp.shared.context import LifespanContextT, RequestT
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResourceManager:
|
||||
"""Manages FastMCP resources."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_resources: bool = True):
|
||||
self._resources: dict[str, Resource] = {}
|
||||
self._templates: dict[str, ResourceTemplate] = {}
|
||||
self.warn_on_duplicate_resources = warn_on_duplicate_resources
|
||||
|
||||
def add_resource(self, resource: Resource) -> Resource:
|
||||
"""Add a resource to the manager.
|
||||
|
||||
Args:
|
||||
resource: A Resource instance to add
|
||||
|
||||
Returns:
|
||||
The added resource. If a resource with the same URI already exists,
|
||||
returns the existing resource.
|
||||
"""
|
||||
logger.debug(
|
||||
"Adding resource",
|
||||
extra={
|
||||
"uri": resource.uri,
|
||||
"type": type(resource).__name__,
|
||||
"resource_name": resource.name,
|
||||
},
|
||||
)
|
||||
existing = self._resources.get(str(resource.uri))
|
||||
if existing:
|
||||
if self.warn_on_duplicate_resources:
|
||||
logger.warning(f"Resource already exists: {resource.uri}")
|
||||
return existing
|
||||
self._resources[str(resource.uri)] = resource
|
||||
return resource
|
||||
|
||||
def add_template(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
uri_template: str,
|
||||
name: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
icons: list[Icon] | None = None,
|
||||
annotations: Annotations | None = None,
|
||||
) -> ResourceTemplate:
|
||||
"""Add a template from a function."""
|
||||
template = ResourceTemplate.from_function(
|
||||
fn,
|
||||
uri_template=uri_template,
|
||||
name=name,
|
||||
title=title,
|
||||
description=description,
|
||||
mime_type=mime_type,
|
||||
icons=icons,
|
||||
annotations=annotations,
|
||||
)
|
||||
self._templates[template.uri_template] = template
|
||||
return template
|
||||
|
||||
async def get_resource(
|
||||
self,
|
||||
uri: AnyUrl | str,
|
||||
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||
) -> Resource | None:
|
||||
"""Get resource by URI, checking concrete resources first, then templates."""
|
||||
uri_str = str(uri)
|
||||
logger.debug("Getting resource", extra={"uri": uri_str})
|
||||
|
||||
# First check concrete resources
|
||||
if resource := self._resources.get(uri_str):
|
||||
return resource
|
||||
|
||||
# Then check templates
|
||||
for template in self._templates.values():
|
||||
if params := template.matches(uri_str):
|
||||
try:
|
||||
return await template.create_resource(uri_str, params, context=context)
|
||||
except Exception as e: # pragma: no cover
|
||||
raise ValueError(f"Error creating resource from template: {e}")
|
||||
|
||||
raise ValueError(f"Unknown resource: {uri}")
|
||||
|
||||
def list_resources(self) -> list[Resource]:
|
||||
"""List all registered resources."""
|
||||
logger.debug("Listing resources", extra={"count": len(self._resources)})
|
||||
return list(self._resources.values())
|
||||
|
||||
def list_templates(self) -> list[ResourceTemplate]:
|
||||
"""List all registered templates."""
|
||||
logger.debug("Listing templates", extra={"count": len(self._templates)})
|
||||
return list(self._templates.values())
|
||||
@@ -0,0 +1,118 @@
|
||||
"""Resource template functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field, validate_call
|
||||
|
||||
from mcp.server.fastmcp.resources.types import FunctionResource, Resource
|
||||
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
|
||||
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
|
||||
from mcp.types import Annotations, Icon
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
from mcp.shared.context import LifespanContextT, RequestT
|
||||
|
||||
|
||||
class ResourceTemplate(BaseModel):
|
||||
"""A template for dynamically creating resources."""
|
||||
|
||||
uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)")
|
||||
name: str = Field(description="Name of the resource")
|
||||
title: str | None = Field(description="Human-readable title of the resource", default=None)
|
||||
description: str | None = Field(description="Description of what the resource does")
|
||||
mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
|
||||
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for the resource template")
|
||||
annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource template")
|
||||
fn: Callable[..., Any] = Field(exclude=True)
|
||||
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
|
||||
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., Any],
|
||||
uri_template: str,
|
||||
name: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
icons: list[Icon] | None = None,
|
||||
annotations: Annotations | None = None,
|
||||
context_kwarg: str | None = None,
|
||||
) -> ResourceTemplate:
|
||||
"""Create a template from a function."""
|
||||
func_name = name or fn.__name__
|
||||
if func_name == "<lambda>":
|
||||
raise ValueError("You must provide a name for lambda functions") # pragma: no cover
|
||||
|
||||
# Find context parameter if it exists
|
||||
if context_kwarg is None: # pragma: no branch
|
||||
context_kwarg = find_context_parameter(fn)
|
||||
|
||||
# Get schema from func_metadata, excluding context parameter
|
||||
func_arg_metadata = func_metadata(
|
||||
fn,
|
||||
skip_names=[context_kwarg] if context_kwarg is not None else [],
|
||||
)
|
||||
parameters = func_arg_metadata.arg_model.model_json_schema()
|
||||
|
||||
# ensure the arguments are properly cast
|
||||
fn = validate_call(fn)
|
||||
|
||||
return cls(
|
||||
uri_template=uri_template,
|
||||
name=func_name,
|
||||
title=title,
|
||||
description=description or fn.__doc__ or "",
|
||||
mime_type=mime_type or "text/plain",
|
||||
icons=icons,
|
||||
annotations=annotations,
|
||||
fn=fn,
|
||||
parameters=parameters,
|
||||
context_kwarg=context_kwarg,
|
||||
)
|
||||
|
||||
def matches(self, uri: str) -> dict[str, Any] | None:
|
||||
"""Check if URI matches template and extract parameters."""
|
||||
# Convert template to regex pattern
|
||||
pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
|
||||
match = re.match(f"^{pattern}$", uri)
|
||||
if match:
|
||||
return match.groupdict()
|
||||
return None
|
||||
|
||||
async def create_resource(
|
||||
self,
|
||||
uri: str,
|
||||
params: dict[str, Any],
|
||||
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||
) -> Resource:
|
||||
"""Create a resource from the template with the given parameters."""
|
||||
try:
|
||||
# Add context to params if needed
|
||||
params = inject_context(self.fn, params, context, self.context_kwarg)
|
||||
|
||||
# Call function and check if result is a coroutine
|
||||
result = self.fn(**params)
|
||||
if inspect.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
return FunctionResource(
|
||||
uri=uri, # type: ignore
|
||||
name=self.name,
|
||||
title=self.title,
|
||||
description=self.description,
|
||||
mime_type=self.mime_type,
|
||||
icons=self.icons,
|
||||
annotations=self.annotations,
|
||||
fn=lambda: result, # Capture result in closure
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating resource from template: {e}")
|
||||
@@ -0,0 +1,201 @@
|
||||
"""Concrete resource implementations."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
import httpx
|
||||
import pydantic
|
||||
import pydantic_core
|
||||
from pydantic import AnyUrl, Field, ValidationInfo, validate_call
|
||||
|
||||
from mcp.server.fastmcp.resources.base import Resource
|
||||
from mcp.types import Annotations, Icon
|
||||
|
||||
|
||||
class TextResource(Resource):
|
||||
"""A resource that reads from a string."""
|
||||
|
||||
text: str = Field(description="Text content of the resource")
|
||||
|
||||
async def read(self) -> str:
|
||||
"""Read the text content."""
|
||||
return self.text # pragma: no cover
|
||||
|
||||
|
||||
class BinaryResource(Resource):
|
||||
"""A resource that reads from bytes."""
|
||||
|
||||
data: bytes = Field(description="Binary content of the resource")
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Read the binary content."""
|
||||
return self.data # pragma: no cover
|
||||
|
||||
|
||||
class FunctionResource(Resource):
|
||||
"""A resource that defers data loading by wrapping a function.
|
||||
|
||||
The function is only called when the resource is read, allowing for lazy loading
|
||||
of potentially expensive data. This is particularly useful when listing resources,
|
||||
as the function won't be called until the resource is actually accessed.
|
||||
|
||||
The function can return:
|
||||
- str for text content (default)
|
||||
- bytes for binary content
|
||||
- other types will be converted to JSON
|
||||
"""
|
||||
|
||||
fn: Callable[[], Any] = Field(exclude=True)
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the resource by calling the wrapped function."""
|
||||
try:
|
||||
# Call the function first to see if it returns a coroutine
|
||||
result = self.fn()
|
||||
# If it's a coroutine, await it
|
||||
if inspect.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, Resource): # pragma: no cover
|
||||
return await result.read()
|
||||
elif isinstance(result, bytes):
|
||||
return result
|
||||
elif isinstance(result, str):
|
||||
return result
|
||||
else:
|
||||
return pydantic_core.to_json(result, fallback=str, indent=2).decode()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading resource {self.uri}: {e}")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., Any],
|
||||
uri: str,
|
||||
name: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
icons: list[Icon] | None = None,
|
||||
annotations: Annotations | None = None,
|
||||
) -> "FunctionResource":
|
||||
"""Create a FunctionResource from a function."""
|
||||
func_name = name or fn.__name__
|
||||
if func_name == "<lambda>": # pragma: no cover
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
# ensure the arguments are properly cast
|
||||
fn = validate_call(fn)
|
||||
|
||||
return cls(
|
||||
uri=AnyUrl(uri),
|
||||
name=func_name,
|
||||
title=title,
|
||||
description=description or fn.__doc__ or "",
|
||||
mime_type=mime_type or "text/plain",
|
||||
fn=fn,
|
||||
icons=icons,
|
||||
annotations=annotations,
|
||||
)
|
||||
|
||||
|
||||
class FileResource(Resource):
|
||||
"""A resource that reads from a file.
|
||||
|
||||
Set is_binary=True to read file as binary data instead of text.
|
||||
"""
|
||||
|
||||
path: Path = Field(description="Path to the file")
|
||||
is_binary: bool = Field(
|
||||
default=False,
|
||||
description="Whether to read the file as binary data",
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="text/plain",
|
||||
description="MIME type of the resource content",
|
||||
)
|
||||
|
||||
@pydantic.field_validator("path")
|
||||
@classmethod
|
||||
def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover
|
||||
"""Ensure path is absolute."""
|
||||
if not path.is_absolute():
|
||||
raise ValueError("Path must be absolute")
|
||||
return path
|
||||
|
||||
@pydantic.field_validator("is_binary")
|
||||
@classmethod
|
||||
def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool:
|
||||
"""Set is_binary based on mime_type if not explicitly set."""
|
||||
if is_binary:
|
||||
return True
|
||||
mime_type = info.data.get("mime_type", "text/plain")
|
||||
return not mime_type.startswith("text/")
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the file content."""
|
||||
try:
|
||||
if self.is_binary:
|
||||
return await anyio.to_thread.run_sync(self.path.read_bytes)
|
||||
return await anyio.to_thread.run_sync(self.path.read_text)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading file {self.path}: {e}")
|
||||
|
||||
|
||||
class HttpResource(Resource):
|
||||
"""A resource that reads from an HTTP endpoint."""
|
||||
|
||||
url: str = Field(description="URL to fetch content from")
|
||||
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the HTTP content."""
|
||||
async with httpx.AsyncClient() as client: # pragma: no cover
|
||||
response = await client.get(self.url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
|
||||
class DirectoryResource(Resource):
|
||||
"""A resource that lists files in a directory."""
|
||||
|
||||
path: Path = Field(description="Path to the directory")
|
||||
recursive: bool = Field(default=False, description="Whether to list files recursively")
|
||||
pattern: str | None = Field(default=None, description="Optional glob pattern to filter files")
|
||||
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
|
||||
|
||||
@pydantic.field_validator("path")
|
||||
@classmethod
|
||||
def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover
|
||||
"""Ensure path is absolute."""
|
||||
if not path.is_absolute():
|
||||
raise ValueError("Path must be absolute")
|
||||
return path
|
||||
|
||||
def list_files(self) -> list[Path]: # pragma: no cover
|
||||
"""List files in the directory."""
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"Directory not found: {self.path}")
|
||||
if not self.path.is_dir():
|
||||
raise NotADirectoryError(f"Not a directory: {self.path}")
|
||||
|
||||
try:
|
||||
if self.pattern:
|
||||
return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern))
|
||||
return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*"))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error listing directory {self.path}: {e}")
|
||||
|
||||
async def read(self) -> str: # Always returns JSON string # pragma: no cover
|
||||
"""Read the directory listing."""
|
||||
try:
|
||||
files = await anyio.to_thread.run_sync(self.list_files)
|
||||
file_list = [str(f.relative_to(self.path)) for f in files if f.is_file()]
|
||||
return json.dumps({"files": file_list}, indent=2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading directory {self.path}: {e}")
|
||||
1343
.venv/lib/python3.11/site-packages/mcp/server/fastmcp/server.py
Normal file
1343
.venv/lib/python3.11/site-packages/mcp/server/fastmcp/server.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,4 @@
|
||||
from .base import Tool
|
||||
from .tool_manager import ToolManager
|
||||
|
||||
__all__ = ["Tool", "ToolManager"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
|
||||
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
|
||||
from mcp.shared.exceptions import UrlElicitationRequiredError
|
||||
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
|
||||
from mcp.types import Icon, ToolAnnotations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
from mcp.shared.context import LifespanContextT, RequestT
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Internal tool registration info."""
|
||||
|
||||
fn: Callable[..., Any] = Field(exclude=True)
|
||||
name: str = Field(description="Name of the tool")
|
||||
title: str | None = Field(None, description="Human-readable title of the tool")
|
||||
description: str = Field(description="Description of what the tool does")
|
||||
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
|
||||
fn_metadata: FuncMetadata = Field(
|
||||
description="Metadata about the function including a pydantic model for tool arguments"
|
||||
)
|
||||
is_async: bool = Field(description="Whether the tool is async")
|
||||
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
|
||||
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
|
||||
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool")
|
||||
meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool")
|
||||
|
||||
@cached_property
|
||||
def output_schema(self) -> dict[str, Any] | None:
|
||||
return self.fn_metadata.output_schema
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., Any],
|
||||
name: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
context_kwarg: str | None = None,
|
||||
annotations: ToolAnnotations | None = None,
|
||||
icons: list[Icon] | None = None,
|
||||
meta: dict[str, Any] | None = None,
|
||||
structured_output: bool | None = None,
|
||||
) -> Tool:
|
||||
"""Create a Tool from a function."""
|
||||
func_name = name or fn.__name__
|
||||
|
||||
validate_and_warn_tool_name(func_name)
|
||||
|
||||
if func_name == "<lambda>":
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
func_doc = description or fn.__doc__ or ""
|
||||
is_async = _is_async_callable(fn)
|
||||
|
||||
if context_kwarg is None: # pragma: no branch
|
||||
context_kwarg = find_context_parameter(fn)
|
||||
|
||||
func_arg_metadata = func_metadata(
|
||||
fn,
|
||||
skip_names=[context_kwarg] if context_kwarg is not None else [],
|
||||
structured_output=structured_output,
|
||||
)
|
||||
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
|
||||
|
||||
return cls(
|
||||
fn=fn,
|
||||
name=func_name,
|
||||
title=title,
|
||||
description=func_doc,
|
||||
parameters=parameters,
|
||||
fn_metadata=func_arg_metadata,
|
||||
is_async=is_async,
|
||||
context_kwarg=context_kwarg,
|
||||
annotations=annotations,
|
||||
icons=icons,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
arguments: dict[str, Any],
|
||||
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||
convert_result: bool = False,
|
||||
) -> Any:
|
||||
"""Run the tool with arguments."""
|
||||
try:
|
||||
result = await self.fn_metadata.call_fn_with_arg_validation(
|
||||
self.fn,
|
||||
self.is_async,
|
||||
arguments,
|
||||
{self.context_kwarg: context} if self.context_kwarg is not None else None,
|
||||
)
|
||||
|
||||
if convert_result:
|
||||
result = self.fn_metadata.convert_result(result)
|
||||
|
||||
return result
|
||||
except UrlElicitationRequiredError:
|
||||
# Re-raise UrlElicitationRequiredError so it can be properly handled
|
||||
# as an MCP error response with code -32042
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
||||
|
||||
|
||||
def _is_async_callable(obj: Any) -> bool:
|
||||
while isinstance(obj, functools.partial): # pragma: no cover
|
||||
obj = obj.func
|
||||
|
||||
return inspect.iscoroutinefunction(obj) or (
|
||||
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
|
||||
)
|
||||
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.tools.base import Tool
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
from mcp.shared.context import LifespanContextT, RequestT
|
||||
from mcp.types import Icon, ToolAnnotations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
from mcp.server.session import ServerSessionT
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manages FastMCP tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warn_on_duplicate_tools: bool = True,
|
||||
*,
|
||||
tools: list[Tool] | None = None,
|
||||
):
|
||||
self._tools: dict[str, Tool] = {}
|
||||
if tools is not None:
|
||||
for tool in tools:
|
||||
if warn_on_duplicate_tools and tool.name in self._tools:
|
||||
logger.warning(f"Tool already exists: {tool.name}")
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
self.warn_on_duplicate_tools = warn_on_duplicate_tools
|
||||
|
||||
def get_tool(self, name: str) -> Tool | None:
|
||||
"""Get tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""List all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def add_tool(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
name: str | None = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
annotations: ToolAnnotations | None = None,
|
||||
icons: list[Icon] | None = None,
|
||||
meta: dict[str, Any] | None = None,
|
||||
structured_output: bool | None = None,
|
||||
) -> Tool:
|
||||
"""Add a tool to the server."""
|
||||
tool = Tool.from_function(
|
||||
fn,
|
||||
name=name,
|
||||
title=title,
|
||||
description=description,
|
||||
annotations=annotations,
|
||||
icons=icons,
|
||||
meta=meta,
|
||||
structured_output=structured_output,
|
||||
)
|
||||
existing = self._tools.get(tool.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_tools:
|
||||
logger.warning(f"Tool already exists: {tool.name}")
|
||||
return existing
|
||||
self._tools[tool.name] = tool
|
||||
return tool
|
||||
|
||||
def remove_tool(self, name: str) -> None:
|
||||
"""Remove a tool by name."""
|
||||
if name not in self._tools:
|
||||
raise ToolError(f"Unknown tool: {name}")
|
||||
del self._tools[name]
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
|
||||
convert_result: bool = False,
|
||||
) -> Any:
|
||||
"""Call a tool by name with arguments."""
|
||||
tool = self.get_tool(name)
|
||||
if not tool:
|
||||
raise ToolError(f"Unknown tool: {name}")
|
||||
|
||||
return await tool.run(arguments, context=context, convert_result=convert_result)
|
||||
@@ -0,0 +1 @@
|
||||
"""FastMCP utility modules."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,68 @@
|
||||
"""Context injection utilities for FastMCP."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
def find_context_parameter(fn: Callable[..., Any]) -> str | None:
|
||||
"""Find the parameter that should receive the Context object.
|
||||
|
||||
Searches through the function's signature to find a parameter
|
||||
with a Context type annotation.
|
||||
|
||||
Args:
|
||||
fn: The function to inspect
|
||||
|
||||
Returns:
|
||||
The name of the context parameter, or None if not found
|
||||
"""
|
||||
from mcp.server.fastmcp.server import Context
|
||||
|
||||
# Get type hints to properly resolve string annotations
|
||||
try:
|
||||
hints = typing.get_type_hints(fn)
|
||||
except Exception:
|
||||
# If we can't resolve type hints, we can't find the context parameter
|
||||
return None
|
||||
|
||||
# Check each parameter's type hint
|
||||
for param_name, annotation in hints.items():
|
||||
# Handle direct Context type
|
||||
if inspect.isclass(annotation) and issubclass(annotation, Context):
|
||||
return param_name
|
||||
|
||||
# Handle generic types like Optional[Context]
|
||||
origin = typing.get_origin(annotation)
|
||||
if origin is not None:
|
||||
args = typing.get_args(annotation)
|
||||
for arg in args:
|
||||
if inspect.isclass(arg) and issubclass(arg, Context):
|
||||
return param_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def inject_context(
|
||||
fn: Callable[..., Any],
|
||||
kwargs: dict[str, Any],
|
||||
context: Any | None,
|
||||
context_kwarg: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Inject context into function kwargs if needed.
|
||||
|
||||
Args:
|
||||
fn: The function that will be called
|
||||
kwargs: The current keyword arguments
|
||||
context: The context object to inject (if any)
|
||||
context_kwarg: The name of the parameter to inject into
|
||||
|
||||
Returns:
|
||||
Updated kwargs with context injected if applicable
|
||||
"""
|
||||
if context_kwarg is not None and context is not None:
|
||||
return {**kwargs, context_kwarg: context}
|
||||
return kwargs
|
||||
@@ -0,0 +1,533 @@
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from itertools import chain
|
||||
from types import GenericAlias
|
||||
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
import pydantic_core
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
WithJsonSchema,
|
||||
create_model,
|
||||
)
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind
|
||||
from typing_extensions import is_typeddict
|
||||
from typing_inspection.introspection import (
|
||||
UNKNOWN,
|
||||
AnnotationSource,
|
||||
ForbiddenQualifier,
|
||||
inspect_annotation,
|
||||
is_union_origin,
|
||||
)
|
||||
|
||||
from mcp.server.fastmcp.exceptions import InvalidSignature
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
from mcp.server.fastmcp.utilities.types import Audio, Image
|
||||
from mcp.types import CallToolResult, ContentBlock, TextContent
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StrictJsonSchema(GenerateJsonSchema):
|
||||
"""A JSON schema generator that raises exceptions instead of emitting warnings.
|
||||
|
||||
This is used to detect non-serializable types during schema generation.
|
||||
"""
|
||||
|
||||
def emit_warning(self, kind: JsonSchemaWarningKind, detail: str) -> None:
|
||||
# Raise an exception instead of emitting a warning
|
||||
raise ValueError(f"JSON schema warning: {kind} - {detail}")
|
||||
|
||||
|
||||
class ArgModelBase(BaseModel):
|
||||
"""A model representing the arguments to a function."""
|
||||
|
||||
def model_dump_one_level(self) -> dict[str, Any]:
|
||||
"""Return a dict of the model's fields, one level deep.
|
||||
|
||||
That is, sub-models etc are not dumped - they are kept as pydantic models.
|
||||
"""
|
||||
kwargs: dict[str, Any] = {}
|
||||
for field_name, field_info in self.__class__.model_fields.items():
|
||||
value = getattr(self, field_name)
|
||||
# Use the alias if it exists, otherwise use the field name
|
||||
output_name = field_info.alias if field_info.alias else field_name
|
||||
kwargs[output_name] = value
|
||||
return kwargs
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
class FuncMetadata(BaseModel):
|
||||
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
|
||||
output_schema: dict[str, Any] | None = None
|
||||
output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None
|
||||
wrap_output: bool = False
|
||||
|
||||
async def call_fn_with_arg_validation(
|
||||
self,
|
||||
fn: Callable[..., Any | Awaitable[Any]],
|
||||
fn_is_async: bool,
|
||||
arguments_to_validate: dict[str, Any],
|
||||
arguments_to_pass_directly: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Call the given function with arguments validated and injected.
|
||||
|
||||
Arguments are first attempted to be parsed from JSON, then validated against
|
||||
the argument model, before being passed to the function.
|
||||
"""
|
||||
arguments_pre_parsed = self.pre_parse_json(arguments_to_validate)
|
||||
arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed)
|
||||
arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()
|
||||
|
||||
arguments_parsed_dict |= arguments_to_pass_directly or {}
|
||||
|
||||
if fn_is_async:
|
||||
return await fn(**arguments_parsed_dict)
|
||||
else:
|
||||
return fn(**arguments_parsed_dict)
|
||||
|
||||
def convert_result(self, result: Any) -> Any:
|
||||
"""
|
||||
Convert the result of a function call to the appropriate format for
|
||||
the lowlevel server tool call handler:
|
||||
|
||||
- If output_model is None, return the unstructured content directly.
|
||||
- If output_model is not None, convert the result to structured output format
|
||||
(dict[str, Any]) and return both unstructured and structured content.
|
||||
|
||||
Note: we return unstructured content here **even though the lowlevel server
|
||||
tool call handler provides generic backwards compatibility serialization of
|
||||
structured content**. This is for FastMCP backwards compatibility: we need to
|
||||
retain FastMCP's ad hoc conversion logic for constructing unstructured output
|
||||
from function return values, whereas the lowlevel server simply serializes
|
||||
the structured output.
|
||||
"""
|
||||
if isinstance(result, CallToolResult):
|
||||
if self.output_schema is not None:
|
||||
assert self.output_model is not None, "Output model must be set if output schema is defined"
|
||||
self.output_model.model_validate(result.structuredContent)
|
||||
return result
|
||||
|
||||
unstructured_content = _convert_to_content(result)
|
||||
|
||||
if self.output_schema is None:
|
||||
return unstructured_content
|
||||
else:
|
||||
if self.wrap_output:
|
||||
result = {"result": result}
|
||||
|
||||
assert self.output_model is not None, "Output model must be set if output schema is defined"
|
||||
validated = self.output_model.model_validate(result)
|
||||
structured_content = validated.model_dump(mode="json", by_alias=True)
|
||||
|
||||
return (unstructured_content, structured_content)
|
||||
|
||||
def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Pre-parse data from JSON.
|
||||
|
||||
Return a dict with same keys as input but with values parsed from JSON
|
||||
if appropriate.
|
||||
|
||||
This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside
|
||||
a string rather than an actual list. Claude desktop is prone to this - in fact
|
||||
it seems incapable of NOT doing this. For sub-models, it tends to pass
|
||||
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
|
||||
"""
|
||||
new_data = data.copy() # Shallow copy
|
||||
|
||||
# Build a mapping from input keys (including aliases) to field info
|
||||
key_to_field_info: dict[str, FieldInfo] = {}
|
||||
for field_name, field_info in self.arg_model.model_fields.items():
|
||||
# Map both the field name and its alias (if any) to the field info
|
||||
key_to_field_info[field_name] = field_info
|
||||
if field_info.alias:
|
||||
key_to_field_info[field_info.alias] = field_info
|
||||
|
||||
for data_key, data_value in data.items():
|
||||
if data_key not in key_to_field_info: # pragma: no cover
|
||||
continue
|
||||
|
||||
field_info = key_to_field_info[data_key]
|
||||
if isinstance(data_value, str) and field_info.annotation is not str:
|
||||
try:
|
||||
pre_parsed = json.loads(data_value)
|
||||
except json.JSONDecodeError:
|
||||
continue # Not JSON - skip
|
||||
if isinstance(pre_parsed, str | int | float):
|
||||
# This is likely that the raw value is e.g. `"hello"` which we
|
||||
# Should really be parsed as '"hello"' in Python - but if we parse
|
||||
# it as JSON it'll turn into just 'hello'. So we skip it.
|
||||
continue
|
||||
new_data[data_key] = pre_parsed
|
||||
assert new_data.keys() == data.keys()
|
||||
return new_data
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
def func_metadata(
|
||||
func: Callable[..., Any],
|
||||
skip_names: Sequence[str] = (),
|
||||
structured_output: bool | None = None,
|
||||
) -> FuncMetadata:
|
||||
"""Given a function, return metadata including a pydantic model representing its
|
||||
signature.
|
||||
|
||||
The use case for this is
|
||||
```
|
||||
meta = func_metadata(func)
|
||||
validated_args = meta.arg_model.model_validate(some_raw_data_dict)
|
||||
return func(**validated_args.model_dump_one_level())
|
||||
```
|
||||
|
||||
**critically** it also provides pre-parse helper to attempt to parse things from
|
||||
JSON.
|
||||
|
||||
Args:
|
||||
func: The function to convert to a pydantic model
|
||||
skip_names: A list of parameter names to skip. These will not be included in
|
||||
the model.
|
||||
structured_output: Controls whether the tool's output is structured or unstructured
|
||||
- If None, auto-detects based on the function's return type annotation
|
||||
- If True, creates a structured tool (return type annotation permitting)
|
||||
- If False, unconditionally creates an unstructured tool
|
||||
|
||||
If structured, creates a Pydantic model for the function's result based on its annotation.
|
||||
Supports various return types:
|
||||
- BaseModel subclasses (used directly)
|
||||
- Primitive types (str, int, float, bool, bytes, None) - wrapped in a
|
||||
model with a 'result' field
|
||||
- TypedDict - converted to a Pydantic model with same fields
|
||||
- Dataclasses and other annotated classes - converted to Pydantic models
|
||||
- Generic types (list, dict, Union, etc.) - wrapped in a model with a 'result' field
|
||||
|
||||
Returns:
|
||||
A FuncMetadata object containing:
|
||||
- arg_model: A pydantic model representing the function's arguments
|
||||
- output_model: A pydantic model for the return type if output is structured
|
||||
- output_conversion: Records how function output should be converted before returning.
|
||||
"""
|
||||
try:
|
||||
sig = inspect.signature(func, eval_str=True)
|
||||
except NameError as e: # pragma: no cover
|
||||
# This raise could perhaps be skipped, and we (FastMCP) just call
|
||||
# model_rebuild right before using it 🤷
|
||||
raise InvalidSignature(f"Unable to evaluate type annotations for callable {func.__name__!r}") from e
|
||||
params = sig.parameters
|
||||
dynamic_pydantic_model_params: dict[str, Any] = {}
|
||||
for param in params.values():
|
||||
if param.name.startswith("_"): # pragma: no cover
|
||||
raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'")
|
||||
if param.name in skip_names:
|
||||
continue
|
||||
|
||||
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any
|
||||
field_name = param.name
|
||||
field_kwargs: dict[str, Any] = {}
|
||||
field_metadata: list[Any] = []
|
||||
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
field_metadata.append(WithJsonSchema({"title": param.name, "type": "string"}))
|
||||
# Check if the parameter name conflicts with BaseModel attributes
|
||||
# This is necessary because Pydantic warns about shadowing parent attributes
|
||||
if hasattr(BaseModel, field_name) and callable(getattr(BaseModel, field_name)):
|
||||
# Use an alias to avoid the shadowing warning
|
||||
field_kwargs["alias"] = field_name
|
||||
# Use a prefixed field name
|
||||
field_name = f"field_{field_name}"
|
||||
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
dynamic_pydantic_model_params[field_name] = (
|
||||
Annotated[(annotation, *field_metadata, Field(**field_kwargs))],
|
||||
param.default,
|
||||
)
|
||||
else:
|
||||
dynamic_pydantic_model_params[field_name] = Annotated[(annotation, *field_metadata, Field(**field_kwargs))]
|
||||
|
||||
arguments_model = create_model(
|
||||
f"{func.__name__}Arguments",
|
||||
__base__=ArgModelBase,
|
||||
**dynamic_pydantic_model_params,
|
||||
)
|
||||
|
||||
if structured_output is False:
|
||||
return FuncMetadata(arg_model=arguments_model)
|
||||
|
||||
# set up structured output support based on return type annotation
|
||||
|
||||
if sig.return_annotation is inspect.Parameter.empty and structured_output is True:
|
||||
raise InvalidSignature(f"Function {func.__name__}: return annotation required for structured output")
|
||||
|
||||
try:
|
||||
inspected_return_ann = inspect_annotation(sig.return_annotation, annotation_source=AnnotationSource.FUNCTION)
|
||||
except ForbiddenQualifier as e:
|
||||
raise InvalidSignature(f"Function {func.__name__}: return annotation contains an invalid type qualifier") from e
|
||||
|
||||
return_type_expr = inspected_return_ann.type
|
||||
|
||||
# `AnnotationSource.FUNCTION` allows no type qualifier to be used, so `return_type_expr` is guaranteed to *not* be
|
||||
# unknown (i.e. a bare `Final`).
|
||||
assert return_type_expr is not UNKNOWN
|
||||
|
||||
if is_union_origin(get_origin(return_type_expr)):
|
||||
args = get_args(return_type_expr)
|
||||
# Check if CallToolResult appears in the union (excluding None for Optional check)
|
||||
if any(isinstance(arg, type) and issubclass(arg, CallToolResult) for arg in args if arg is not type(None)):
|
||||
raise InvalidSignature(
|
||||
f"Function {func.__name__}: CallToolResult cannot be used in Union or Optional types. "
|
||||
"To return empty results, use: CallToolResult(content=[])"
|
||||
)
|
||||
|
||||
original_annotation: Any
|
||||
# if the typehint is CallToolResult, the user either intends to return without validation
|
||||
# or they provided validation as Annotated metadata
|
||||
if isinstance(return_type_expr, type) and issubclass(return_type_expr, CallToolResult):
|
||||
if inspected_return_ann.metadata:
|
||||
return_type_expr = inspected_return_ann.metadata[0]
|
||||
if len(inspected_return_ann.metadata) >= 2:
|
||||
# Reconstruct the original annotation, by preserving the remaining metadata,
|
||||
# i.e. from `Annotated[CallToolResult, ReturnType, Gt(1)]` to
|
||||
# `Annotated[ReturnType, Gt(1)]`:
|
||||
original_annotation = Annotated[
|
||||
(return_type_expr, *inspected_return_ann.metadata[1:])
|
||||
] # pragma: no cover
|
||||
else:
|
||||
# We only had `Annotated[CallToolResult, ReturnType]`, treat the original annotation
|
||||
# as beging `ReturnType`:
|
||||
original_annotation = return_type_expr
|
||||
else:
|
||||
return FuncMetadata(arg_model=arguments_model)
|
||||
else:
|
||||
original_annotation = sig.return_annotation
|
||||
|
||||
output_model, output_schema, wrap_output = _try_create_model_and_schema(
|
||||
original_annotation, return_type_expr, func.__name__
|
||||
)
|
||||
|
||||
if output_model is None and structured_output is True:
|
||||
# Model creation failed or produced warnings - no structured output
|
||||
raise InvalidSignature(
|
||||
f"Function {func.__name__}: return type {return_type_expr} is not serializable for structured output"
|
||||
)
|
||||
|
||||
return FuncMetadata(
|
||||
arg_model=arguments_model,
|
||||
output_schema=output_schema,
|
||||
output_model=output_model,
|
||||
wrap_output=wrap_output,
|
||||
)
|
||||
|
||||
|
||||
def _try_create_model_and_schema(
|
||||
original_annotation: Any,
|
||||
type_expr: Any,
|
||||
func_name: str,
|
||||
) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]:
|
||||
"""Try to create a model and schema for the given annotation without warnings.
|
||||
|
||||
Args:
|
||||
original_annotation: The original return annotation (may be wrapped in `Annotated`).
|
||||
type_expr: The underlying type expression derived from the return annotation
|
||||
(`Annotated` and type qualifiers were stripped).
|
||||
func_name: The name of the function.
|
||||
|
||||
Returns:
|
||||
tuple of (model or None, schema or None, wrap_output)
|
||||
Model and schema are None if warnings occur or creation fails.
|
||||
wrap_output is True if the result needs to be wrapped in {"result": ...}
|
||||
"""
|
||||
model = None
|
||||
wrap_output = False
|
||||
|
||||
# First handle special case: None
|
||||
if type_expr is None:
|
||||
model = _create_wrapped_model(func_name, original_annotation)
|
||||
wrap_output = True
|
||||
|
||||
# Handle GenericAlias types (list[str], dict[str, int], Union[str, int], etc.)
|
||||
elif isinstance(type_expr, GenericAlias):
|
||||
origin = get_origin(type_expr)
|
||||
|
||||
# Special case: dict with string keys can use RootModel
|
||||
if origin is dict:
|
||||
args = get_args(type_expr)
|
||||
if len(args) == 2 and args[0] is str:
|
||||
# TODO: should we use the original annotation? We are loosing any potential `Annotated`
|
||||
# metadata for Pydantic here:
|
||||
model = _create_dict_model(func_name, type_expr)
|
||||
else:
|
||||
# dict with non-str keys needs wrapping
|
||||
model = _create_wrapped_model(func_name, original_annotation)
|
||||
wrap_output = True
|
||||
else:
|
||||
# All other generic types need wrapping (list, tuple, Union, Optional, etc.)
|
||||
model = _create_wrapped_model(func_name, original_annotation)
|
||||
wrap_output = True
|
||||
|
||||
# Handle regular type objects
|
||||
elif isinstance(type_expr, type):
|
||||
type_annotation = cast(type[Any], type_expr)
|
||||
|
||||
# Case 1: BaseModel subclasses (can be used directly)
|
||||
if issubclass(type_annotation, BaseModel):
|
||||
model = type_annotation
|
||||
|
||||
# Case 2: TypedDicts:
|
||||
elif is_typeddict(type_annotation):
|
||||
model = _create_model_from_typeddict(type_annotation)
|
||||
|
||||
# Case 3: Primitive types that need wrapping
|
||||
elif type_annotation in (str, int, float, bool, bytes, type(None)):
|
||||
model = _create_wrapped_model(func_name, original_annotation)
|
||||
wrap_output = True
|
||||
|
||||
# Case 4: Other class types (dataclasses, regular classes with annotations)
|
||||
else:
|
||||
type_hints = get_type_hints(type_annotation)
|
||||
if type_hints:
|
||||
# Classes with type hints can be converted to Pydantic models
|
||||
model = _create_model_from_class(type_annotation, type_hints)
|
||||
# Classes without type hints are not serializable - model remains None
|
||||
|
||||
# Handle any other types not covered above
|
||||
else:
|
||||
# This includes typing constructs that aren't GenericAlias in Python 3.10
|
||||
# (e.g., Union, Optional in some Python versions)
|
||||
model = _create_wrapped_model(func_name, original_annotation)
|
||||
wrap_output = True
|
||||
|
||||
if model:
|
||||
# If we successfully created a model, try to get its schema
|
||||
# Use StrictJsonSchema to raise exceptions instead of warnings
|
||||
try:
|
||||
schema = model.model_json_schema(schema_generator=StrictJsonSchema)
|
||||
except (TypeError, ValueError, pydantic_core.SchemaError, pydantic_core.ValidationError) as e:
|
||||
# These are expected errors when a type can't be converted to a Pydantic schema
|
||||
# TypeError: When Pydantic can't handle the type
|
||||
# ValueError: When there are issues with the type definition (including our custom warnings)
|
||||
# SchemaError: When Pydantic can't build a schema
|
||||
# ValidationError: When validation fails
|
||||
logger.info(f"Cannot create schema for type {type_expr} in {func_name}: {type(e).__name__}: {e}")
|
||||
return None, None, False
|
||||
|
||||
return model, schema, wrap_output
|
||||
|
||||
return None, None, False
|
||||
|
||||
|
||||
_no_default = object()
|
||||
|
||||
|
||||
def _create_model_from_class(cls: type[Any], type_hints: dict[str, Any]) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from an ordinary class.
|
||||
|
||||
The created model will:
|
||||
- Have the same name as the class
|
||||
- Have fields with the same names and types as the class's fields
|
||||
- Include all fields whose type does not include None in the set of required fields
|
||||
|
||||
Precondition: cls must have type hints (i.e., `type_hints` is non-empty)
|
||||
"""
|
||||
model_fields: dict[str, Any] = {}
|
||||
for field_name, field_type in type_hints.items():
|
||||
if field_name.startswith("_"): # pragma: no cover
|
||||
continue
|
||||
|
||||
default = getattr(cls, field_name, _no_default)
|
||||
if default is _no_default:
|
||||
model_fields[field_name] = field_type
|
||||
else:
|
||||
model_fields[field_name] = (field_type, default)
|
||||
|
||||
return create_model(cls.__name__, __config__=ConfigDict(from_attributes=True), **model_fields)
|
||||
|
||||
|
||||
def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from a TypedDict.
|
||||
|
||||
The created model will have the same name and fields as the TypedDict.
|
||||
"""
|
||||
type_hints = get_type_hints(td_type)
|
||||
required_keys = getattr(td_type, "__required_keys__", set(type_hints.keys()))
|
||||
|
||||
model_fields: dict[str, Any] = {}
|
||||
for field_name, field_type in type_hints.items():
|
||||
if field_name not in required_keys:
|
||||
# For optional TypedDict fields, set default=None
|
||||
# This makes them not required in the Pydantic model
|
||||
# The model should use exclude_unset=True when dumping to get TypedDict semantics
|
||||
model_fields[field_name] = (field_type, None)
|
||||
else:
|
||||
model_fields[field_name] = field_type
|
||||
|
||||
return create_model(td_type.__name__, **model_fields)
|
||||
|
||||
|
||||
def _create_wrapped_model(func_name: str, annotation: Any) -> type[BaseModel]:
|
||||
"""Create a model that wraps a type in a 'result' field.
|
||||
|
||||
This is used for primitive types, generic types like list/dict, etc.
|
||||
"""
|
||||
model_name = f"{func_name}Output"
|
||||
|
||||
return create_model(model_name, result=annotation)
|
||||
|
||||
|
||||
def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]:
|
||||
"""Create a RootModel for dict[str, T] types."""
|
||||
|
||||
class DictModel(RootModel[dict_annotation]):
|
||||
pass
|
||||
|
||||
# Give it a meaningful name
|
||||
DictModel.__name__ = f"{func_name}DictOutput"
|
||||
DictModel.__qualname__ = f"{func_name}DictOutput"
|
||||
|
||||
return DictModel
|
||||
|
||||
|
||||
def _convert_to_content(
|
||||
result: Any,
|
||||
) -> Sequence[ContentBlock]:
|
||||
"""
|
||||
Convert a result to a sequence of content objects.
|
||||
|
||||
Note: This conversion logic comes from previous versions of FastMCP and is being
|
||||
retained for purposes of backwards compatibility. It produces different unstructured
|
||||
output than the lowlevel server tool call handler, which just serializes structured
|
||||
content verbatim.
|
||||
"""
|
||||
if result is None: # pragma: no cover
|
||||
return []
|
||||
|
||||
if isinstance(result, ContentBlock):
|
||||
return [result]
|
||||
|
||||
if isinstance(result, Image):
|
||||
return [result.to_image_content()]
|
||||
|
||||
if isinstance(result, Audio):
|
||||
return [result.to_audio_content()]
|
||||
|
||||
if isinstance(result, list | tuple):
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
_convert_to_content(item)
|
||||
for item in result # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
if not isinstance(result, str):
|
||||
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()
|
||||
|
||||
return [TextContent(type="text", text=result)]
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Logging utilities for FastMCP."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger nested under MCPnamespace.
|
||||
|
||||
Args:
|
||||
name: the name of the logger, which will be prefixed with 'FastMCP.'
|
||||
|
||||
Returns:
|
||||
a configured logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def configure_logging(
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
|
||||
) -> None:
|
||||
"""Configure logging for MCP.
|
||||
|
||||
Args:
|
||||
level: the log level to use
|
||||
"""
|
||||
handlers: list[logging.Handler] = []
|
||||
try: # pragma: no cover
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True))
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
|
||||
if not handlers: # pragma: no cover
|
||||
handlers.append(logging.StreamHandler())
|
||||
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(message)s",
|
||||
handlers=handlers,
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Common types used across FastMCP."""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from mcp.types import AudioContent, ImageContent
|
||||
|
||||
|
||||
class Image:
|
||||
"""Helper class for returning images from tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
data: bytes | None = None,
|
||||
format: str | None = None,
|
||||
):
|
||||
if path is None and data is None: # pragma: no cover
|
||||
raise ValueError("Either path or data must be provided")
|
||||
if path is not None and data is not None: # pragma: no cover
|
||||
raise ValueError("Only one of path or data can be provided")
|
||||
|
||||
self.path = Path(path) if path else None
|
||||
self.data = data
|
||||
self._format = format
|
||||
self._mime_type = self._get_mime_type()
|
||||
|
||||
def _get_mime_type(self) -> str:
|
||||
"""Get MIME type from format or guess from file extension."""
|
||||
if self._format: # pragma: no cover
|
||||
return f"image/{self._format.lower()}"
|
||||
|
||||
if self.path:
|
||||
suffix = self.path.suffix.lower()
|
||||
return {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
}.get(suffix, "application/octet-stream")
|
||||
return "image/png" # pragma: no cover # default for raw binary data
|
||||
|
||||
def to_image_content(self) -> ImageContent:
|
||||
"""Convert to MCP ImageContent."""
|
||||
if self.path:
|
||||
with open(self.path, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode()
|
||||
elif self.data is not None: # pragma: no cover
|
||||
data = base64.b64encode(self.data).decode()
|
||||
else: # pragma: no cover
|
||||
raise ValueError("No image data available")
|
||||
|
||||
return ImageContent(type="image", data=data, mimeType=self._mime_type)
|
||||
|
||||
|
||||
class Audio:
|
||||
"""Helper class for returning audio from tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
data: bytes | None = None,
|
||||
format: str | None = None,
|
||||
):
|
||||
if not bool(path) ^ bool(data): # pragma: no cover
|
||||
raise ValueError("Either path or data can be provided")
|
||||
|
||||
self.path = Path(path) if path else None
|
||||
self.data = data
|
||||
self._format = format
|
||||
self._mime_type = self._get_mime_type()
|
||||
|
||||
def _get_mime_type(self) -> str:
|
||||
"""Get MIME type from format or guess from file extension."""
|
||||
if self._format: # pragma: no cover
|
||||
return f"audio/{self._format.lower()}"
|
||||
|
||||
if self.path:
|
||||
suffix = self.path.suffix.lower()
|
||||
return {
|
||||
".wav": "audio/wav",
|
||||
".mp3": "audio/mpeg",
|
||||
".ogg": "audio/ogg",
|
||||
".flac": "audio/flac",
|
||||
".aac": "audio/aac",
|
||||
".m4a": "audio/mp4",
|
||||
}.get(suffix, "application/octet-stream")
|
||||
return "audio/wav" # pragma: no cover # default for raw binary data
|
||||
|
||||
def to_audio_content(self) -> AudioContent:
|
||||
"""Convert to MCP AudioContent."""
|
||||
if self.path:
|
||||
with open(self.path, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode()
|
||||
elif self.data is not None: # pragma: no cover
|
||||
data = base64.b64encode(self.data).decode()
|
||||
else: # pragma: no cover
|
||||
raise ValueError("No audio data available")
|
||||
|
||||
return AudioContent(type="audio", data=data, mimeType=self._mime_type)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .server import NotificationOptions, Server
|
||||
|
||||
__all__ = ["Server", "NotificationOptions"]
|
||||
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user