Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1,135 @@
from .client.session import ClientSession
from .client.session_group import ClientSessionGroup
from .client.stdio import StdioServerParameters, stdio_client
from .server.session import ServerSession
from .server.stdio import stdio_server
from .shared.exceptions import McpError, UrlElicitationRequiredError
from .types import (
CallToolRequest,
ClientCapabilities,
ClientNotification,
ClientRequest,
ClientResult,
CompleteRequest,
CreateMessageRequest,
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
GetPromptRequest,
GetPromptResult,
Implementation,
IncludeContext,
InitializedNotification,
InitializeRequest,
InitializeResult,
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
ListPromptsRequest,
ListPromptsResult,
ListResourcesRequest,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
LoggingMessageNotification,
Notification,
PingRequest,
ProgressNotification,
PromptsCapability,
ReadResourceRequest,
ReadResourceResult,
Resource,
ResourcesCapability,
ResourceUpdatedNotification,
RootsCapability,
SamplingCapability,
SamplingContent,
SamplingContextCapability,
SamplingMessage,
SamplingMessageContentBlock,
SamplingToolsCapability,
ServerCapabilities,
ServerNotification,
ServerRequest,
ServerResult,
SetLevelRequest,
StopReason,
SubscribeRequest,
Tool,
ToolChoice,
ToolResultContent,
ToolsCapability,
ToolUseContent,
UnsubscribeRequest,
)
from .types import (
Role as SamplingRole,
)
__all__ = [
"CallToolRequest",
"ClientCapabilities",
"ClientNotification",
"ClientRequest",
"ClientResult",
"ClientSession",
"ClientSessionGroup",
"CompleteRequest",
"CreateMessageRequest",
"CreateMessageResult",
"CreateMessageResultWithTools",
"ErrorData",
"GetPromptRequest",
"GetPromptResult",
"Implementation",
"IncludeContext",
"InitializeRequest",
"InitializeResult",
"InitializedNotification",
"JSONRPCError",
"JSONRPCRequest",
"JSONRPCResponse",
"ListPromptsRequest",
"ListPromptsResult",
"ListResourcesRequest",
"ListResourcesResult",
"ListToolsResult",
"LoggingLevel",
"LoggingMessageNotification",
"McpError",
"Notification",
"PingRequest",
"ProgressNotification",
"PromptsCapability",
"ReadResourceRequest",
"ReadResourceResult",
"Resource",
"ResourcesCapability",
"ResourceUpdatedNotification",
"RootsCapability",
"SamplingCapability",
"SamplingContent",
"SamplingContextCapability",
"SamplingMessage",
"SamplingMessageContentBlock",
"SamplingRole",
"SamplingToolsCapability",
"ServerCapabilities",
"ServerNotification",
"ServerRequest",
"ServerResult",
"ServerSession",
"SetLevelRequest",
"StdioServerParameters",
"StopReason",
"SubscribeRequest",
"Tool",
"ToolChoice",
"ToolResultContent",
"ToolsCapability",
"ToolUseContent",
"UnsubscribeRequest",
"UrlElicitationRequiredError",
"stdio_client",
"stdio_server",
]

View File

@@ -0,0 +1,6 @@
"""FastMCP CLI package."""
from .cli import app
if __name__ == "__main__": # pragma: no cover
app()

View File

@@ -0,0 +1,148 @@
"""Claude app integration utilities."""
import json
import os
import shutil
import sys
from pathlib import Path
from typing import Any
from mcp.server.fastmcp.utilities.logging import get_logger
logger = get_logger(__name__)
MCP_PACKAGE = "mcp[cli]"
def get_claude_config_path() -> Path | None: # pragma: no cover
"""Get the Claude config directory based on platform."""
if sys.platform == "win32":
path = Path(Path.home(), "AppData", "Roaming", "Claude")
elif sys.platform == "darwin":
path = Path(Path.home(), "Library", "Application Support", "Claude")
elif sys.platform.startswith("linux"):
path = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude")
else:
return None
if path.exists():
return path
return None
def get_uv_path() -> str:
"""Get the full path to the uv executable."""
uv_path = shutil.which("uv")
if not uv_path: # pragma: no cover
logger.error(
"uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH"
)
return "uv" # Fall back to just "uv" if not found
return uv_path
def update_claude_config(
file_spec: str,
server_name: str,
*,
with_editable: Path | None = None,
with_packages: list[str] | None = None,
env_vars: dict[str, str] | None = None,
) -> bool:
"""Add or update a FastMCP server in Claude's configuration.
Args:
file_spec: Path to the server file, optionally with :object suffix
server_name: Name for the server in Claude's config
with_editable: Optional directory to install in editable mode
with_packages: Optional list of additional packages to install
env_vars: Optional dictionary of environment variables. These are merged with
any existing variables, with new values taking precedence.
Raises:
RuntimeError: If Claude Desktop's config directory is not found, indicating
Claude Desktop may not be installed or properly set up.
"""
config_dir = get_claude_config_path()
uv_path = get_uv_path()
if not config_dir: # pragma: no cover
raise RuntimeError(
"Claude Desktop config directory not found. Please ensure Claude Desktop"
" is installed and has been run at least once to initialize its config."
)
config_file = config_dir / "claude_desktop_config.json"
if not config_file.exists(): # pragma: no cover
try:
config_file.write_text("{}")
except Exception:
logger.exception(
"Failed to create Claude config file",
extra={
"config_file": str(config_file),
},
)
return False
try:
config = json.loads(config_file.read_text())
if "mcpServers" not in config:
config["mcpServers"] = {}
# Always preserve existing env vars and merge with new ones
if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: # pragma: no cover
existing_env = config["mcpServers"][server_name]["env"]
if env_vars:
# New vars take precedence over existing ones
env_vars = {**existing_env, **env_vars}
else:
env_vars = existing_env
# Build uv run command
args = ["run", "--frozen"]
# Collect all packages in a set to deduplicate
packages = {MCP_PACKAGE}
if with_packages: # pragma: no cover
packages.update(pkg for pkg in with_packages if pkg)
# Add all packages with --with
for pkg in sorted(packages):
args.extend(["--with", pkg])
if with_editable: # pragma: no cover
args.extend(["--with-editable", str(with_editable)])
# Convert file path to absolute before adding to command
# Split off any :object suffix first
if ":" in file_spec:
file_path, server_object = file_spec.rsplit(":", 1)
file_spec = f"{Path(file_path).resolve()}:{server_object}"
else: # pragma: no cover
file_spec = str(Path(file_spec).resolve())
# Add fastmcp run command
args.extend(["mcp", "run", file_spec])
server_config: dict[str, Any] = {"command": uv_path, "args": args}
# Add environment variables if specified
if env_vars: # pragma: no cover
server_config["env"] = env_vars
config["mcpServers"][server_name] = server_config
config_file.write_text(json.dumps(config, indent=2))
logger.info(
f"Added server '{server_name}' to Claude config",
extra={"config_file": str(config_file)},
)
return True
except Exception: # pragma: no cover
logger.exception(
"Failed to update Claude config",
extra={
"config_file": str(config_file),
},
)
return False

View File

@@ -0,0 +1,488 @@
"""MCP CLI tools."""
import importlib.metadata
import importlib.util
import os
import subprocess
import sys
from pathlib import Path
from typing import Annotated, Any
from mcp.server import FastMCP
from mcp.server import Server as LowLevelServer
try:
import typer
except ImportError: # pragma: no cover
print("Error: typer is required. Install with 'pip install mcp[cli]'")
sys.exit(1)
try:
from mcp.cli import claude
from mcp.server.fastmcp.utilities.logging import get_logger
except ImportError: # pragma: no cover
print("Error: mcp.server.fastmcp is not installed or not in PYTHONPATH")
sys.exit(1)
try:
import dotenv
except ImportError: # pragma: no cover
dotenv = None
logger = get_logger("cli")
app = typer.Typer(
name="mcp",
help="MCP development tools",
add_completion=False,
no_args_is_help=True, # Show help if no args provided
)
def _get_npx_command():
"""Get the correct npx command for the current platform."""
if sys.platform == "win32":
# Try both npx.cmd and npx.exe on Windows
for cmd in ["npx.cmd", "npx.exe", "npx"]:
try:
subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True)
return cmd
except subprocess.CalledProcessError:
continue
return None
return "npx" # On Unix-like systems, just use npx
def _parse_env_var(env_var: str) -> tuple[str, str]: # pragma: no cover
"""Parse environment variable string in format KEY=VALUE."""
if "=" not in env_var:
logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE")
sys.exit(1)
key, value = env_var.split("=", 1)
return key.strip(), value.strip()
def _build_uv_command(
file_spec: str,
with_editable: Path | None = None,
with_packages: list[str] | None = None,
) -> list[str]:
"""Build the uv run command that runs an MCP server through mcp run."""
cmd = ["uv"]
cmd.extend(["run", "--with", "mcp"])
if with_editable:
cmd.extend(["--with-editable", str(with_editable)])
if with_packages:
for pkg in with_packages:
if pkg: # pragma: no cover
cmd.extend(["--with", pkg])
# Add mcp run command
cmd.extend(["mcp", "run", file_spec])
return cmd
def _parse_file_path(file_spec: str) -> tuple[Path, str | None]:
"""Parse a file path that may include a server object specification.
Args:
file_spec: Path to file, optionally with :object suffix
Returns:
Tuple of (file_path, server_object)
"""
# First check if we have a Windows path (e.g., C:\...)
has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":"
# Split on the last colon, but only if it's not part of the Windows drive letter
# and there's actually another colon in the string after the drive letter
if ":" in (file_spec[2:] if has_windows_drive else file_spec):
file_str, server_object = file_spec.rsplit(":", 1)
else:
file_str, server_object = file_spec, None
# Resolve the file path
file_path = Path(file_str).expanduser().resolve()
if not file_path.exists():
logger.error(f"File not found: {file_path}")
sys.exit(1)
if not file_path.is_file():
logger.error(f"Not a file: {file_path}")
sys.exit(1)
return file_path, server_object
def _import_server(file: Path, server_object: str | None = None): # pragma: no cover
"""Import an MCP server from a file.
Args:
file: Path to the file
server_object: Optional object name in format "module:object" or just "object"
Returns:
The server object
"""
# Add parent directory to Python path so imports can be resolved
file_dir = str(file.parent)
if file_dir not in sys.path:
sys.path.insert(0, file_dir)
# Import the module
spec = importlib.util.spec_from_file_location("server_module", file)
if not spec or not spec.loader:
logger.error("Could not load module", extra={"file": str(file)})
sys.exit(1)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
def _check_server_object(server_object: Any, object_name: str):
"""Helper function to check that the server object is supported
Args:
server_object: The server object to check.
Returns:
True if it's supported.
"""
if not isinstance(server_object, FastMCP):
logger.error(f"The server object {object_name} is of type {type(server_object)} (expecting {FastMCP}).")
if isinstance(server_object, LowLevelServer):
logger.warning(
"Note that only FastMCP server is supported. Low level Server class is not yet supported."
)
return False
return True
# If no object specified, try common server names
if not server_object:
# Look for the most common server object names
for name in ["mcp", "server", "app"]:
if hasattr(module, name):
if not _check_server_object(getattr(module, name), f"{file}:{name}"):
logger.error(f"Ignoring object '{file}:{name}' as it's not a valid server object")
continue
return getattr(module, name)
logger.error(
f"No server object found in {file}. Please either:\n"
"1. Use a standard variable name (mcp, server, or app)\n"
"2. Specify the object name with file:object syntax"
"3. If the server creates the FastMCP object within main() "
" or another function, refactor the FastMCP object to be a "
" global variable named mcp, server, or app.",
extra={"file": str(file)},
)
sys.exit(1)
# Handle module:object syntax
if ":" in server_object:
module_name, object_name = server_object.split(":", 1)
try:
server_module = importlib.import_module(module_name)
server = getattr(server_module, object_name, None)
except ImportError:
logger.error(
f"Could not import module '{module_name}'",
extra={"file": str(file)},
)
sys.exit(1)
else:
# Just object name
server = getattr(module, server_object, None)
if server is None:
logger.error(
f"Server object '{server_object}' not found",
extra={"file": str(file)},
)
sys.exit(1)
if not _check_server_object(server, server_object):
sys.exit(1)
return server
@app.command()
def version() -> None: # pragma: no cover
"""Show the MCP version."""
try:
version = importlib.metadata.version("mcp")
print(f"MCP version {version}")
except importlib.metadata.PackageNotFoundError:
print("MCP version unknown (package not installed)")
sys.exit(1)
@app.command()
def dev(
file_spec: str = typer.Argument(
...,
help="Python file to run, optionally with :object suffix",
),
with_editable: Annotated[
Path | None,
typer.Option(
"--with-editable",
"-e",
help="Directory containing pyproject.toml to install in editable mode",
exists=True,
file_okay=False,
resolve_path=True,
),
] = None,
with_packages: Annotated[
list[str],
typer.Option(
"--with",
help="Additional packages to install",
),
] = [],
) -> None: # pragma: no cover
"""Run an MCP server with the MCP Inspector."""
file, server_object = _parse_file_path(file_spec)
logger.debug(
"Starting dev server",
extra={
"file": str(file),
"server_object": server_object,
"with_editable": str(with_editable) if with_editable else None,
"with_packages": with_packages,
},
)
try:
# Import server to get dependencies
server = _import_server(file, server_object)
if hasattr(server, "dependencies"):
with_packages = list(set(with_packages + server.dependencies))
uv_cmd = _build_uv_command(file_spec, with_editable, with_packages)
# Get the correct npx command
npx_cmd = _get_npx_command()
if not npx_cmd:
logger.error(
"npx not found. Please ensure Node.js and npm are properly installed and added to your system PATH."
)
sys.exit(1)
# Run the MCP Inspector command with shell=True on Windows
shell = sys.platform == "win32"
process = subprocess.run(
[npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd,
check=True,
shell=shell,
env=dict(os.environ.items()), # Convert to list of tuples for env update
)
sys.exit(process.returncode)
except subprocess.CalledProcessError as e:
logger.error(
"Dev server failed",
extra={
"file": str(file),
"error": str(e),
"returncode": e.returncode,
},
)
sys.exit(e.returncode)
except FileNotFoundError:
logger.error(
"npx not found. Please ensure Node.js and npm are properly installed "
"and added to your system PATH. You may need to restart your terminal "
"after installation.",
extra={"file": str(file)},
)
sys.exit(1)
@app.command()
def run(
file_spec: str = typer.Argument(
...,
help="Python file to run, optionally with :object suffix",
),
transport: Annotated[
str | None,
typer.Option(
"--transport",
"-t",
help="Transport protocol to use (stdio or sse)",
),
] = None,
) -> None: # pragma: no cover
"""Run an MCP server.
The server can be specified in two ways:\n
1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n
2. Import approach: server.py:app - imports and runs the specified server object.\n\n
Note: This command runs the server directly. You are responsible for ensuring
all dependencies are available.\n
For dependency management, use `mcp install` or `mcp dev` instead.
""" # noqa: E501
file, server_object = _parse_file_path(file_spec)
logger.debug(
"Running server",
extra={
"file": str(file),
"server_object": server_object,
"transport": transport,
},
)
try:
# Import and get server object
server = _import_server(file, server_object)
# Run the server
kwargs = {}
if transport:
kwargs["transport"] = transport
server.run(**kwargs)
except Exception:
logger.exception(
"Failed to run server",
extra={
"file": str(file),
},
)
sys.exit(1)
@app.command()
def install(
file_spec: str = typer.Argument(
...,
help="Python file to run, optionally with :object suffix",
),
server_name: Annotated[
str | None,
typer.Option(
"--name",
"-n",
help="Custom name for the server (defaults to server's name attribute or file name)",
),
] = None,
with_editable: Annotated[
Path | None,
typer.Option(
"--with-editable",
"-e",
help="Directory containing pyproject.toml to install in editable mode",
exists=True,
file_okay=False,
resolve_path=True,
),
] = None,
with_packages: Annotated[
list[str],
typer.Option(
"--with",
help="Additional packages to install",
),
] = [],
env_vars: Annotated[
list[str],
typer.Option(
"--env-var",
"-v",
help="Environment variables in KEY=VALUE format",
),
] = [],
env_file: Annotated[
Path | None,
typer.Option(
"--env-file",
"-f",
help="Load environment variables from a .env file",
exists=True,
file_okay=True,
dir_okay=False,
resolve_path=True,
),
] = None,
) -> None: # pragma: no cover
"""Install an MCP server in the Claude desktop app.
Environment variables are preserved once added and only updated if new values
are explicitly provided.
"""
file, server_object = _parse_file_path(file_spec)
logger.debug(
"Installing server",
extra={
"file": str(file),
"server_name": server_name,
"server_object": server_object,
"with_editable": str(with_editable) if with_editable else None,
"with_packages": with_packages,
},
)
if not claude.get_claude_config_path():
logger.error("Claude app not found")
sys.exit(1)
# Try to import server to get its name, but fall back to file name if dependencies
# missing
name = server_name
server = None
if not name:
try:
server = _import_server(file, server_object)
name = server.name
except (ImportError, ModuleNotFoundError) as e:
logger.debug(
"Could not import server (likely missing dependencies), using file name",
extra={"error": str(e)},
)
name = file.stem
# Get server dependencies if available
server_dependencies = getattr(server, "dependencies", []) if server else []
if server_dependencies:
with_packages = list(set(with_packages + server_dependencies))
# Process environment variables if provided
env_dict: dict[str, str] | None = None
if env_file or env_vars:
env_dict = {}
# Load from .env file if specified
if env_file:
if dotenv:
try:
env_dict |= {k: v for k, v in dotenv.dotenv_values(env_file).items() if v is not None}
except (OSError, ValueError):
logger.exception("Failed to load .env file")
sys.exit(1)
else:
logger.error("python-dotenv is not installed. Cannot load .env file.")
sys.exit(1)
# Add command line environment variables
for env_var in env_vars:
key, value = _parse_env_var(env_var)
env_dict[key] = value
if claude.update_claude_config(
file_spec,
name,
with_editable=with_editable,
with_packages=with_packages,
env_vars=env_dict,
):
logger.info(f"Successfully installed {name} in Claude app")
else:
logger.error(f"Failed to install {name} in Claude app")
sys.exit(1)

View File

@@ -0,0 +1,85 @@
import argparse
import logging
import sys
from functools import partial
from urllib.parse import urlparse
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("client")
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
logger.error("Error: %s", message)
return
logger.info("Received message from server: %s", message)
async def run_session(
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
client_info: types.Implementation | None = None,
):
async with ClientSession(
read_stream,
write_stream,
message_handler=message_handler,
client_info=client_info,
) as session:
logger.info("Initializing session")
await session.initialize()
logger.info("Initialized")
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
env_dict = dict(env)
if urlparse(command_or_url).scheme in ("http", "https"):
# Use SSE client for HTTP(S) URLs
async with sse_client(command_or_url) as streams:
await run_session(*streams)
else:
# Use stdio client for commands
server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict)
async with stdio_client(server_parameters) as streams:
await run_session(*streams)
def cli():
parser = argparse.ArgumentParser()
parser.add_argument("command_or_url", help="Command or URL to connect to")
parser.add_argument("args", nargs="*", help="Additional arguments")
parser.add_argument(
"-e",
"--env",
nargs=2,
action="append",
metavar=("KEY", "VALUE"),
help="Environment variables to set. Can be used multiple times.",
default=[],
)
args = parser.parse_args()
anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio")
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,21 @@
"""
OAuth2 Authentication implementation for HTTPX.
Implements authorization code flow with PKCE and automatic token refresh.
"""
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
from mcp.client.auth.oauth2 import (
OAuthClientProvider,
PKCEParameters,
TokenStorage,
)
__all__ = [
"OAuthClientProvider",
"OAuthFlowError",
"OAuthRegistrationError",
"OAuthTokenError",
"PKCEParameters",
"TokenStorage",
]

View File

@@ -0,0 +1,10 @@
class OAuthFlowError(Exception):
"""Base exception for OAuth flow errors."""
class OAuthTokenError(OAuthFlowError):
"""Raised when token operations fail."""
class OAuthRegistrationError(OAuthFlowError):
"""Raised when client registration fails."""

View File

@@ -0,0 +1,487 @@
"""
OAuth client credential extensions for MCP.
Provides OAuth providers for machine-to-machine authentication flows:
- ClientCredentialsOAuthProvider: For client_credentials with client_id + client_secret
- PrivateKeyJWTOAuthProvider: For client_credentials with private_key_jwt authentication
(typically using a pre-built JWT from workload identity federation)
- RFC7523OAuthClientProvider: For jwt-bearer grant (RFC 7523 Section 2.1)
"""
import time
from collections.abc import Awaitable, Callable
from typing import Any, Literal
from uuid import uuid4
import httpx
import jwt
from pydantic import BaseModel, Field
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
class ClientCredentialsOAuthProvider(OAuthClientProvider):
"""OAuth provider for client_credentials grant with client_id + client_secret.
This provider sets client_info directly, bypassing dynamic client registration.
Use this when you already have client credentials (client_id and client_secret).
Example:
```python
provider = ClientCredentialsOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
client_secret="my-client-secret",
)
```
"""
def __init__(
self,
server_url: str,
storage: TokenStorage,
client_id: str,
client_secret: str,
token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post"] = "client_secret_basic",
scopes: str | None = None,
) -> None:
"""Initialize client_credentials OAuth provider.
Args:
server_url: The MCP server URL.
storage: Token storage implementation.
client_id: The OAuth client ID.
client_secret: The OAuth client secret.
token_endpoint_auth_method: Authentication method for token endpoint.
Either "client_secret_basic" (default) or "client_secret_post".
scopes: Optional space-separated list of scopes to request.
"""
# Build minimal client_metadata for the base class
client_metadata = OAuthClientMetadata(
redirect_uris=None,
grant_types=["client_credentials"],
token_endpoint_auth_method=token_endpoint_auth_method,
scope=scopes,
)
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
# Store client_info to be set during _initialize - no dynamic registration needed
self._fixed_client_info = OAuthClientInformationFull(
redirect_uris=None,
client_id=client_id,
client_secret=client_secret,
grant_types=["client_credentials"],
token_endpoint_auth_method=token_endpoint_auth_method,
scope=scopes,
)
async def _initialize(self) -> None:
"""Load stored tokens and set pre-configured client_info."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = self._fixed_client_info
self._initialized = True
async def _perform_authorization(self) -> httpx.Request:
"""Perform client_credentials authorization."""
return await self._exchange_token_client_credentials()
async def _exchange_token_client_credentials(self) -> httpx.Request:
"""Build token exchange request for client_credentials grant."""
token_data: dict[str, Any] = {
"grant_type": "client_credentials",
}
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
# Use standard auth methods (client_secret_basic, client_secret_post, none)
token_data, headers = self.context.prepare_token_auth(token_data, headers)
if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url()
if self.context.client_metadata.scope:
token_data["scope"] = self.context.client_metadata.scope
token_url = self._get_token_endpoint()
return httpx.Request("POST", token_url, data=token_data, headers=headers)
def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]:
"""Create an assertion provider that returns a static JWT token.
Use this when you have a pre-built JWT (e.g., from workload identity federation)
that doesn't need the audience parameter.
Example:
```python
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=static_assertion_provider(my_prebuilt_jwt),
)
```
Args:
token: The pre-built JWT assertion string.
Returns:
An async callback suitable for use as an assertion_provider.
"""
async def provider(audience: str) -> str:
return token
return provider
class SignedJWTParameters(BaseModel):
"""Parameters for creating SDK-signed JWT assertions.
Use `create_assertion_provider()` to create an assertion provider callback
for use with `PrivateKeyJWTOAuthProvider`.
Example:
```python
jwt_params = SignedJWTParameters(
issuer="my-client-id",
subject="my-client-id",
signing_key=private_key_pem,
)
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=jwt_params.create_assertion_provider(),
)
```
"""
issuer: str = Field(description="Issuer for JWT assertions (typically client_id).")
subject: str = Field(description="Subject identifier for JWT assertions (typically client_id).")
signing_key: str = Field(description="Private key for JWT signing (PEM format).")
signing_algorithm: str = Field(default="RS256", description="Algorithm for signing JWT assertions.")
lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
additional_claims: dict[str, Any] | None = Field(default=None, description="Additional claims.")
def create_assertion_provider(self) -> Callable[[str], Awaitable[str]]:
"""Create an assertion provider callback for use with PrivateKeyJWTOAuthProvider.
Returns:
An async callback that takes the audience (authorization server issuer URL)
and returns a signed JWT assertion.
"""
async def provider(audience: str) -> str:
now = int(time.time())
claims: dict[str, Any] = {
"iss": self.issuer,
"sub": self.subject,
"aud": audience,
"exp": now + self.lifetime_seconds,
"iat": now,
"jti": str(uuid4()),
}
if self.additional_claims:
claims.update(self.additional_claims)
return jwt.encode(claims, self.signing_key, algorithm=self.signing_algorithm)
return provider
class PrivateKeyJWTOAuthProvider(OAuthClientProvider):
"""OAuth provider for client_credentials grant with private_key_jwt authentication.
Uses RFC 7523 Section 2.2 for client authentication via JWT assertion.
The JWT assertion's audience MUST be the authorization server's issuer identifier
(per RFC 7523bis security updates). The `assertion_provider` callback receives
this audience value and must return a JWT with that audience.
**Option 1: Pre-built JWT via Workload Identity Federation**
In production scenarios, the JWT assertion is typically obtained from a workload
identity provider (e.g., GCP, AWS IAM, Azure AD):
```python
async def get_workload_identity_token(audience: str) -> str:
# Fetch JWT from your identity provider
# The JWT's audience must match the provided audience parameter
return await fetch_token_from_identity_provider(audience=audience)
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=get_workload_identity_token,
)
```
**Option 2: Static pre-built JWT**
If you have a static JWT that doesn't need the audience parameter:
```python
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=static_assertion_provider(my_prebuilt_jwt),
)
```
**Option 3: SDK-signed JWT (for testing/simple setups)**
For testing or simple deployments, use `SignedJWTParameters.create_assertion_provider()`:
```python
jwt_params = SignedJWTParameters(
issuer="my-client-id",
subject="my-client-id",
signing_key=private_key_pem,
)
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=jwt_params.create_assertion_provider(),
)
```
"""
def __init__(
self,
server_url: str,
storage: TokenStorage,
client_id: str,
assertion_provider: Callable[[str], Awaitable[str]],
scopes: str | None = None,
) -> None:
"""Initialize private_key_jwt OAuth provider.
Args:
server_url: The MCP server URL.
storage: Token storage implementation.
client_id: The OAuth client ID.
assertion_provider: Async callback that takes the audience (authorization
server's issuer identifier) and returns a JWT assertion. Use
`SignedJWTParameters.create_assertion_provider()` for SDK-signed JWTs,
`static_assertion_provider()` for pre-built JWTs, or provide your own
callback for workload identity federation.
scopes: Optional space-separated list of scopes to request.
"""
# Build minimal client_metadata for the base class
client_metadata = OAuthClientMetadata(
redirect_uris=None,
grant_types=["client_credentials"],
token_endpoint_auth_method="private_key_jwt",
scope=scopes,
)
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
self._assertion_provider = assertion_provider
# Store client_info to be set during _initialize - no dynamic registration needed
self._fixed_client_info = OAuthClientInformationFull(
redirect_uris=None,
client_id=client_id,
grant_types=["client_credentials"],
token_endpoint_auth_method="private_key_jwt",
scope=scopes,
)
async def _initialize(self) -> None:
"""Load stored tokens and set pre-configured client_info."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = self._fixed_client_info
self._initialized = True
async def _perform_authorization(self) -> httpx.Request:
"""Perform client_credentials authorization with private_key_jwt."""
return await self._exchange_token_client_credentials()
async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None:
"""Add JWT assertion for client authentication to token endpoint parameters."""
if not self.context.oauth_metadata:
raise OAuthFlowError("Missing OAuth metadata for private_key_jwt flow") # pragma: no cover
# Audience MUST be the issuer identifier of the authorization server
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01
audience = str(self.context.oauth_metadata.issuer)
assertion = await self._assertion_provider(audience)
# RFC 7523 Section 2.2: client authentication via JWT
token_data["client_assertion"] = assertion
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
async def _exchange_token_client_credentials(self) -> httpx.Request:
"""Build token exchange request for client_credentials grant with private_key_jwt."""
token_data: dict[str, Any] = {
"grant_type": "client_credentials",
}
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
# Add JWT client authentication (RFC 7523 Section 2.2)
await self._add_client_authentication_jwt(token_data=token_data)
if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url()
if self.context.client_metadata.scope:
token_data["scope"] = self.context.client_metadata.scope
token_url = self._get_token_endpoint()
return httpx.Request("POST", token_url, data=token_data, headers=headers)
class JWTParameters(BaseModel):
"""JWT parameters."""
assertion: str | None = Field(
default=None,
description="JWT assertion for JWT authentication. "
"Will be used instead of generating a new assertion if provided.",
)
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
if self.assertion is not None:
# Prebuilt JWT (e.g. acquired out-of-band)
assertion = self.assertion
else:
if not self.jwt_signing_key:
raise OAuthFlowError("Missing signing key for JWT bearer grant") # pragma: no cover
if not self.issuer:
raise OAuthFlowError("Missing issuer for JWT bearer grant") # pragma: no cover
if not self.subject:
raise OAuthFlowError("Missing subject for JWT bearer grant") # pragma: no cover
audience = self.audience if self.audience else with_audience_fallback
if not audience:
raise OAuthFlowError("Missing audience for JWT bearer grant") # pragma: no cover
now = int(time.time())
claims: dict[str, Any] = {
"iss": self.issuer,
"sub": self.subject,
"aud": audience,
"exp": now + self.jwt_lifetime_seconds,
"iat": now,
"jti": str(uuid4()),
}
claims.update(self.claims or {})
assertion = jwt.encode(
claims,
self.jwt_signing_key,
algorithm=self.jwt_signing_algorithm or "RS256",
)
return assertion
class RFC7523OAuthClientProvider(OAuthClientProvider):
"""OAuth client provider for RFC 7523 jwt-bearer grant.
.. deprecated::
Use :class:`ClientCredentialsOAuthProvider` for client_credentials with
client_id + client_secret, or :class:`PrivateKeyJWTOAuthProvider` for
client_credentials with private_key_jwt authentication instead.
This provider supports the jwt-bearer authorization grant (RFC 7523 Section 2.1)
where the JWT itself is the authorization grant.
"""
def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
jwt_parameters: JWTParameters | None = None,
) -> None:
import warnings
warnings.warn(
"RFC7523OAuthClientProvider is deprecated. Use ClientCredentialsOAuthProvider "
"or PrivateKeyJWTOAuthProvider instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
self.jwt_parameters = jwt_parameters
async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
) -> httpx.Request: # pragma: no cover
"""Build token exchange request for authorization_code flow."""
token_data = token_data or {}
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
self._add_client_authentication_jwt(token_data=token_data)
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
"""Perform the authorization flow."""
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
token_request = await self._exchange_token_jwt_bearer()
return token_request
else:
return await super()._perform_authorization()
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
"""Add JWT assertion for client authentication to token endpoint parameters."""
if not self.jwt_parameters:
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
if not self.context.oauth_metadata:
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
# We need to set the audience to the issuer identifier of the authorization server
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
issuer = str(self.context.oauth_metadata.issuer)
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
token_data["client_assertion"] = assertion
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
# We need to set the audience to the resource server, the audience is difference from the one in claims
# it represents the resource server that will validate the token
token_data["audience"] = self.context.get_resource_url()
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
"""Build token exchange request for JWT bearer grant."""
if not self.context.client_info:
raise OAuthFlowError("Missing client info") # pragma: no cover
if not self.jwt_parameters:
raise OAuthFlowError("Missing JWT parameters") # pragma: no cover
if not self.context.oauth_metadata:
raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover
# We need to set the audience to the issuer identifier of the authorization server
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
issuer = str(self.context.oauth_metadata.issuer)
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
token_data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": assertion,
}
if self.context.should_include_resource_param(self.context.protocol_version): # pragma: no branch
token_data["resource"] = self.context.get_resource_url()
if self.context.client_metadata.scope: # pragma: no branch
token_data["scope"] = self.context.client_metadata.scope
token_url = self._get_token_endpoint()
return httpx.Request(
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
)

View File

@@ -0,0 +1,616 @@
"""
OAuth2 Authentication implementation for HTTPX.
Implements authorization code flow with PKCE and automatic token refresh.
"""
import base64
import hashlib
import logging
import secrets
import string
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
import anyio
import httpx
from pydantic import BaseModel, Field, ValidationError
from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
create_client_info_from_metadata_url,
create_client_registration_request,
create_oauth_metadata_request,
extract_field_from_www_auth,
extract_resource_metadata_from_www_auth,
extract_scope_from_www_auth,
get_client_metadata_scopes,
handle_auth_metadata_response,
handle_protected_resource_response,
handle_registration_response,
handle_token_response_scopes,
is_valid_client_metadata_url,
should_use_client_metadata_url,
)
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
ProtectedResourceMetadata,
)
from mcp.shared.auth_utils import (
calculate_token_expiry,
check_resource_allowed,
resource_url_from_server_url,
)
logger = logging.getLogger(__name__)
class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""
code_verifier: str = Field(..., min_length=43, max_length=128)
code_challenge: str = Field(..., min_length=43, max_length=128)
@classmethod
def generate(cls) -> "PKCEParameters":
"""Generate new PKCE parameters."""
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
class TokenStorage(Protocol):
"""Protocol for token storage implementations."""
async def get_tokens(self) -> OAuthToken | None:
"""Get stored tokens."""
...
async def set_tokens(self, tokens: OAuthToken) -> None:
"""Store tokens."""
...
async def get_client_info(self) -> OAuthClientInformationFull | None:
"""Get stored client information."""
...
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Store client information."""
...
@dataclass
class OAuthContext:
"""OAuth flow context."""
server_url: str
client_metadata: OAuthClientMetadata
storage: TokenStorage
redirect_handler: Callable[[str], Awaitable[None]] | None
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
timeout: float = 300.0
client_metadata_url: str | None = None
# Discovered metadata
protected_resource_metadata: ProtectedResourceMetadata | None = None
oauth_metadata: OAuthMetadata | None = None
auth_server_url: str | None = None
protocol_version: str | None = None
# Client registration
client_info: OAuthClientInformationFull | None = None
# Token management
current_tokens: OAuthToken | None = None
token_expiry_time: float | None = None
# State
lock: anyio.Lock = field(default_factory=anyio.Lock)
def get_authorization_base_url(self, server_url: str) -> str:
"""Extract base URL by removing path component."""
parsed = urlparse(server_url)
return f"{parsed.scheme}://{parsed.netloc}"
def update_token_expiry(self, token: OAuthToken) -> None:
"""Update token expiry time using shared util function."""
self.token_expiry_time = calculate_token_expiry(token.expires_in)
def is_token_valid(self) -> bool:
"""Check if current token is valid."""
return bool(
self.current_tokens
and self.current_tokens.access_token
and (not self.token_expiry_time or time.time() <= self.token_expiry_time)
)
def can_refresh_token(self) -> bool:
"""Check if token can be refreshed."""
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
def clear_tokens(self) -> None:
"""Clear current tokens."""
self.current_tokens = None
self.token_expiry_time = None
def get_resource_url(self) -> str:
"""Get resource URL for RFC 8707.
Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
"""
resource = resource_url_from_server_url(self.server_url)
# If PRM provides a resource that's a valid parent, use it
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
prm_resource = str(self.protected_resource_metadata.resource)
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
resource = prm_resource
return resource
def should_include_resource_param(self, protocol_version: str | None = None) -> bool:
"""Determine if the resource parameter should be included in OAuth requests.
Returns True if:
- Protected resource metadata is available, OR
- MCP-Protocol-Version header is 2025-06-18 or later
"""
# If we have protected resource metadata, include the resource param
if self.protected_resource_metadata is not None:
return True
# If no protocol version provided, don't include resource param
if not protocol_version:
return False
# Check if protocol version is 2025-06-18 or later
# Version format is YYYY-MM-DD, so string comparison works
return protocol_version >= "2025-06-18"
def prepare_token_auth(
self, data: dict[str, str], headers: dict[str, str] | None = None
) -> tuple[dict[str, str], dict[str, str]]:
"""Prepare authentication for token requests.
Args:
data: The form data to send
headers: Optional headers dict to update
Returns:
Tuple of (updated_data, updated_headers)
"""
if headers is None:
headers = {} # pragma: no cover
if not self.client_info:
return data, headers # pragma: no cover
auth_method = self.client_info.token_endpoint_auth_method
if auth_method == "client_secret_basic" and self.client_info.client_id and self.client_info.client_secret:
# URL-encode client ID and secret per RFC 6749 Section 2.3.1
encoded_id = quote(self.client_info.client_id, safe="")
encoded_secret = quote(self.client_info.client_secret, safe="")
credentials = f"{encoded_id}:{encoded_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
# Don't include client_secret in body for basic auth
data = {k: v for k, v in data.items() if k != "client_secret"}
elif auth_method == "client_secret_post" and self.client_info.client_secret:
# Include client_secret in request body
data["client_secret"] = self.client_info.client_secret
# For auth_method == "none", don't add any client_secret
return data, headers
class OAuthClientProvider(httpx.Auth):
"""
OAuth2 authentication for httpx.
Handles OAuth flow with automatic client registration and token storage.
"""
requires_response_body = True
def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
client_metadata_url: str | None = None,
):
"""Initialize OAuth2 authentication.
Args:
server_url: The MCP server URL.
client_metadata: OAuth client metadata for registration.
storage: Token storage implementation.
redirect_handler: Handler for authorization redirects.
callback_handler: Handler for authorization callbacks.
timeout: Timeout for the OAuth flow.
client_metadata_url: URL-based client ID. When provided and the server
advertises client_id_metadata_document_supported=true, this URL will be
used as the client_id instead of performing dynamic client registration.
Must be a valid HTTPS URL with a non-root pathname.
Raises:
ValueError: If client_metadata_url is provided but not a valid HTTPS URL
with a non-root pathname.
"""
# Validate client_metadata_url if provided
if client_metadata_url is not None and not is_valid_client_metadata_url(client_metadata_url):
raise ValueError(
f"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {client_metadata_url}"
)
self.context = OAuthContext(
server_url=server_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
timeout=timeout,
client_metadata_url=client_metadata_url,
)
self._initialized = False
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
"""
Handle protected resource metadata discovery response.
Per SEP-985, supports fallback when discovery fails at one URL.
Returns:
True if metadata was successfully discovered, False if we should try next URL
"""
if response.status_code == 200:
try:
content = await response.aread()
metadata = ProtectedResourceMetadata.model_validate_json(content)
self.context.protected_resource_metadata = metadata
if metadata.authorization_servers: # pragma: no branch
self.context.auth_server_url = str(metadata.authorization_servers[0])
return True
except ValidationError: # pragma: no cover
# Invalid metadata - try next URL
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
return False
elif response.status_code == 404: # pragma: no cover
# Not found - try next URL in fallback chain
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
return False
else:
# Other error - fail immediately
raise OAuthFlowError(
f"Protected Resource Metadata request failed: {response.status_code}"
) # pragma: no cover
async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
auth_code, code_verifier = await self._perform_authorization_code_grant()
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
return token_request
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
"""Perform the authorization redirect and get auth code."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
if not self.context.redirect_handler:
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
if not self.context.callback_handler:
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
auth_endpoint = urljoin(auth_base_url, "/authorize")
if not self.context.client_info:
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
# Generate PKCE parameters
pkce_params = PKCEParameters.generate()
state = secrets.token_urlsafe(32)
auth_params = {
"response_type": "code",
"client_id": self.context.client_info.client_id,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"state": state,
"code_challenge": pkce_params.code_challenge,
"code_challenge_method": "S256",
}
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover
if self.context.client_metadata.scope: # pragma: no branch
auth_params["scope"] = self.context.client_metadata.scope
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
await self.context.redirect_handler(authorization_url)
# Wait for callback
auth_code, returned_state = await self.context.callback_handler()
if returned_state is None or not secrets.compare_digest(returned_state, state):
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover
if not auth_code:
raise OAuthFlowError("No authorization code received") # pragma: no cover
# Return auth code and code verifier for token exchange
return auth_code, pkce_params.code_verifier
def _get_token_endpoint(self) -> str:
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
token_url = str(self.context.oauth_metadata.token_endpoint)
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
token_url = urljoin(auth_base_url, "/token")
return token_url
async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
) -> httpx.Request:
"""Build token exchange request for authorization_code flow."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
if not self.context.client_info:
raise OAuthFlowError("Missing client info") # pragma: no cover
token_url = self._get_token_endpoint()
token_data = token_data or {}
token_data.update(
{
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"client_id": self.context.client_info.client_id,
"code_verifier": code_verifier,
}
)
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url() # RFC 8707
# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
token_data, headers = self.context.prepare_token_auth(token_data, headers)
return httpx.Request("POST", token_url, data=token_data, headers=headers)
async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
if response.status_code != 200:
body = await response.aread() # pragma: no cover
body_text = body.decode("utf-8") # pragma: no cover
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
# Parse and validate response with scope validation
token_response = await handle_token_response_scopes(response)
# Store tokens in context
self.context.current_tokens = token_response
self.context.update_token_expiry(token_response)
await self.context.storage.set_tokens(token_response)
async def _refresh_token(self) -> httpx.Request:
"""Build token refresh request."""
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
raise OAuthTokenError("No refresh token available") # pragma: no cover
if not self.context.client_info or not self.context.client_info.client_id:
raise OAuthTokenError("No client info available") # pragma: no cover
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
token_url = urljoin(auth_base_url, "/token")
refresh_data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": self.context.current_tokens.refresh_token,
"client_id": self.context.client_info.client_id,
}
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
"""Handle token refresh response. Returns True if successful."""
if response.status_code != 200:
logger.warning(f"Token refresh failed: {response.status_code}")
self.context.clear_tokens()
return False
try:
content = await response.aread()
token_response = OAuthToken.model_validate_json(content)
self.context.current_tokens = token_response
self.context.update_token_expiry(token_response)
await self.context.storage.set_tokens(token_response)
return True
except ValidationError:
logger.exception("Invalid refresh response")
self.context.clear_tokens()
return False
async def _initialize(self) -> None: # pragma: no cover
"""Load stored tokens and client info."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = await self.context.storage.get_client_info()
self._initialized = True
def _add_auth_header(self, request: httpx.Request) -> None:
"""Add authorization header to request if we have valid tokens."""
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
if not self._initialized:
await self._initialize() # pragma: no cover
# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token() # pragma: no cover
refresh_response = yield refresh_request # pragma: no cover
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
# Refresh failed, need full re-authentication
self._initialized = False
if self.context.is_token_valid():
self._add_auth_header(request)
response = yield request
if response.status_code == 401:
# Perform full OAuth flow
try:
# OAuth flow must be inline due to generator constraints
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
)
for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
discovery_response = yield discovery_request # sending request
prm = await handle_protected_resource_response(discovery_response)
if prm:
self.context.protected_resource_metadata = prm
# todo: try all authorization_servers to find the OASM
assert (
len(prm.authorization_servers) > 0
) # this is always true as authorization_servers has a min length of 1
self.context.auth_server_url = str(prm.authorization_servers[0])
break
else:
logger.debug(f"Protected resource metadata discovery failed: {url}")
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no cover
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
if not ok:
break
if ok and asm:
self.context.oauth_metadata = asm
break
else:
logger.debug(f"OAuth metadata discovery failed: {url}")
# Step 3: Apply scope selection strategy
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response),
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)
# Step 4: Register client or use URL-based client ID (CIMD)
if not self.context.client_info:
if should_use_client_metadata_url(
self.context.oauth_metadata, self.context.client_metadata_url
):
# Use URL-based client ID (CIMD)
logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}")
client_information = create_client_info_from_metadata_url(
self.context.client_metadata_url, # type: ignore[arg-type]
redirect_uris=self.context.client_metadata.redirect_uris,
)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
else:
# Fallback to Dynamic Client Registration
registration_request = create_client_registration_request(
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
# Retry with new tokens
self._add_auth_header(request)
yield request
elif response.status_code == 403:
# Step 1: Extract error field from WWW-Authenticate header
error = extract_field_from_www_auth(response, "error")
# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope": # pragma: no branch
try:
# Step 2a: Update the required scopes
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
)
# Step 2b: Perform (re-)authorization and token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
# Retry with new tokens
self._add_auth_header(request)
yield request

View File

@@ -0,0 +1,336 @@
import logging
import re
from urllib.parse import urljoin, urlparse
from httpx import Request, Response
from pydantic import AnyUrl, ValidationError
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
ProtectedResourceMetadata,
)
from mcp.types import LATEST_PROTOCOL_VERSION
logger = logging.getLogger(__name__)
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
"""
Extract field from WWW-Authenticate header.
Returns:
Field value if found in WWW-Authenticate header, None otherwise
"""
www_auth_header = response.headers.get("WWW-Authenticate")
if not www_auth_header:
return None
# Pattern matches: field_name="value" or field_name=value (unquoted)
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
match = re.search(pattern, www_auth_header)
if match:
# Return quoted value if present, otherwise unquoted value
return match.group(1) or match.group(2)
return None
def extract_scope_from_www_auth(response: Response) -> str | None:
"""
Extract scope parameter from WWW-Authenticate header as per RFC6750.
Returns:
Scope string if found in WWW-Authenticate header, None otherwise
"""
return extract_field_from_www_auth(response, "scope")
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
"""
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
Returns:
Resource metadata URL if found in WWW-Authenticate header, None otherwise
"""
if not response or response.status_code != 401:
return None # pragma: no cover
return extract_field_from_www_auth(response, "resource_metadata")
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
"""
Build ordered list of URLs to try for protected resource metadata discovery.
Per SEP-985, the client MUST:
1. Try resource_metadata from WWW-Authenticate header (if present)
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
Args:
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
server_url: server url
Returns:
Ordered list of URLs to try for discovery
"""
urls: list[str] = []
# Priority 1: WWW-Authenticate header with resource_metadata parameter
if www_auth_url:
urls.append(www_auth_url)
# Priority 2-3: Well-known URIs (RFC 9728)
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
# Priority 2: Path-based well-known URI (if server has a path component)
if parsed.path and parsed.path != "/":
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
urls.append(path_based_url)
# Priority 3: Root-based well-known URI
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
urls.append(root_based_url)
return urls
def get_client_metadata_scopes(
www_authenticate_scope: str | None,
protected_resource_metadata: ProtectedResourceMetadata | None,
authorization_server_metadata: OAuthMetadata | None = None,
) -> str | None:
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
# Per MCP spec, scope selection priority order:
# 1. Use scope from WWW-Authenticate header (if provided)
# 2. Use all scopes from PRM scopes_supported (if available)
# 3. Omit scope parameter if neither is available
if www_authenticate_scope is not None:
# Priority 1: WWW-Authenticate header scope
return www_authenticate_scope
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
# Priority 2: PRM scopes_supported
return " ".join(protected_resource_metadata.scopes_supported)
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
else:
# Priority 3: Omit scope parameter
return None
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Generate ordered list of (url, type) tuples for discovery attempts.
Args:
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
"""
if not auth_server_url:
# Legacy path using the 2025-03-26 spec:
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
parsed = urlparse(server_url)
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
urls: list[str] = []
parsed = urlparse(auth_server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
# RFC 8414: Path-aware OAuth discovery
if parsed.path and parsed.path != "/":
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oauth_path))
# RFC 8414 section 5: Path-aware OIDC discovery
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oidc_path))
# https://openid.net/specs/openid-connect-discovery-1_0.html
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
urls.append(urljoin(base_url, oidc_path))
return urls
# OAuth root
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
# https://openid.net/specs/openid-connect-discovery-1_0.html
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
return urls
async def handle_protected_resource_response(
response: Response,
) -> ProtectedResourceMetadata | None:
"""
Handle protected resource metadata discovery response.
Per SEP-985, supports fallback when discovery fails at one URL.
Returns:
True if metadata was successfully discovered, False if we should try next URL
"""
if response.status_code == 200:
try:
content = await response.aread()
metadata = ProtectedResourceMetadata.model_validate_json(content)
return metadata
except ValidationError: # pragma: no cover
# Invalid metadata - try next URL
return None
else:
# Not found - try next URL in fallback chain
return None
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
if response.status_code == 200:
try:
content = await response.aread()
asm = OAuthMetadata.model_validate_json(content)
return True, asm
except ValidationError: # pragma: no cover
return True, None
elif response.status_code < 400 or response.status_code >= 500:
return False, None # Non-4XX error, stop trying
return True, None
def create_oauth_metadata_request(url: str) -> Request:
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
def create_client_registration_request(
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
) -> Request:
"""Build registration request or skip if already registered."""
if auth_server_metadata and auth_server_metadata.registration_endpoint:
registration_url = str(auth_server_metadata.registration_endpoint)
else:
registration_url = urljoin(auth_base_url, "/register")
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
"""Handle registration response."""
if response.status_code not in (200, 201):
await response.aread()
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
try:
content = await response.aread()
client_info = OAuthClientInformationFull.model_validate_json(content)
return client_info
# self.context.client_info = client_info
# await self.context.storage.set_client_info(client_info)
except ValidationError as e: # pragma: no cover
raise OAuthRegistrationError(f"Invalid registration response: {e}")
def is_valid_client_metadata_url(url: str | None) -> bool:
"""Validate that a URL is suitable for use as a client_id (CIMD).
The URL must be HTTPS with a non-root pathname.
Args:
url: The URL to validate
Returns:
True if the URL is a valid HTTPS URL with a non-root pathname
"""
if not url:
return False
try:
parsed = urlparse(url)
return parsed.scheme == "https" and parsed.path not in ("", "/")
except Exception:
return False
def should_use_client_metadata_url(
oauth_metadata: OAuthMetadata | None,
client_metadata_url: str | None,
) -> bool:
"""Determine if URL-based client ID (CIMD) should be used instead of DCR.
URL-based client IDs should be used when:
1. The server advertises client_id_metadata_document_supported=true
2. The client has a valid client_metadata_url configured
Args:
oauth_metadata: OAuth authorization server metadata
client_metadata_url: URL-based client ID (already validated)
Returns:
True if CIMD should be used, False if DCR should be used
"""
if not client_metadata_url:
return False
if not oauth_metadata:
return False
return oauth_metadata.client_id_metadata_document_supported is True
def create_client_info_from_metadata_url(
client_metadata_url: str, redirect_uris: list[AnyUrl] | None = None
) -> OAuthClientInformationFull:
"""Create client information using a URL-based client ID (CIMD).
When using URL-based client IDs, the URL itself becomes the client_id
and no client_secret is used (token_endpoint_auth_method="none").
Args:
client_metadata_url: The URL to use as the client_id
redirect_uris: The redirect URIs from the client metadata (passed through for
compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata)
Returns:
OAuthClientInformationFull with the URL as client_id
"""
return OAuthClientInformationFull(
client_id=client_metadata_url,
token_endpoint_auth_method="none",
redirect_uris=redirect_uris,
)
async def handle_token_response_scopes(
response: Response,
) -> OAuthToken:
"""Parse and validate token response with optional scope validation.
Parses token response JSON. Callers should check response.status_code before calling.
Args:
response: HTTP response from token endpoint (status already checked by caller)
Returns:
Validated OAuthToken model
Raises:
OAuthTokenError: If response JSON is invalid
"""
try:
content = await response.aread()
token_response = OAuthToken.model_validate_json(content)
return token_response
except ValidationError as e: # pragma: no cover
raise OAuthTokenError(f"Invalid token response: {e}")

View File

@@ -0,0 +1,9 @@
"""
Experimental client features.
WARNING: These APIs are experimental and may change without notice.
"""
from mcp.client.experimental.tasks import ExperimentalClientFeatures
__all__ = ["ExperimentalClientFeatures"]

View File

@@ -0,0 +1,290 @@
"""
Experimental task handler protocols for server -> client requests.
This module provides Protocol types and default handlers for when servers
send task-related requests to clients (the reverse of normal client -> server flow).
WARNING: These APIs are experimental and may change without notice.
Use cases:
- Server sends task-augmented sampling/elicitation request to client
- Client creates a local task, spawns background work, returns CreateTaskResult
- Server polls client's task status via tasks/get, tasks/result, etc.
"""
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol
from pydantic import TypeAdapter
import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.session import RequestResponder
if TYPE_CHECKING:
from mcp.client.session import ClientSession
class GetTaskHandlerFnT(Protocol):
"""Handler for tasks/get requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.GetTaskRequestParams,
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
class GetTaskResultHandlerFnT(Protocol):
"""Handler for tasks/result requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.GetTaskPayloadRequestParams,
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
class ListTasksHandlerFnT(Protocol):
"""Handler for tasks/list requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.PaginatedRequestParams | None,
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
class CancelTaskHandlerFnT(Protocol):
"""Handler for tasks/cancel requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CancelTaskRequestParams,
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
class TaskAugmentedSamplingFnT(Protocol):
"""Handler for task-augmented sampling/createMessage requests from server.
When server sends a CreateMessageRequest with task field, this callback
is invoked. The callback should create a task, spawn background work,
and return CreateTaskResult immediately.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
class TaskAugmentedElicitationFnT(Protocol):
"""Handler for task-augmented elicitation/create requests from server.
When server sends an ElicitRequest with task field, this callback
is invoked. The callback should create a task, spawn background work,
and return CreateTaskResult immediately.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
async def default_get_task_handler(
context: RequestContext["ClientSession", Any],
params: types.GetTaskRequestParams,
) -> types.GetTaskResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/get not supported",
)
async def default_get_task_result_handler(
context: RequestContext["ClientSession", Any],
params: types.GetTaskPayloadRequestParams,
) -> types.GetTaskPayloadResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/result not supported",
)
async def default_list_tasks_handler(
context: RequestContext["ClientSession", Any],
params: types.PaginatedRequestParams | None,
) -> types.ListTasksResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/list not supported",
)
async def default_cancel_task_handler(
context: RequestContext["ClientSession", Any],
params: types.CancelTaskRequestParams,
) -> types.CancelTaskResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/cancel not supported",
)
async def default_task_augmented_sampling(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Task-augmented sampling not supported",
)
async def default_task_augmented_elicitation(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Task-augmented elicitation not supported",
)
@dataclass
class ExperimentalTaskHandlers:
"""Container for experimental task handlers.
Groups all task-related handlers that handle server -> client requests.
This includes both pure task requests (get, list, cancel, result) and
task-augmented request handlers (sampling, elicitation with task field).
WARNING: These APIs are experimental and may change without notice.
Example:
handlers = ExperimentalTaskHandlers(
get_task=my_get_task_handler,
list_tasks=my_list_tasks_handler,
)
session = ClientSession(..., experimental_task_handlers=handlers)
"""
# Pure task request handlers
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
# Task-augmented request handlers
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
def build_capability(self) -> types.ClientTasksCapability | None:
"""Build ClientTasksCapability from the configured handlers.
Returns a capability object that reflects which handlers are configured
(i.e., not using the default "not supported" handlers).
Returns:
ClientTasksCapability if any handlers are provided, None otherwise
"""
has_list = self.list_tasks is not default_list_tasks_handler
has_cancel = self.cancel_task is not default_cancel_task_handler
has_sampling = self.augmented_sampling is not default_task_augmented_sampling
has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation
# If no handlers are provided, return None
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
return None
# Build requests capability if any request handlers are provided
requests_capability: types.ClientTasksRequestsCapability | None = None
if has_sampling or has_elicitation:
requests_capability = types.ClientTasksRequestsCapability(
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
if has_sampling
else None,
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
if has_elicitation
else None,
)
return types.ClientTasksCapability(
list=types.TasksListCapability() if has_list else None,
cancel=types.TasksCancelCapability() if has_cancel else None,
requests=requests_capability,
)
@staticmethod
def handles_request(request: types.ServerRequest) -> bool:
"""Check if this handler handles the given request type."""
return isinstance(
request.root,
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
)
async def handle_request(
self,
ctx: RequestContext["ClientSession", Any],
responder: RequestResponder[types.ServerRequest, types.ClientResult],
) -> None:
"""Handle a task-related request from the server.
Call handles_request() first to check if this handler can handle the request.
"""
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)
match responder.request.root:
case types.GetTaskRequest(params=params):
response = await self.get_task(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.GetTaskPayloadRequest(params=params):
response = await self.get_task_result(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.ListTasksRequest(params=params):
response = await self.list_tasks(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.CancelTaskRequest(params=params):
response = await self.cancel_task(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case _: # pragma: no cover
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
# Backwards compatibility aliases
default_task_augmented_sampling_callback = default_task_augmented_sampling
default_task_augmented_elicitation_callback = default_task_augmented_elicitation

View File

@@ -0,0 +1,224 @@
"""
Experimental client-side task support.
This module provides client methods for interacting with MCP tasks.
WARNING: These APIs are experimental and may change without notice.
Example:
# Call a tool as a task
result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"})
task_id = result.task.taskId
# Get task status
status = await session.experimental.get_task(task_id)
# Get task result when complete
if status.status == "completed":
result = await session.experimental.get_task_result(task_id, CallToolResult)
# List all tasks
tasks = await session.experimental.list_tasks()
# Cancel a task
await session.experimental.cancel_task(task_id)
"""
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypeVar
import mcp.types as types
from mcp.shared.experimental.tasks.polling import poll_until_terminal
if TYPE_CHECKING:
from mcp.client.session import ClientSession
ResultT = TypeVar("ResultT", bound=types.Result)
class ExperimentalClientFeatures:
"""
Experimental client features for tasks and other experimental APIs.
WARNING: These APIs are experimental and may change without notice.
Access via session.experimental:
status = await session.experimental.get_task(task_id)
"""
def __init__(self, session: "ClientSession") -> None:
self._session = session
async def call_tool_as_task(
self,
name: str,
arguments: dict[str, Any] | None = None,
*,
ttl: int = 60000,
meta: dict[str, Any] | None = None,
) -> types.CreateTaskResult:
"""Call a tool as a task, returning a CreateTaskResult for polling.
This is a convenience method for calling tools that support task execution.
The server will return a task reference instead of the immediate result,
which can then be polled via `get_task()` and retrieved via `get_task_result()`.
Args:
name: The tool name
arguments: Tool arguments
ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute)
meta: Optional metadata to include in the request
Returns:
CreateTaskResult containing the task reference
Example:
# Create task
result = await session.experimental.call_tool_as_task(
"long_running_tool", {"input": "data"}
)
task_id = result.task.taskId
# Poll for completion
while True:
status = await session.experimental.get_task(task_id)
if status.status == "completed":
break
await asyncio.sleep(0.5)
# Get result
final = await session.experimental.get_task_result(task_id, CallToolResult)
"""
_meta: types.RequestParams.Meta | None = None
if meta is not None:
_meta = types.RequestParams.Meta(**meta)
return await self._session.send_request(
types.ClientRequest(
types.CallToolRequest(
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
task=types.TaskMetadata(ttl=ttl),
_meta=_meta,
),
)
),
types.CreateTaskResult,
)
async def get_task(self, task_id: str) -> types.GetTaskResult:
"""
Get the current status of a task.
Args:
task_id: The task identifier
Returns:
GetTaskResult containing the task status and metadata
"""
return await self._session.send_request(
types.ClientRequest(
types.GetTaskRequest(
params=types.GetTaskRequestParams(taskId=task_id),
)
),
types.GetTaskResult,
)
async def get_task_result(
self,
task_id: str,
result_type: type[ResultT],
) -> ResultT:
"""
Get the result of a completed task.
The result type depends on the original request type:
- tools/call tasks return CallToolResult
- Other request types return their corresponding result type
Args:
task_id: The task identifier
result_type: The expected result type (e.g., CallToolResult)
Returns:
The task result, validated against result_type
"""
return await self._session.send_request(
types.ClientRequest(
types.GetTaskPayloadRequest(
params=types.GetTaskPayloadRequestParams(taskId=task_id),
)
),
result_type,
)
async def list_tasks(
self,
cursor: str | None = None,
) -> types.ListTasksResult:
"""
List all tasks.
Args:
cursor: Optional pagination cursor
Returns:
ListTasksResult containing tasks and optional next cursor
"""
params = types.PaginatedRequestParams(cursor=cursor) if cursor else None
return await self._session.send_request(
types.ClientRequest(
types.ListTasksRequest(params=params),
),
types.ListTasksResult,
)
async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
"""
Cancel a running task.
Args:
task_id: The task identifier
Returns:
CancelTaskResult with the updated task state
"""
return await self._session.send_request(
types.ClientRequest(
types.CancelTaskRequest(
params=types.CancelTaskRequestParams(taskId=task_id),
)
),
types.CancelTaskResult,
)
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
"""
Poll a task until it reaches a terminal status.
Yields GetTaskResult for each poll, allowing the caller to react to
status changes (e.g., handle input_required). Exits when task reaches
a terminal status (completed, failed, cancelled).
Respects the pollInterval hint from the server.
Args:
task_id: The task identifier
Yields:
GetTaskResult for each poll
Example:
async for status in session.experimental.poll_task(task_id):
print(f"Status: {status.status}")
if status.status == "input_required":
# Handle elicitation request via tasks/result
pass
# Task is now terminal, get the result
result = await session.experimental.get_task_result(task_id, CallToolResult)
"""
async for status in poll_until_terminal(self.get_task, task_id):
yield status

View File

@@ -0,0 +1,615 @@
import logging
from datetime import timedelta
from typing import Any, Protocol, overload
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
from typing_extensions import deprecated
import mcp.types as types
from mcp.client.experimental import ExperimentalClientFeatures
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
logger = logging.getLogger("client")
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
class LoggingFnT(Protocol):
async def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ... # pragma: no branch
class MessageHandlerFnT(Protocol):
async def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ... # pragma: no branch
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
await anyio.lowlevel.checkpoint()
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData( # pragma: no cover
code=types.INVALID_REQUEST,
message="Elicitation not supported",
)
async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
async def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._sampling_capabilities = sampling_capabilities
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
self._experimental_features: ExperimentalClientFeatures | None = None
# Experimental: Task handlers (use defaults if not provided)
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
async def initialize(self) -> types.InitializeResult:
sampling = (
(self._sampling_capabilities or types.SamplingCapability())
if self._sampling_callback is not _default_sampling_callback
else None
)
elicitation = (
types.ElicitationCapability(
form=types.FormElicitationCapability(),
url=types.UrlElicitationCapability(),
)
if self._elicitation_callback is not _default_elicitation_callback
else None
)
roots = (
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
types.RootsCapability(listChanged=True)
if self._list_roots_callback is not _default_list_roots_callback
else None
)
result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
tasks=self._task_handlers.build_capability(),
),
clientInfo=self._client_info,
),
)
),
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
self._server_capabilities = result.capabilities
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
return result
def get_server_capabilities(self) -> types.ServerCapabilities | None:
"""Return the server capabilities received during initialization.
Returns None if the session has not been initialized yet.
"""
return self._server_capabilities
@property
def experimental(self) -> ExperimentalClientFeatures:
"""Experimental APIs for tasks and other features.
WARNING: These APIs are experimental and may change without notice.
Example:
status = await session.experimental.get_task(task_id)
result = await session.experimental.get_task_result(task_id, CallToolResult)
"""
if self._experimental_features is None:
self._experimental_features = ExperimentalClientFeatures(self)
return self._experimental_features
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
types.ClientRequest(types.PingRequest()),
types.EmptyResult,
)
async def send_progress_notification(
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
types.ClientNotification(
types.ProgressNotification(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
message=message,
),
),
)
)
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SetLevelRequest(
params=types.SetLevelRequestParams(level=level),
)
),
types.EmptyResult,
)
@overload
@deprecated("Use list_resources(params=PaginatedRequestParams(...)) instead")
async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ...
@overload
async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ...
@overload
async def list_resources(self) -> types.ListResourcesResult: ...
async def list_resources(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourcesResult:
"""Send a resources/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListResourcesRequest(params=request_params)),
types.ListResourcesResult,
)
@overload
@deprecated("Use list_resource_templates(params=PaginatedRequestParams(...)) instead")
async def list_resource_templates(self, cursor: str | None) -> types.ListResourceTemplatesResult: ...
@overload
async def list_resource_templates(
self, *, params: types.PaginatedRequestParams | None
) -> types.ListResourceTemplatesResult: ...
@overload
async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ...
async def list_resource_templates(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)),
types.ListResourceTemplatesResult,
)
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return await self.send_request(
types.ClientRequest(
types.ReadResourceRequest(
params=types.ReadResourceRequestParams(uri=uri),
)
),
types.ReadResourceResult,
)
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SubscribeRequest(
params=types.SubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.UnsubscribeRequest(
params=types.UnsubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult:
"""Send a tools/call request with optional progress callback support."""
_meta: types.RequestParams.Meta | None = None
if meta is not None:
_meta = types.RequestParams.Meta(**meta)
result = await self.send_request(
types.ClientRequest(
types.CallToolRequest(
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
)
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
)
if not result.isError:
await self._validate_tool_result(name, result)
return result
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
"""Validate the structured content of a tool result against its output schema."""
if name not in self._tool_output_schemas:
# refresh output schema cache
await self.list_tools()
output_schema = None
if name in self._tool_output_schemas:
output_schema = self._tool_output_schemas.get(name)
else:
logger.warning(f"Tool {name} not listed by server, cannot validate any structured content")
if output_schema is not None:
from jsonschema import SchemaError, ValidationError, validate
if result.structuredContent is None:
raise RuntimeError(
f"Tool {name} has an output schema but did not return structured content"
) # pragma: no cover
try:
validate(result.structuredContent, output_schema)
except ValidationError as e:
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") # pragma: no cover
except SchemaError as e: # pragma: no cover
raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover
@overload
@deprecated("Use list_prompts(params=PaginatedRequestParams(...)) instead")
async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ...
@overload
async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ...
@overload
async def list_prompts(self) -> types.ListPromptsResult: ...
async def list_prompts(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListPromptsResult:
"""Send a prompts/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListPromptsRequest(params=request_params)),
types.ListPromptsResult,
)
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
return await self.send_request(
types.ClientRequest(
types.GetPromptRequest(
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
types.GetPromptResult,
)
async def complete(
self,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
context_arguments: dict[str, str] | None = None,
) -> types.CompleteResult:
"""Send a completion/complete request."""
context = None
if context_arguments is not None:
context = types.CompletionContext(arguments=context_arguments)
return await self.send_request(
types.ClientRequest(
types.CompleteRequest(
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
context=context,
),
)
),
types.CompleteResult,
)
@overload
@deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead")
async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ...
@overload
async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ...
@overload
async def list_tools(self) -> types.ListToolsResult: ...
async def list_tools(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListToolsResult:
"""Send a tools/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
result = await self.send_request(
types.ClientRequest(types.ListToolsRequest(params=request_params)),
types.ListToolsResult,
)
# Cache tool output schemas for future validation
# Note: don't clear the cache, as we may be using a cursor
for tool in result.tools:
self._tool_output_schemas[tool.name] = tool.outputSchema
return result
async def send_roots_list_changed(self) -> None: # pragma: no cover
"""Send a roots/list_changed notification."""
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
# Delegate to experimental task handler if applicable
if self._task_handlers.handles_request(responder.request):
with responder:
await self._task_handlers.handle_request(ctx, responder)
return None
# Core request handling
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
# Check if this is a task-augmented request
if params.task is not None:
response = await self._task_handlers.augmented_sampling(ctx, params, params.task)
else:
response = await self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ElicitRequest(params=params):
with responder:
# Check if this is a task-augmented request
if params.task is not None:
response = await self._task_handlers.augmented_elicitation(ctx, params, params.task)
else:
response = await self._elicitation_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.PingRequest(): # pragma: no cover
with responder:
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
case _: # pragma: no cover
pass # Task requests handled above by _task_handlers
return None
async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)
async def _received_notification(self, notification: types.ServerNotification) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case types.ElicitCompleteNotification(params=params):
# Handle elicitation completion notification
# Clients MAY use this to retry requests or update UI
# The notification contains the elicitationId of the completed elicitation
pass
case _:
pass

View File

@@ -0,0 +1,447 @@
"""
SessionGroup concurrently manages multiple MCP session connections.
Tools, resources, and prompts are aggregated across servers. Servers may
be connected to or disconnected from at any point after initialization.
This abstractions can handle naming collisions using a custom user-provided
hook.
"""
import contextlib
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from types import TracebackType
from typing import Any, TypeAlias, overload
import anyio
import httpx
from pydantic import BaseModel
from typing_extensions import Self, deprecated
import mcp
from mcp import types
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamable_http_client
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.exceptions import McpError
from mcp.shared.session import ProgressFnT
class SseServerParameters(BaseModel):
"""Parameters for intializing a sse_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: float = 5
# Timeout for SSE read operations.
sse_read_timeout: float = 60 * 5
class StreamableHttpParameters(BaseModel):
"""Parameters for intializing a streamable_http_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: timedelta = timedelta(seconds=30)
# Timeout for SSE read operations.
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
# Close the client session when the transport closes.
terminate_on_close: bool = True
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
# Use dataclass instead of pydantic BaseModel
# because pydantic BaseModel cannot handle Protocol fields.
@dataclass
class ClientSessionParameters:
"""Parameters for establishing a client session to an MCP server."""
read_timeout_seconds: timedelta | None = None
sampling_callback: SamplingFnT | None = None
elicitation_callback: ElicitationFnT | None = None
list_roots_callback: ListRootsFnT | None = None
logging_callback: LoggingFnT | None = None
message_handler: MessageHandlerFnT | None = None
client_info: types.Implementation | None = None
class ClientSessionGroup:
"""Client for managing connections to multiple MCP servers.
This class is responsible for encapsulating management of server connections.
It aggregates tools, resources, and prompts from all connected servers.
For auxiliary handlers, such as resource subscription, this is delegated to
the client and can be accessed via the session.
Example Usage:
name_fn = lambda name, server_info: f"{(server_info.name)}_{name}"
async with ClientSessionGroup(component_name_hook=name_fn) as group:
for server_param in server_params:
await group.connect_to_server(server_param)
...
"""
class _ComponentNames(BaseModel):
"""Used for reverse index to find components."""
prompts: set[str] = set()
resources: set[str] = set()
tools: set[str] = set()
# Standard MCP components.
_prompts: dict[str, types.Prompt]
_resources: dict[str, types.Resource]
_tools: dict[str, types.Tool]
# Client-server connection management.
_sessions: dict[mcp.ClientSession, _ComponentNames]
_tool_to_session: dict[str, mcp.ClientSession]
_exit_stack: contextlib.AsyncExitStack
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
# Optional fn consuming (component_name, serverInfo) for custom names.
# This is provide a means to mitigate naming conflicts across servers.
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
_component_name_hook: _ComponentNameHook | None
def __init__(
self,
exit_stack: contextlib.AsyncExitStack | None = None,
component_name_hook: _ComponentNameHook | None = None,
) -> None:
"""Initializes the MCP client."""
self._tools = {}
self._resources = {}
self._prompts = {}
self._sessions = {}
self._tool_to_session = {}
if exit_stack is None:
self._exit_stack = contextlib.AsyncExitStack()
self._owns_exit_stack = True
else:
self._exit_stack = exit_stack
self._owns_exit_stack = False
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook
async def __aenter__(self) -> Self: # pragma: no cover
# Enter the exit stack only if we created it ourselves
if self._owns_exit_stack:
await self._exit_stack.__aenter__()
return self
async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_val: BaseException | None,
_exc_tb: TracebackType | None,
) -> bool | None: # pragma: no cover
"""Closes session exit stacks and main exit stack upon completion."""
# Only close the main exit stack if we created it
if self._owns_exit_stack:
await self._exit_stack.aclose()
# Concurrently close session stacks.
async with anyio.create_task_group() as tg:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)
@property
def sessions(self) -> list[mcp.ClientSession]:
"""Returns the list of sessions being managed."""
return list(self._sessions.keys()) # pragma: no cover
@property
def prompts(self) -> dict[str, types.Prompt]:
"""Returns the prompts as a dictionary of names to prompts."""
return self._prompts
@property
def resources(self) -> dict[str, types.Resource]:
"""Returns the resources as a dictionary of names to resources."""
return self._resources
@property
def tools(self) -> dict[str, types.Tool]:
"""Returns the tools as a dictionary of names to tools."""
return self._tools
@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult: ...
@overload
@deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.")
async def call_tool(
self,
name: str,
*,
args: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult: ...
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
args: dict[str, Any] | None = None,
) -> types.CallToolResult:
"""Executes a tool given its name and arguments."""
session = self._tool_to_session[name]
session_tool_name = self.tools[name].name
return await session.call_tool(
session_tool_name,
arguments if args is None else args,
read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
meta=meta,
)
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
"""Disconnects from a single MCP server."""
session_known_for_components = session in self._sessions
session_known_for_stack = session in self._session_exit_stacks
if not session_known_for_components and not session_known_for_stack:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message="Provided session is not managed or already disconnected.",
)
)
if session_known_for_components: # pragma: no cover
component_names = self._sessions.pop(session) # Pop from _sessions tracking
# Remove prompts associated with the session.
for name in component_names.prompts:
if name in self._prompts:
del self._prompts[name]
# Remove resources associated with the session.
for name in component_names.resources:
if name in self._resources:
del self._resources[name]
# Remove tools associated with the session.
for name in component_names.tools:
if name in self._tools:
del self._tools[name]
if name in self._tool_to_session:
del self._tool_to_session[name]
# Clean up the session's resources via its dedicated exit stack
if session_known_for_stack:
session_stack_to_close = self._session_exit_stacks.pop(session) # pragma: no cover
await session_stack_to_close.aclose() # pragma: no cover
async def connect_with_session(
self, server_info: types.Implementation, session: mcp.ClientSession
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
await self._aggregate_components(server_info, session)
return session
async def connect_to_server(
self,
server_params: ServerParameters,
session_params: ClientSessionParameters | None = None,
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters())
return await self.connect_with_session(server_info, session)
async def _establish_session(
self,
server_params: ServerParameters,
session_params: ClientSessionParameters,
) -> tuple[types.Implementation, mcp.ClientSession]:
"""Establish a client session to an MCP server."""
session_stack = contextlib.AsyncExitStack()
try:
# Create read and write streams that facilitate io with the server.
if isinstance(server_params, StdioServerParameters):
client = mcp.stdio_client(server_params)
read, write = await session_stack.enter_async_context(client)
elif isinstance(server_params, SseServerParameters):
client = sse_client(
url=server_params.url,
headers=server_params.headers,
timeout=server_params.timeout,
sse_read_timeout=server_params.sse_read_timeout,
)
read, write = await session_stack.enter_async_context(client)
else:
httpx_client = create_mcp_http_client(
headers=server_params.headers,
timeout=httpx.Timeout(
server_params.timeout.total_seconds(),
read=server_params.sse_read_timeout.total_seconds(),
),
)
await session_stack.enter_async_context(httpx_client)
client = streamable_http_client(
url=server_params.url,
http_client=httpx_client,
terminate_on_close=server_params.terminate_on_close,
)
read, write, _ = await session_stack.enter_async_context(client)
session = await session_stack.enter_async_context(
mcp.ClientSession(
read,
write,
read_timeout_seconds=session_params.read_timeout_seconds,
sampling_callback=session_params.sampling_callback,
elicitation_callback=session_params.elicitation_callback,
list_roots_callback=session_params.list_roots_callback,
logging_callback=session_params.logging_callback,
message_handler=session_params.message_handler,
client_info=session_params.client_info,
)
)
result = await session.initialize()
# Session successfully initialized.
# Store its stack and register the stack with the main group stack.
self._session_exit_stacks[session] = session_stack
# session_stack itself becomes a resource managed by the
# main _exit_stack.
await self._exit_stack.enter_async_context(session_stack)
return result.serverInfo, session
except Exception: # pragma: no cover
# If anything during this setup fails, ensure the session-specific
# stack is closed.
await session_stack.aclose()
raise
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
"""Aggregates prompts, resources, and tools from a given session."""
# Create a reverse index so we can find all prompts, resources, and
# tools belonging to this session. Used for removing components from
# the session group via self.disconnect_from_server.
component_names = self._ComponentNames()
# Temporary components dicts. We do not want to modify the aggregate
# lists in case of an intermediate failure.
prompts_temp: dict[str, types.Prompt] = {}
resources_temp: dict[str, types.Resource] = {}
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
# Query the server for its prompts and aggregate to list.
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")
# Query the server for its resources and aggregate to list.
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")
# Query the server for its tools and aggregate to list.
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")
# Clean up exit stack for session if we couldn't retrieve anything
# from the server.
if not any((prompts_temp, resources_temp, tools_temp)):
del self._session_exit_stacks[session] # pragma: no cover
# Check for duplicates.
matching_prompts = prompts_temp.keys() & self._prompts.keys()
if matching_prompts:
raise McpError( # pragma: no cover
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_prompts} already exist in group prompts.",
)
)
matching_resources = resources_temp.keys() & self._resources.keys()
if matching_resources:
raise McpError( # pragma: no cover
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_resources} already exist in group resources.",
)
)
matching_tools = tools_temp.keys() & self._tools.keys()
if matching_tools:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_tools} already exist in group tools.",
)
)
# Aggregate components.
self._sessions[session] = component_names
self._prompts.update(prompts_temp)
self._resources.update(resources_temp)
self._tools.update(tools_temp)
self._tool_to_session.update(tool_to_session_temp)
def _component_name(self, name: str, server_info: types.Implementation) -> str:
if self._component_name_hook:
return self._component_name_hook(name, server_info)
return name

View File

@@ -0,0 +1,164 @@
import logging
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import parse_qs, urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError
import mcp.types as types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
query_params = parse_qs(urlparse(endpoint_url).query)
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
on_session_created: Callable[[str], None] | None = None,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
on_session_created: Optional callback invoked with the session ID when received.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if ( # pragma: no cover
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = ( # pragma: no cover
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover
if on_session_created:
session_id = _extract_session_id_from_endpoint(endpoint_url)
if session_id:
on_session_created(session_id)
task_status.started(endpoint_url)
case "message":
# Skip empty data (keep-alive pings)
if not sse.data:
continue
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(f"Received server message: {message}")
except Exception as exc: # pragma: no cover
logger.exception("Error parsing server message") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
continue # pragma: no cover
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: no cover
logger.exception("Encountered SSE exception") # pragma: no cover
raise sse_exc # pragma: no cover
except Exception as exc: # pragma: no cover
logger.exception("Error in sse_reader") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception: # pragma: no cover
logger.exception("Error in post_writer") # pragma: no cover
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()

View File

@@ -0,0 +1,278 @@
import logging
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal, TextIO
import anyio
import anyio.lowlevel
from anyio.abc import Process
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field
import mcp.types as types
from mcp.os.posix.utilities import terminate_posix_process_tree
from mcp.os.win32.utilities import (
FallbackProcess,
create_windows_process,
get_windows_executable_command,
terminate_windows_process_tree,
)
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
# Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = (
[
"APPDATA",
"HOMEDRIVE",
"HOMEPATH",
"LOCALAPPDATA",
"PATH",
"PATHEXT",
"PROCESSOR_ARCHITECTURE",
"SYSTEMDRIVE",
"SYSTEMROOT",
"TEMP",
"USERNAME",
"USERPROFILE",
]
if sys.platform == "win32"
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
)
# Timeout for process termination before falling back to force kill
PROCESS_TERMINATION_TIMEOUT = 2.0
def get_default_environment() -> dict[str, str]:
"""
Returns a default environment object including only environment variables deemed
safe to inherit.
"""
env: dict[str, str] = {}
for key in DEFAULT_INHERITED_ENV_VARS:
value = os.environ.get(key)
if value is None:
continue # pragma: no cover
if value.startswith("()"): # pragma: no cover
# Skip functions, which are a security risk
continue # pragma: no cover
env[key] = value
return env
class StdioServerParameters(BaseModel):
command: str
"""The executable to run to start the server."""
args: list[str] = Field(default_factory=list)
"""Command line arguments to pass to the executable."""
env: dict[str, str] | None = None
"""
The environment to use when spawning the process.
If not specified, the result of get_default_environment() will be used.
"""
cwd: str | Path | None = None
"""The working directory to use when spawning the process."""
encoding: str = "utf-8"
"""
The text encoding used when sending/receiving messages to the server
defaults to utf-8
"""
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
"""
The text encoding error handler.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values
"""
@asynccontextmanager
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
"""
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
try:
command = _get_executable_command(server.command)
# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
errlog=errlog,
cwd=server.cwd,
)
except OSError:
# Clean up streams if process creation fails
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
raise
async def stdout_reader():
assert process.stdout, "Opened process is missing stdout"
try:
async with read_stream_writer:
buffer = ""
async for chunk in TextReceiveStream(
process.stdout,
encoding=server.encoding,
errors=server.encoding_error_handler,
):
lines = (buffer + chunk).split("\n")
buffer = lines.pop()
for line in lines:
try:
message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc: # pragma: no cover
logger.exception("Failed to parse JSONRPC message from server")
await read_stream_writer.send(exc)
continue
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()
async def stdin_writer():
assert process.stdin, "Opened process is missing stdin"
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
await process.stdin.send(
(json + "\n").encode(
encoding=server.encoding,
errors=server.encoding_error_handler,
)
)
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()
async with (
anyio.create_task_group() as tg,
process,
):
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
try:
yield read_stream, write_stream
finally:
# MCP spec: stdio shutdown sequence
# 1. Close input stream to server
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
# 3. Send SIGKILL if still not exited
if process.stdin: # pragma: no branch
try:
await process.stdin.aclose()
except Exception: # pragma: no cover
# stdin might already be closed, which is fine
pass
try:
# Give the process time to exit gracefully after stdin closes
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
await process.wait()
except TimeoutError:
# Process didn't exit from stdin closure, use platform-specific termination
# which handles SIGTERM -> SIGKILL escalation
await _terminate_process_tree(process)
except ProcessLookupError: # pragma: no cover
# Process already exited, which is fine
pass
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
def _get_executable_command(command: str) -> str:
"""
Get the correct executable command normalized for the current platform.
Args:
command: Base command (e.g., 'uvx', 'npx')
Returns:
str: Platform-appropriate command
"""
if sys.platform == "win32": # pragma: no cover
return get_windows_executable_command(command)
else:
return command # pragma: no cover
async def _create_platform_compatible_process(
command: str,
args: list[str],
env: dict[str, str] | None = None,
errlog: TextIO = sys.stderr,
cwd: Path | str | None = None,
):
"""
Creates a subprocess in a platform-compatible way.
Unix: Creates process in a new session/process group for killpg support
Windows: Creates process in a Job Object for reliable child termination
"""
if sys.platform == "win32": # pragma: no cover
process = await create_windows_process(command, args, env, errlog, cwd)
else:
process = await anyio.open_process(
[command, *args],
env=env,
stderr=errlog,
cwd=cwd,
start_new_session=True,
) # pragma: no cover
return process
async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None:
"""
Terminate a process and all its children using platform-specific methods.
Unix: Uses os.killpg() for atomic process group termination
Windows: Uses Job Objects via pywin32 for reliable child process cleanup
Args:
process: The process to terminate
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
"""
if sys.platform == "win32": # pragma: no cover
await terminate_windows_process_tree(process, timeout_seconds)
else: # pragma: no cover
# FallbackProcess should only be used for Windows compatibility
assert isinstance(process, Process)
await terminate_posix_process_tree(process, timeout_seconds)

View File

@@ -0,0 +1,722 @@
"""
StreamableHTTP Client Transport Module
This module implements the StreamableHTTP transport for MCP clients,
providing support for HTTP POST requests with optional SSE streaming responses
and session management.
"""
import contextlib
import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, overload
from warnings import warn
import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from typing_extensions import deprecated
from mcp.shared._httpx_utils import (
McpHttpClientFactory,
create_mcp_http_client,
)
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
)
logger = logging.getLogger(__name__)
SessionMessageOrError = SessionMessage | Exception
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
StreamReader = MemoryObjectReceiveStream[SessionMessage]
GetSessionIdCallback = Callable[[], str | None]
MCP_SESSION_ID = "mcp-session-id"
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
LAST_EVENT_ID = "last-event-id"
# Reconnection defaults
DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
CONTENT_TYPE = "content-type"
ACCEPT = "accept"
JSON = "application/json"
SSE = "text/event-stream"
# Sentinel value for detecting unset optional parameters
_UNSET = object()
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
@dataclass
class RequestContext:
"""Context for a request operation."""
client: httpx.AsyncClient
session_id: str | None
session_message: SessionMessage
metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter
headers: dict[str, str] | None = None # Deprecated - no longer used
sse_read_timeout: float | None = None # Deprecated - no longer used
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""
@overload
def __init__(self, url: str) -> None: ...
@overload
@deprecated(
"Parameters headers, timeout, sse_read_timeout, and auth are deprecated. "
"Configure these on the httpx.AsyncClient instead."
)
def __init__(
self,
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
auth: httpx.Auth | None = None,
) -> None: ...
def __init__(
self,
url: str,
headers: Any = _UNSET,
timeout: Any = _UNSET,
sse_read_timeout: Any = _UNSET,
auth: Any = _UNSET,
) -> None:
"""Initialize the StreamableHTTP transport.
Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
"""
# Check for deprecated parameters and issue runtime warning
deprecated_params: list[str] = []
if headers is not _UNSET:
deprecated_params.append("headers")
if timeout is not _UNSET:
deprecated_params.append("timeout")
if sse_read_timeout is not _UNSET:
deprecated_params.append("sse_read_timeout")
if auth is not _UNSET:
deprecated_params.append("auth")
if deprecated_params:
warn(
f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. "
"Configure these on the httpx.AsyncClient instead.",
DeprecationWarning,
stacklevel=2,
)
self.url = url
self.session_id = None
self.protocol_version = None
def _prepare_headers(self) -> dict[str, str]:
"""Build MCP-specific request headers.
These headers will be merged with the httpx.AsyncClient's default headers,
with these MCP-specific headers taking precedence.
"""
headers: dict[str, str] = {}
# Add MCP protocol headers
headers[ACCEPT] = f"{JSON}, {SSE}"
headers[CONTENT_TYPE] = JSON
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
if self.protocol_version:
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
return headers
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
) -> None:
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")
def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception as exc: # pragma: no cover
logger.warning(
f"Failed to parse initialization response as InitializeResult: {exc}"
) # pragma: no cover
logger.warning(f"Raw result: {message.root.result}")
async def _handle_sse_event(
self,
sse: ServerSentEvent,
read_stream_writer: StreamWriter,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
is_initialization: bool = False,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
# Handle priming events (empty data with ID) for resumability
if not sse.data:
# Call resumption callback for priming events that have an ID
if sse.id and resumption_callback:
await resumption_callback(sse.id)
return False
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")
# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
await resumption_callback(sse.id)
# If this is a response or error return True indicating completion
# Otherwise, return False to continue listening
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc: # pragma: no cover
logger.exception("Error parsing SSE message")
await read_stream_writer.send(exc)
return False
else: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}")
return False
async def handle_get_stream(
self,
client: httpx.AsyncClient,
read_stream_writer: StreamWriter,
) -> None:
"""Handle GET stream for server-initiated messages with auto-reconnect."""
last_event_id: str | None = None
retry_interval_ms: int | None = None
attempt: int = 0
while attempt < MAX_RECONNECTION_ATTEMPTS: # pragma: no branch
try:
if not self.session_id:
return
headers = self._prepare_headers()
if last_event_id:
headers[LAST_EVENT_ID] = last_event_id # pragma: no cover
async with aconnect_sse(
client,
"GET",
self.url,
headers=headers,
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
async for sse in event_source.aiter_sse():
# Track last event ID for reconnection
if sse.id:
last_event_id = sse.id # pragma: no cover
# Track retry interval from server
if sse.retry is not None:
retry_interval_ms = sse.retry # pragma: no cover
await self._handle_sse_event(sse, read_stream_writer)
# Stream ended normally (server closed) - reset attempt counter
attempt = 0
except Exception as exc: # pragma: no cover
logger.debug(f"GET stream error: {exc}")
attempt += 1
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
logger.debug(f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
return
# Wait before reconnecting
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
logger.info(f"GET stream disconnected, reconnecting in {delay_ms}ms...")
await anyio.sleep(delay_ms / 1000.0)
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._prepare_headers()
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.root.id
async with aconnect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
async for sse in event_source.aiter_sse(): # pragma: no branch
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
break
async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_headers()
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
async with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
if response.status_code == 404: # pragma: no branch
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error( # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
message.root.id, # pragma: no cover
) # pragma: no cover
return # pragma: no cover
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type( # pragma: no cover
content_type, # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
) # pragma: no cover
async def _handle_json_response(
self,
response: httpx.Response,
read_stream_writer: StreamWriter,
is_initialization: bool = False,
) -> None:
"""Handle JSON response from the server."""
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content)
# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc: # pragma: no cover
logger.exception("Error parsing JSON response")
await read_stream_writer.send(exc)
async def _handle_sse_response(
self,
response: httpx.Response,
ctx: RequestContext,
is_initialization: bool = False,
) -> None:
"""Handle SSE response from the server."""
last_event_id: str | None = None
retry_interval_ms: int | None = None
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse(): # pragma: no branch
# Track last event ID for potential reconnection
if sse.id:
last_event_id = sse.id
# Track retry interval from server
if sse.retry is not None:
retry_interval_ms = sse.retry
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
await response.aclose()
return # Normal completion, no reconnect needed
except Exception as e: # pragma: no cover
logger.debug(f"SSE stream ended: {e}")
# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
async def _handle_reconnection(
self,
ctx: RequestContext,
last_event_id: str,
retry_interval_ms: int | None = None,
attempt: int = 0,
) -> None:
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
return
# Always wait - use server value or default
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
await anyio.sleep(delay_ms / 1000.0)
headers = self._prepare_headers()
headers[LAST_EVENT_ID] = last_event_id
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.root.id
try:
async with aconnect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
) as event_source:
event_source.response.raise_for_status()
logger.info("Reconnected to SSE stream")
# Track for potential further reconnection
reconnect_last_event_id: str = last_event_id
reconnect_retry_ms = retry_interval_ms
async for sse in event_source.aiter_sse():
if sse.id: # pragma: no branch
reconnect_last_event_id = sse.id
if sse.retry is not None:
reconnect_retry_ms = sse.retry
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
return
# Stream ended again without response - reconnect again (reset attempt counter)
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
except Exception as e: # pragma: no cover
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
async def _handle_unexpected_content_type(
self,
content_type: str,
read_stream_writer: StreamWriter,
) -> None: # pragma: no cover
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
logger.error(error_msg) # pragma: no cover
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover
async def _send_session_terminated_error(
self,
read_stream_writer: StreamWriter,
request_id: RequestId,
) -> None:
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
await read_stream_writer.send(session_message)
async def post_writer(
self,
client: httpx.AsyncClient,
write_stream_reader: StreamReader,
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
) -> None:
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
message = session_message.message
metadata = (
session_message.metadata
if isinstance(session_message.metadata, ClientMessageMetadata)
else None
)
# Check if this is a resumption request
is_resumption = bool(metadata and metadata.resumption_token)
logger.debug(f"Sending client message: {message}")
# Handle initialized notification
if self._is_initialized_notification(message):
start_get_stream()
ctx = RequestContext(
client=client,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
read_stream_writer=read_stream_writer,
)
async def handle_request_async():
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
# If this is a request, start a new task to handle it
if isinstance(message.root, JSONRPCRequest):
tg.start_soon(handle_request_async)
else:
await handle_request_async()
except Exception:
logger.exception("Error in post_writer") # pragma: no cover
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover
"""Terminate the session by sending a DELETE request."""
if not self.session_id:
return
try:
headers = self._prepare_headers()
response = await client.delete(self.url, headers=headers)
if response.status_code == 405:
logger.debug("Server does not allow session termination")
elif response.status_code not in (200, 204):
logger.warning(f"Session termination failed: {response.status_code}")
except Exception as exc:
logger.warning(f"Session termination failed: {exc}")
def get_session_id(self) -> str | None:
"""Get the current session ID."""
return self.session_id
@asynccontextmanager
async def streamable_http_client(
url: str,
*,
http_client: httpx.AsyncClient | None = None,
terminate_on_close: bool = True,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
],
None,
]:
"""
Client transport for StreamableHTTP.
Args:
url: The MCP server endpoint URL.
http_client: Optional pre-configured httpx.AsyncClient. If None, a default
client with recommended MCP timeouts will be created. To configure headers,
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
terminate_on_close: If True, send a DELETE request to terminate the session
when the context exits.
Yields:
Tuple containing:
- read_stream: Stream for reading messages from the server
- write_stream: Stream for sending messages to the server
- get_session_id_callback: Function to retrieve the current session ID
Example:
See examples/snippets/clients/ for usage patterns.
"""
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
# Determine if we need to create and manage the client
client_provided = http_client is not None
client = http_client
if client is None:
# Create default client with recommended MCP timeouts
client = create_mcp_http_client()
transport = StreamableHTTPTransport(url)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
async with contextlib.AsyncExitStack() as stack:
# Only manage client lifecycle if we created it
if not client_provided:
await stack.enter_async_context(client)
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
tg.start_soon(
transport.post_writer,
client,
write_stream_reader,
read_stream_writer,
write_stream,
start_get_stream,
tg,
)
try:
yield (
read_stream,
write_stream,
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
@asynccontextmanager
@deprecated("Use `streamable_http_client` instead.")
async def streamablehttp_client(
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
],
None,
]:
# Convert timeout parameters
timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
sse_read_timeout_seconds = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
# Create httpx client using the factory with old-style parameters
client = httpx_client_factory(
headers=headers,
timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds),
auth=auth,
)
# Manage client lifecycle since we created it
async with client:
async with streamable_http_client(
url,
http_client=client,
terminate_on_close=terminate_on_close,
) as streams:
yield streams

View File

@@ -0,0 +1,86 @@
import json
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from websockets.asyncio.client import connect as ws_connect
from websockets.typing import Subprotocol
import mcp.types as types
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
@asynccontextmanager
async def websocket_client(
url: str,
) -> AsyncGenerator[
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]],
None,
]:
"""
WebSocket client transport for MCP, symmetrical to the server version.
Connects to 'url' using the 'mcp' subprotocol, then yields:
(read_stream, write_stream)
- read_stream: As you read from this stream, you'll receive either valid
JSONRPCMessage objects or Exception objects (when validation fails).
- write_stream: Write JSONRPCMessage objects to this stream to send them
over the WebSocket to the server.
"""
# Create two in-memory streams:
# - One for incoming messages (read_stream, written by ws_reader)
# - One for outgoing messages (write_stream, read by ws_writer)
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
# Connect using websockets, requesting the "mcp" subprotocol
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
async def ws_reader():
"""
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
and sends them into read_stream_writer.
"""
async with read_stream_writer:
async for raw_text in ws:
try:
message = types.JSONRPCMessage.model_validate_json(raw_text)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except ValidationError as exc: # pragma: no cover
# If JSON parse or model validation fails, send the exception
await read_stream_writer.send(exc)
async def ws_writer():
"""
Reads JSON-RPC messages from write_stream_reader and
sends them to the server.
"""
async with write_stream_reader:
async for session_message in write_stream_reader:
# Convert to a dict, then to JSON
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
await ws.send(json.dumps(msg_dict))
async with anyio.create_task_group() as tg:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
# Yield the receive/send streams
yield (read_stream, write_stream)
# Once the caller's 'async with' block exits, we shut down
tg.cancel_scope.cancel()

View File

@@ -0,0 +1 @@
"""Platform-specific utilities for MCP."""

View File

@@ -0,0 +1 @@
"""POSIX-specific utilities for MCP."""

View File

@@ -0,0 +1,60 @@
"""
POSIX-specific functionality for stdio client operations.
"""
import logging
import os
import signal
import anyio
from anyio.abc import Process
logger = logging.getLogger(__name__)
async def terminate_posix_process_tree(process: Process, timeout_seconds: float = 2.0) -> None:
"""
Terminate a process and all its children on POSIX systems.
Uses os.killpg() for atomic process group termination.
Args:
process: The process to terminate
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
"""
pid = getattr(process, "pid", None) or getattr(getattr(process, "popen", None), "pid", None)
if not pid:
# No PID means there's no process to terminate - it either never started,
# already exited, or we have an invalid process object
return
try:
pgid = os.getpgid(pid)
os.killpg(pgid, signal.SIGTERM)
with anyio.move_on_after(timeout_seconds):
while True:
try:
# Check if process group still exists (signal 0 = check only)
os.killpg(pgid, 0)
await anyio.sleep(0.1)
except ProcessLookupError:
return
try:
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
pass
except (ProcessLookupError, PermissionError, OSError) as e:
logger.warning(f"Process group termination failed for PID {pid}: {e}, falling back to simple terminate")
try:
process.terminate()
with anyio.fail_after(timeout_seconds):
await process.wait()
except Exception:
logger.warning(f"Process termination failed for PID {pid}, attempting force kill")
try:
process.kill()
except Exception:
logger.exception(f"Failed to kill process {pid}")

View File

@@ -0,0 +1 @@
"""Windows-specific utilities for MCP."""

View File

@@ -0,0 +1,338 @@
"""
Windows-specific functionality for stdio client operations.
"""
import logging
import shutil
import subprocess
import sys
from pathlib import Path
from typing import BinaryIO, TextIO, cast
import anyio
from anyio import to_thread
from anyio.abc import Process
from anyio.streams.file import FileReadStream, FileWriteStream
from typing_extensions import deprecated
logger = logging.getLogger("client.stdio.win32")
# Windows-specific imports for Job Objects
if sys.platform == "win32":
import pywintypes
import win32api
import win32con
import win32job
else:
# Type stubs for non-Windows platforms
win32api = None
win32con = None
win32job = None
pywintypes = None
JobHandle = int
def get_windows_executable_command(command: str) -> str:
"""
Get the correct executable command normalized for Windows.
On Windows, commands might exist with specific extensions (.exe, .cmd, etc.)
that need to be located for proper execution.
Args:
command: Base command (e.g., 'uvx', 'npx')
Returns:
str: Windows-appropriate command path
"""
try:
# First check if command exists in PATH as-is
if command_path := shutil.which(command):
return command_path
# Check for Windows-specific extensions
for ext in [".cmd", ".bat", ".exe", ".ps1"]:
ext_version = f"{command}{ext}"
if ext_path := shutil.which(ext_version):
return ext_path
# For regular commands or if we couldn't find special versions
return command
except OSError:
# Handle file system errors during path resolution
# (permissions, broken symlinks, etc.)
return command
class FallbackProcess:
"""
A fallback process wrapper for Windows to handle async I/O
when using subprocess.Popen, which provides sync-only FileIO objects.
This wraps stdin and stdout into async-compatible
streams (FileReadStream, FileWriteStream),
so that MCP clients expecting async streams can work properly.
"""
def __init__(self, popen_obj: subprocess.Popen[bytes]):
self.popen: subprocess.Popen[bytes] = popen_obj
self.stdin_raw = popen_obj.stdin # type: ignore[assignment]
self.stdout_raw = popen_obj.stdout # type: ignore[assignment]
self.stderr = popen_obj.stderr # type: ignore[assignment]
self.stdin = FileWriteStream(cast(BinaryIO, self.stdin_raw)) if self.stdin_raw else None
self.stdout = FileReadStream(cast(BinaryIO, self.stdout_raw)) if self.stdout_raw else None
async def __aenter__(self):
"""Support async context manager entry."""
return self
async def __aexit__(
self,
exc_type: BaseException | None,
exc_val: BaseException | None,
exc_tb: object | None,
) -> None:
"""Terminate and wait on process exit inside a thread."""
self.popen.terminate()
await to_thread.run_sync(self.popen.wait)
# Close the file handles to prevent ResourceWarning
if self.stdin:
await self.stdin.aclose()
if self.stdout:
await self.stdout.aclose()
if self.stdin_raw:
self.stdin_raw.close()
if self.stdout_raw:
self.stdout_raw.close()
if self.stderr:
self.stderr.close()
async def wait(self):
"""Async wait for process completion."""
return await to_thread.run_sync(self.popen.wait)
def terminate(self):
"""Terminate the subprocess immediately."""
return self.popen.terminate()
def kill(self) -> None:
"""Kill the subprocess immediately (alias for terminate)."""
self.terminate()
@property
def pid(self) -> int:
"""Return the process ID."""
return self.popen.pid
# ------------------------
# Updated function
# ------------------------
async def create_windows_process(
command: str,
args: list[str],
env: dict[str, str] | None = None,
errlog: TextIO | None = sys.stderr,
cwd: Path | str | None = None,
) -> Process | FallbackProcess:
"""
Creates a subprocess in a Windows-compatible way with Job Object support.
Attempt to use anyio's open_process for async subprocess creation.
In some cases this will throw NotImplementedError on Windows, e.g.
when using the SelectorEventLoop which does not support async subprocesses.
In that case, we fall back to using subprocess.Popen.
The process is automatically added to a Job Object to ensure all child
processes are terminated when the parent is terminated.
Args:
command (str): The executable to run
args (list[str]): List of command line arguments
env (dict[str, str] | None): Environment variables
errlog (TextIO | None): Where to send stderr output (defaults to sys.stderr)
cwd (Path | str | None): Working directory for the subprocess
Returns:
Process | FallbackProcess: Async-compatible subprocess with stdin and stdout streams
"""
job = _create_job_object()
process = None
try:
# First try using anyio with Windows-specific flags to hide console window
process = await anyio.open_process(
[command, *args],
env=env,
# Ensure we don't create console windows for each process
creationflags=subprocess.CREATE_NO_WINDOW # type: ignore
if hasattr(subprocess, "CREATE_NO_WINDOW")
else 0,
stderr=errlog,
cwd=cwd,
)
except NotImplementedError:
# If Windows doesn't support async subprocess creation, use fallback
process = await _create_windows_fallback_process(command, args, env, errlog, cwd)
except Exception:
# Try again without creation flags
process = await anyio.open_process(
[command, *args],
env=env,
stderr=errlog,
cwd=cwd,
)
_maybe_assign_process_to_job(process, job)
return process
async def _create_windows_fallback_process(
command: str,
args: list[str],
env: dict[str, str] | None = None,
errlog: TextIO | None = sys.stderr,
cwd: Path | str | None = None,
) -> FallbackProcess:
"""
Create a subprocess using subprocess.Popen as a fallback when anyio fails.
This function wraps the sync subprocess.Popen in an async-compatible interface.
"""
try:
# Try launching with creationflags to avoid opening a new console window
popen_obj = subprocess.Popen(
[command, *args],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=errlog,
env=env,
cwd=cwd,
bufsize=0, # Unbuffered output
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
)
except Exception:
# If creationflags failed, fallback without them
popen_obj = subprocess.Popen(
[command, *args],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=errlog,
env=env,
cwd=cwd,
bufsize=0,
)
return FallbackProcess(popen_obj)
def _create_job_object() -> int | None:
"""
Create a Windows Job Object configured to terminate all processes when closed.
"""
if sys.platform != "win32" or not win32job:
return None
try:
job = win32job.CreateJobObject(None, "")
extended_info = win32job.QueryInformationJobObject(job, win32job.JobObjectExtendedLimitInformation)
extended_info["BasicLimitInformation"]["LimitFlags"] |= win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
win32job.SetInformationJobObject(job, win32job.JobObjectExtendedLimitInformation, extended_info)
return job
except Exception as e:
logger.warning(f"Failed to create Job Object for process tree management: {e}")
return None
def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHandle | None) -> None:
"""
Try to assign a process to a job object. If assignment fails
for any reason, the job handle is closed.
"""
if not job:
return
if sys.platform != "win32" or not win32api or not win32con or not win32job:
return
try:
process_handle = win32api.OpenProcess(
win32con.PROCESS_SET_QUOTA | win32con.PROCESS_TERMINATE, False, process.pid
)
if not process_handle:
raise Exception("Failed to open process handle")
try:
win32job.AssignProcessToJobObject(job, process_handle)
process._job_object = job
finally:
win32api.CloseHandle(process_handle)
except Exception as e:
logger.warning(f"Failed to assign process {process.pid} to Job Object: {e}")
if win32api:
win32api.CloseHandle(job)
async def terminate_windows_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None:
"""
Terminate a process and all its children on Windows.
If the process has an associated job object, it will be terminated.
Otherwise, falls back to basic process termination.
Args:
process: The process to terminate
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
"""
if sys.platform != "win32":
return
job = getattr(process, "_job_object", None)
if job and win32job:
try:
win32job.TerminateJobObject(job, 1)
except Exception:
# Job might already be terminated
pass
finally:
if win32api:
try:
win32api.CloseHandle(job)
except Exception:
pass
# Always try to terminate the process itself as well
try:
process.terminate()
except Exception:
pass
@deprecated(
"terminate_windows_process is deprecated and will be removed in a future version. "
"Process termination is now handled internally by the stdio_client context manager."
)
async def terminate_windows_process(process: Process | FallbackProcess):
"""
Terminate a Windows process.
Note: On Windows, terminating a process with process.terminate() doesn't
always guarantee immediate process termination.
So we give it 2s to exit, or we call process.kill()
which sends a SIGKILL equivalent signal.
Args:
process: The process to terminate
"""
try:
process.terminate()
with anyio.fail_after(2.0):
await process.wait()
except TimeoutError:
# Force kill if it doesn't terminate
process.kill()

View File

@@ -0,0 +1,5 @@
from .fastmcp import FastMCP
from .lowlevel import NotificationOptions, Server
from .models import InitializationOptions
__all__ = ["Server", "FastMCP", "NotificationOptions", "InitializationOptions"]

View File

@@ -0,0 +1,50 @@
import importlib.metadata
import logging
import sys
import anyio
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server
from mcp.types import ServerCapabilities
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("server")
async def receive_loop(session: ServerSession):
logger.info("Starting receive loop")
async for message in session.incoming_messages:
if isinstance(message, Exception):
logger.error("Error: %s", message)
continue
logger.info("Received message from client: %s", message)
async def main():
version = importlib.metadata.version("mcp")
async with stdio_server() as (read_stream, write_stream):
async with (
ServerSession(
read_stream,
write_stream,
InitializationOptions(
server_name="mcp",
server_version=version,
capabilities=ServerCapabilities(),
),
) as session,
write_stream,
):
await receive_loop(session)
if __name__ == "__main__":
anyio.run(main, backend="trio")

View File

@@ -0,0 +1,3 @@
"""
MCP OAuth server authorization components.
"""

View File

@@ -0,0 +1,5 @@
from pydantic import ValidationError
def stringify_pydantic_error(validation_error: ValidationError) -> str:
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())

View File

@@ -0,0 +1,3 @@
"""
Request handlers for MCP authorization endpoints.
"""

View File

@@ -0,0 +1,224 @@
import logging
from dataclasses import dataclass
from typing import Any, Literal
from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.datastructures import FormData, QueryParams
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from mcp.server.auth.errors import stringify_pydantic_error
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.provider import (
AuthorizationErrorCode,
AuthorizationParams,
AuthorizeError,
OAuthAuthorizationServerProvider,
construct_redirect_uri,
)
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
logger = logging.getLogger(__name__)
class AuthorizationRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
client_id: str = Field(..., description="The client ID")
redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
# see OAuthClientMetadata; we only support `code`
response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
code_challenge: str = Field(..., description="PKCE code challenge")
code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
state: str | None = Field(None, description="Optional state parameter")
scope: str | None = Field(
None,
description="Optional scope; if specified, should be a space-separated list of scope strings",
)
resource: str | None = Field(
None,
description="RFC 8707 resource indicator - the MCP server this token will be used with",
)
class AuthorizationErrorResponse(BaseModel):
error: AuthorizationErrorCode
error_description: str | None
error_uri: AnyUrl | None = None
# must be set if provided in the request
state: str | None = None
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
if params is None: # pragma: no cover
return None
value = params.get(key)
if isinstance(value, str):
return value
return None
class AnyUrlModel(RootModel[AnyUrl]):
root: AnyUrl
@dataclass
class AuthorizationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
async def handle(self, request: Request) -> Response:
# implements authorization requests for grant_type=code;
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
state = None
redirect_uri = None
client = None
params = None
async def error_response(
error: AuthorizationErrorCode,
error_description: str | None,
attempt_load_client: bool = True,
):
# Error responses take two different formats:
# 1. The request has a valid client ID & redirect_uri: we issue a redirect
# back to the redirect_uri with the error response fields as query
# parameters. This allows the client to be notified of the error.
# 2. Otherwise, we return an error response directly to the end user;
# we choose to do so in JSON, but this is left undefined in the
# specification.
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
#
# This logic is a bit awkward to handle, because the error might be thrown
# very early in request validation, before we've done the usual Pydantic
# validation, loaded the client, etc. To handle this, error_response()
# contains fallback logic which attempts to load the parameters directly
# from the request.
nonlocal client, redirect_uri, state
if client is None and attempt_load_client:
# make last-ditch attempt to load the client
client_id = best_effort_extract_string("client_id", params)
client = await self.provider.get_client(client_id) if client_id else None
if redirect_uri is None and client:
# make last-ditch effort to load the redirect uri
try:
if params is not None and "redirect_uri" not in params:
raw_redirect_uri = None
else:
raw_redirect_uri = AnyUrlModel.model_validate(
best_effort_extract_string("redirect_uri", params)
).root
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
except (ValidationError, InvalidRedirectUriError):
# if the redirect URI is invalid, ignore it & just return the
# initial error
pass
# the error response MUST contain the state specified by the client, if any
if state is None: # pragma: no cover
# make last-ditch effort to load state
state = best_effort_extract_string("state", params)
error_resp = AuthorizationErrorResponse(
error=error,
error_description=error_description,
state=state,
)
if redirect_uri and client:
return RedirectResponse(
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
status_code=302,
headers={"Cache-Control": "no-store"},
)
else:
return PydanticJSONResponse(
status_code=400,
content=error_resp,
headers={"Cache-Control": "no-store"},
)
try:
# Parse request parameters
if request.method == "GET":
# Convert query_params to dict for pydantic validation
params = request.query_params
else:
# Parse form data for POST requests
params = await request.form()
# Save state if it exists, even before validation
state = best_effort_extract_string("state", params)
try:
auth_request = AuthorizationRequest.model_validate(params)
state = auth_request.state # Update with validated state
except ValidationError as validation_error:
error: AuthorizationErrorCode = "invalid_request"
for e in validation_error.errors():
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
error = "unsupported_response_type"
break
return await error_response(error, stringify_pydantic_error(validation_error))
# Get client information
client = await self.provider.get_client(
auth_request.client_id,
)
if not client:
# For client_id validation errors, return direct error (no redirect)
return await error_response(
error="invalid_request",
error_description=f"Client ID '{auth_request.client_id}' not found",
attempt_load_client=False,
)
# Validate redirect_uri against client's registered URIs
try:
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
except InvalidRedirectUriError as validation_error:
# For redirect_uri validation errors, return direct error (no redirect)
return await error_response(
error="invalid_request",
error_description=validation_error.message,
)
# Validate scope - for scope errors, we can redirect
try:
scopes = client.validate_scope(auth_request.scope)
except InvalidScopeError as validation_error:
# For scope errors, redirect with error parameters
return await error_response(
error="invalid_scope",
error_description=validation_error.message,
)
# Setup authorization parameters
auth_params = AuthorizationParams(
state=state,
scopes=scopes,
code_challenge=auth_request.code_challenge,
redirect_uri=redirect_uri,
redirect_uri_provided_explicitly=auth_request.redirect_uri is not None,
resource=auth_request.resource, # RFC 8707
)
try:
# Let the provider pick the next URI to redirect to
return RedirectResponse(
url=await self.provider.authorize(
client,
auth_params,
),
status_code=302,
headers={"Cache-Control": "no-store"},
)
except AuthorizeError as e:
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
return await error_response(error=e.error, error_description=e.error_description)
except Exception as validation_error: # pragma: no cover
# Catch-all for unexpected errors
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
return await error_response(error="server_error", error_description="An unexpected error occurred")

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
@dataclass
class MetadataHandler:
metadata: OAuthMetadata
async def handle(self, request: Request) -> Response:
return PydanticJSONResponse(
content=self.metadata,
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
)
@dataclass
class ProtectedResourceMetadataHandler:
metadata: ProtectedResourceMetadata
async def handle(self, request: Request) -> Response:
return PydanticJSONResponse(
content=self.metadata,
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
)

View File

@@ -0,0 +1,136 @@
import secrets
import time
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, RootModel, ValidationError
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.errors import stringify_pydantic_error
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
from mcp.server.auth.settings import ClientRegistrationOptions
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
class RegistrationRequest(RootModel[OAuthClientMetadata]):
# this wrapper is a no-op; it's just to separate out the types exposed to the
# provider from what we use in the HTTP handler
root: OAuthClientMetadata
class RegistrationErrorResponse(BaseModel):
error: RegistrationErrorCode
error_description: str | None
@dataclass
class RegistrationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
options: ClientRegistrationOptions
async def handle(self, request: Request) -> Response:
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
try:
# Parse request body as JSON
body = await request.json()
client_metadata = OAuthClientMetadata.model_validate(body)
# Scope validation is handled below
except ValidationError as validation_error:
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description=stringify_pydantic_error(validation_error),
),
status_code=400,
)
client_id = str(uuid4())
# If auth method is None, default to client_secret_post
if client_metadata.token_endpoint_auth_method is None:
client_metadata.token_endpoint_auth_method = "client_secret_post"
client_secret = None
if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch
# cryptographically secure random 32-byte hex string
client_secret = secrets.token_hex(32)
if client_metadata.scope is None and self.options.default_scopes is not None:
client_metadata.scope = " ".join(self.options.default_scopes)
elif client_metadata.scope is not None and self.options.valid_scopes is not None:
requested_scopes = set(client_metadata.scope.split())
valid_scopes = set(self.options.valid_scopes)
if not requested_scopes.issubset(valid_scopes): # pragma: no branch
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description="Requested scopes are not valid: "
f"{', '.join(requested_scopes - valid_scopes)}",
),
status_code=400,
)
if not {"authorization_code", "refresh_token"}.issubset(set(client_metadata.grant_types)):
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description="grant_types must be authorization_code and refresh_token",
),
status_code=400,
)
# The MCP spec requires servers to use the authorization `code` flow
# with PKCE
if "code" not in client_metadata.response_types:
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description="response_types must include 'code' for authorization_code grant",
),
status_code=400,
)
client_id_issued_at = int(time.time())
client_secret_expires_at = (
client_id_issued_at + self.options.client_secret_expiry_seconds
if self.options.client_secret_expiry_seconds is not None
else None
)
client_info = OAuthClientInformationFull(
client_id=client_id,
client_id_issued_at=client_id_issued_at,
client_secret=client_secret,
client_secret_expires_at=client_secret_expires_at,
# passthrough information from the client request
redirect_uris=client_metadata.redirect_uris,
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
grant_types=client_metadata.grant_types,
response_types=client_metadata.response_types,
client_name=client_metadata.client_name,
client_uri=client_metadata.client_uri,
logo_uri=client_metadata.logo_uri,
scope=client_metadata.scope,
contacts=client_metadata.contacts,
tos_uri=client_metadata.tos_uri,
policy_uri=client_metadata.policy_uri,
jwks_uri=client_metadata.jwks_uri,
jwks=client_metadata.jwks,
software_id=client_metadata.software_id,
software_version=client_metadata.software_version,
)
try:
# Register client
await self.provider.register_client(client_info)
# Return client information
return PydanticJSONResponse(content=client_info, status_code=201)
except RegistrationError as e:
# Handle registration errors as defined in RFC 7591 Section 3.2.2
return PydanticJSONResponse(
content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
status_code=400,
)

View File

@@ -0,0 +1,91 @@
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal
from pydantic import BaseModel, ValidationError
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.errors import (
stringify_pydantic_error,
)
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
class RevocationRequest(BaseModel):
"""
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
"""
token: str
token_type_hint: Literal["access_token", "refresh_token"] | None = None
client_id: str
client_secret: str | None
class RevocationErrorResponse(BaseModel):
error: Literal["invalid_request", "unauthorized_client"]
error_description: str | None = None
@dataclass
class RevocationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
client_authenticator: ClientAuthenticator
async def handle(self, request: Request) -> Response:
"""
Handler for the OAuth 2.0 Token Revocation endpoint.
"""
try:
client = await self.client_authenticator.authenticate_request(request)
except AuthenticationError as e: # pragma: no cover
return PydanticJSONResponse(
status_code=401,
content=RevocationErrorResponse(
error="unauthorized_client",
error_description=e.message,
),
)
try:
form_data = await request.form()
revocation_request = RevocationRequest.model_validate(dict(form_data))
except ValidationError as e:
return PydanticJSONResponse(
status_code=400,
content=RevocationErrorResponse(
error="invalid_request",
error_description=stringify_pydantic_error(e),
),
)
loaders = [
self.provider.load_access_token,
partial(self.provider.load_refresh_token, client),
]
if revocation_request.token_type_hint == "refresh_token": # pragma: no cover
loaders = reversed(loaders)
token: None | AccessToken | RefreshToken = None
for loader in loaders:
token = await loader(revocation_request.token)
if token is not None:
break
# if token is not found, just return HTTP 200 per the RFC
if token and token.client_id == client.client_id:
# Revoke token; provider is not meant to be able to do validation
# at this point that would result in an error
await self.provider.revoke_token(token)
# Return successful empty response
return Response(
status_code=200,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)

View File

@@ -0,0 +1,241 @@
import base64
import hashlib
import time
from dataclasses import dataclass
from typing import Annotated, Any, Literal
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.requests import Request
from mcp.server.auth.errors import stringify_pydantic_error
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
from mcp.shared.auth import OAuthToken
class AuthorizationCodeRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
grant_type: Literal["authorization_code"]
code: str = Field(..., description="The authorization code")
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
client_id: str
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
code_verifier: str = Field(..., description="PKCE code verifier")
# RFC 8707 resource indicator
resource: str | None = Field(None, description="Resource indicator for the token")
class RefreshTokenRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
grant_type: Literal["refresh_token"]
refresh_token: str = Field(..., description="The refresh token")
scope: str | None = Field(None, description="Optional scope parameter")
client_id: str
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None
# RFC 8707 resource indicator
resource: str | None = Field(None, description="Resource indicator for the token")
class TokenRequest(
RootModel[
Annotated[
AuthorizationCodeRequest | RefreshTokenRequest,
Field(discriminator="grant_type"),
]
]
):
root: Annotated[
AuthorizationCodeRequest | RefreshTokenRequest,
Field(discriminator="grant_type"),
]
class TokenErrorResponse(BaseModel):
"""
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
"""
error: TokenErrorCode
error_description: str | None = None
error_uri: AnyHttpUrl | None = None
class TokenSuccessResponse(RootModel[OAuthToken]):
# this is just a wrapper over OAuthToken; the only reason we do this
# is to have some separation between the HTTP response type, and the
# type returned by the provider
root: OAuthToken
@dataclass
class TokenHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
client_authenticator: ClientAuthenticator
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
status_code = 200
if isinstance(obj, TokenErrorResponse):
status_code = 400
return PydanticJSONResponse(
content=obj,
status_code=status_code,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)
async def handle(self, request: Request):
try:
client_info = await self.client_authenticator.authenticate_request(request)
except AuthenticationError as e:
# Authentication failures should return 401
return PydanticJSONResponse(
content=TokenErrorResponse(
error="unauthorized_client",
error_description=e.message,
),
status_code=401,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)
try:
form_data = await request.form()
token_request = TokenRequest.model_validate(dict(form_data)).root
except ValidationError as validation_error: # pragma: no cover
return self.response(
TokenErrorResponse(
error="invalid_request",
error_description=stringify_pydantic_error(validation_error),
)
)
if token_request.grant_type not in client_info.grant_types: # pragma: no cover
return self.response(
TokenErrorResponse(
error="unsupported_grant_type",
error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"),
)
)
tokens: OAuthToken
match token_request:
case AuthorizationCodeRequest():
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
if auth_code is None or auth_code.client_id != token_request.client_id:
# if code belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="authorization code does not exist",
)
)
# make auth codes expire after a deadline
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
if auth_code.expires_at < time.time():
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="authorization code has expired",
)
)
# verify redirect_uri doesn't change between /authorize and /tokens
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
if auth_code.redirect_uri_provided_explicitly:
authorize_request_redirect_uri = auth_code.redirect_uri
else: # pragma: no cover
authorize_request_redirect_uri = None
# Convert both sides to strings for comparison to handle AnyUrl vs string issues
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
auth_redirect_str = (
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
)
if token_redirect_str != auth_redirect_str:
return self.response(
TokenErrorResponse(
error="invalid_request",
error_description=("redirect_uri did not match the one used when creating auth code"),
)
)
# Verify PKCE code verifier
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
if hashed_code_verifier != auth_code.code_challenge:
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="incorrect code_verifier",
)
)
try:
# Exchange authorization code for tokens
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
case RefreshTokenRequest(): # pragma: no cover
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
if refresh_token is None or refresh_token.client_id != token_request.client_id:
# if token belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="refresh token does not exist",
)
)
if refresh_token.expires_at and refresh_token.expires_at < time.time():
# if the refresh token has expired, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="refresh token has expired",
)
)
# Parse scopes if provided
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
for scope in scopes:
if scope not in refresh_token.scopes:
return self.response(
TokenErrorResponse(
error="invalid_scope",
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
)
)
try:
# Exchange refresh token for new tokens
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
return self.response(TokenSuccessResponse(root=tokens))

View File

@@ -0,0 +1,10 @@
from typing import Any
from starlette.responses import JSONResponse
class PydanticJSONResponse(JSONResponse):
# use pydantic json serialization instead of the stock `json.dumps`,
# so that we can handle serializing pydantic models like AnyHttpUrl
def render(self, content: Any) -> bytes:
return content.model_dump_json(exclude_none=True).encode("utf-8")

View File

@@ -0,0 +1,3 @@
"""
Middleware for MCP authorization.
"""

View File

@@ -0,0 +1,48 @@
import contextvars
from starlette.types import ASGIApp, Receive, Scope, Send
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
# Create a contextvar to store the authenticated user
# The default is None, indicating no authenticated user is present
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
def get_access_token() -> AccessToken | None:
"""
Get the access token from the current context.
Returns:
The access token if an authenticated user is available, None otherwise.
"""
auth_user = auth_context_var.get()
return auth_user.access_token if auth_user else None
class AuthContextMiddleware:
"""
Middleware that extracts the authenticated user from the request
and sets it in a contextvar for easy access throughout the request lifecycle.
This middleware should be added after the AuthenticationMiddleware in the
middleware stack to ensure that the user is properly authenticated before
being stored in the context.
"""
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
user = scope.get("user")
if isinstance(user, AuthenticatedUser):
# Set the authenticated user in the contextvar
token = auth_context_var.set(user)
try:
await self.app(scope, receive, send)
finally:
auth_context_var.reset(token)
else:
# No authenticated user, just process the request
await self.app(scope, receive, send)

View File

@@ -0,0 +1,128 @@
import json
import time
from typing import Any
from pydantic import AnyHttpUrl
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
from starlette.requests import HTTPConnection
from starlette.types import Receive, Scope, Send
from mcp.server.auth.provider import AccessToken, TokenVerifier
class AuthenticatedUser(SimpleUser):
"""User with authentication info."""
def __init__(self, auth_info: AccessToken):
super().__init__(auth_info.client_id)
self.access_token = auth_info
self.scopes = auth_info.scopes
class BearerAuthBackend(AuthenticationBackend):
"""
Authentication backend that validates Bearer tokens using a TokenVerifier.
"""
def __init__(self, token_verifier: TokenVerifier):
self.token_verifier = token_verifier
async def authenticate(self, conn: HTTPConnection):
auth_header = next(
(conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
None,
)
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header[7:] # Remove "Bearer " prefix
# Validate the token with the verifier
auth_info = await self.token_verifier.verify_token(token)
if not auth_info:
return None
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
return None
return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info)
class RequireAuthMiddleware:
"""
Middleware that requires a valid Bearer token in the Authorization header.
This will validate the token with the auth provider and store the resulting
auth info in the request state.
"""
def __init__(
self,
app: Any,
required_scopes: list[str],
resource_metadata_url: AnyHttpUrl | None = None,
):
"""
Initialize the middleware.
Args:
app: ASGI application
required_scopes: List of scopes that the token must have
resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header
"""
self.app = app
self.required_scopes = required_scopes
self.resource_metadata_url = resource_metadata_url
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
auth_user = scope.get("user")
if not isinstance(auth_user, AuthenticatedUser):
await self._send_auth_error(
send, status_code=401, error="invalid_token", description="Authentication required"
)
return
auth_credentials = scope.get("auth")
for required_scope in self.required_scopes:
# auth_credentials should always be provided; this is just paranoia
if auth_credentials is None or required_scope not in auth_credentials.scopes:
await self._send_auth_error(
send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}"
)
return
await self.app(scope, receive, send)
async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None:
"""Send an authentication error response with WWW-Authenticate header."""
# Build WWW-Authenticate header value
www_auth_parts = [f'error="{error}"', f'error_description="{description}"']
if self.resource_metadata_url: # pragma: no cover
www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"')
www_authenticate = f"Bearer {', '.join(www_auth_parts)}"
# Send response
body = {"error": error, "error_description": description}
body_bytes = json.dumps(body).encode()
await send(
{
"type": "http.response.start",
"status": status_code,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(body_bytes)).encode()),
(b"www-authenticate", www_authenticate.encode()),
],
}
)
await send(
{
"type": "http.response.body",
"body": body_bytes,
}
)

View File

@@ -0,0 +1,115 @@
import base64
import binascii
import hmac
import time
from typing import Any
from urllib.parse import unquote
from starlette.requests import Request
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.shared.auth import OAuthClientInformationFull
class AuthenticationError(Exception):
def __init__(self, message: str):
self.message = message # pragma: no cover
class ClientAuthenticator:
"""
ClientAuthenticator is a callable which validates requests from a client
application, used to verify /token calls.
If, during registration, the client requested to be issued a secret, the
authenticator asserts that /token calls must be authenticated with
that same token.
NOTE: clients can opt for no authentication during registration, in which case this
logic is skipped.
"""
def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
"""
Initialize the dependency.
Args:
provider: Provider to look up client information
"""
self.provider = provider
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull:
"""
Authenticate a client from an HTTP request.
Extracts client credentials from the appropriate location based on the
client's registered authentication method and validates them.
Args:
request: The HTTP request containing client credentials
Returns:
The authenticated client information
Raises:
AuthenticationError: If authentication fails
"""
form_data = await request.form()
client_id = form_data.get("client_id")
if not client_id:
raise AuthenticationError("Missing client_id")
client = await self.provider.get_client(str(client_id))
if not client:
raise AuthenticationError("Invalid client_id") # pragma: no cover
request_client_secret: str | None = None
auth_header = request.headers.get("Authorization", "")
if client.token_endpoint_auth_method == "client_secret_basic":
if not auth_header.startswith("Basic "):
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")
try:
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
if ":" not in decoded:
raise ValueError("Invalid Basic auth format")
basic_client_id, request_client_secret = decoded.split(":", 1)
# URL-decode both parts per RFC 6749 Section 2.3.1
basic_client_id = unquote(basic_client_id)
request_client_secret = unquote(request_client_secret)
if basic_client_id != client_id:
raise AuthenticationError("Client ID mismatch in Basic auth")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise AuthenticationError("Invalid Basic authentication header")
elif client.token_endpoint_auth_method == "client_secret_post":
raw_form_data = form_data.get("client_secret")
# form_data.get() can return a UploadFile or None, so we need to check if it's a string
if isinstance(raw_form_data, str):
request_client_secret = str(raw_form_data)
elif client.token_endpoint_auth_method == "none":
request_client_secret = None
else:
raise AuthenticationError( # pragma: no cover
f"Unsupported auth method: {client.token_endpoint_auth_method}"
)
# If client from the store expects a secret, validate that the request provides
# that secret
if client.client_secret: # pragma: no branch
if not request_client_secret:
raise AuthenticationError("Client secret is required") # pragma: no cover
# hmac.compare_digest requires that both arguments are either bytes or a `str` containing
# only ASCII characters. Since we do not control `request_client_secret`, we encode both
# arguments to bytes.
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()):
raise AuthenticationError("Invalid client_secret") # pragma: no cover
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
raise AuthenticationError("Client secret has expired") # pragma: no cover
return client

View File

@@ -0,0 +1,301 @@
from dataclasses import dataclass
from typing import Generic, Literal, Protocol, TypeVar
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from pydantic import AnyUrl, BaseModel
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
class AuthorizationParams(BaseModel):
state: str | None
scopes: list[str] | None
code_challenge: str
redirect_uri: AnyUrl
redirect_uri_provided_explicitly: bool
resource: str | None = None # RFC 8707 resource indicator
class AuthorizationCode(BaseModel):
code: str
scopes: list[str]
expires_at: float
client_id: str
code_challenge: str
redirect_uri: AnyUrl
redirect_uri_provided_explicitly: bool
resource: str | None = None # RFC 8707 resource indicator
class RefreshToken(BaseModel):
token: str
client_id: str
scopes: list[str]
expires_at: int | None = None
class AccessToken(BaseModel):
token: str
client_id: str
scopes: list[str]
expires_at: int | None = None
resource: str | None = None # RFC 8707 resource indicator
RegistrationErrorCode = Literal[
"invalid_redirect_uri",
"invalid_client_metadata",
"invalid_software_statement",
"unapproved_software_statement",
]
@dataclass(frozen=True)
class RegistrationError(Exception):
error: RegistrationErrorCode
error_description: str | None = None
AuthorizationErrorCode = Literal[
"invalid_request",
"unauthorized_client",
"access_denied",
"unsupported_response_type",
"invalid_scope",
"server_error",
"temporarily_unavailable",
]
@dataclass(frozen=True)
class AuthorizeError(Exception):
error: AuthorizationErrorCode
error_description: str | None = None
TokenErrorCode = Literal[
"invalid_request",
"invalid_client",
"invalid_grant",
"unauthorized_client",
"unsupported_grant_type",
"invalid_scope",
]
@dataclass(frozen=True)
class TokenError(Exception):
error: TokenErrorCode
error_description: str | None = None
class TokenVerifier(Protocol):
"""Protocol for verifying bearer tokens."""
async def verify_token(self, token: str) -> AccessToken | None:
"""Verify a bearer token and return access info if valid."""
# NOTE: FastMCP doesn't render any of these types in the user response, so it's
# OK to add fields to subclasses which should not be exposed externally.
AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode)
RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""
Retrieves client information by client ID.
Implementors MAY raise NotImplementedError if dynamic client registration is
disabled in ClientRegistrationOptions.
Args:
client_id: The ID of the client to retrieve.
Returns:
The client information, or None if the client does not exist.
"""
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
"""
Saves client information as part of registering it.
Implementors MAY raise NotImplementedError if dynamic client registration is
disabled in ClientRegistrationOptions.
Args:
client_info: The client metadata to register.
Raises:
RegistrationError: If the client metadata is invalid.
"""
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
"""
Called as part of the /authorize endpoint, and returns a URL that the client
will be redirected to.
Many MCP implementations will redirect to a third-party provider to perform
a second OAuth exchange with that provider. In this sort of setup, the client
has an OAuth connection with the MCP server, and the MCP server has an OAuth
connection with the 3rd-party provider. At the end of this flow, the client
should be redirected to the redirect_uri from params.redirect_uri.
+--------+ +------------+ +-------------------+
| | | | | |
| Client | --> | MCP Server | --> | 3rd Party OAuth |
| | | | | Server |
+--------+ +------------+ +-------------------+
| ^ |
+------------+ | | |
| | | | Redirect |
|redirect_uri|<-----+ +------------------+
| |
+------------+
Implementations will need to define another handler on the MCP server return
flow to perform the second redirect, and generate and store an authorization
code as part of completing the OAuth authorization step.
Implementations SHOULD generate an authorization code with at least 160 bits of
entropy,
and MUST generate an authorization code with at least 128 bits of entropy.
See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10.
Args:
client: The client requesting authorization.
params: The parameters of the authorization request.
Returns:
A URL to redirect the client to for authorization.
Raises:
AuthorizeError: If the authorization request is invalid.
"""
...
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
) -> AuthorizationCodeT | None:
"""
Loads an AuthorizationCode by its code.
Args:
client: The client that requested the authorization code.
authorization_code: The authorization code to get the challenge for.
Returns:
The AuthorizationCode, or None if not found
"""
...
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT
) -> OAuthToken:
"""
Exchanges an authorization code for an access token and refresh token.
Args:
client: The client exchanging the authorization code.
authorization_code: The authorization code to exchange.
Returns:
The OAuth token, containing access and refresh tokens.
Raises:
TokenError: If the request is invalid
"""
...
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None:
"""
Loads a RefreshToken by its token string.
Args:
client: The client that is requesting to load the refresh token.
refresh_token: The refresh token string to load.
Returns:
The RefreshToken object if found, or None if not found.
"""
...
async def exchange_refresh_token(
self,
client: OAuthClientInformationFull,
refresh_token: RefreshTokenT,
scopes: list[str],
) -> OAuthToken:
"""
Exchanges a refresh token for an access token and refresh token.
Implementations SHOULD rotate both the access token and refresh token.
Args:
client: The client exchanging the refresh token.
refresh_token: The refresh token to exchange.
scopes: Optional scopes to request with the new access token.
Returns:
The OAuth token, containing access and refresh tokens.
Raises:
TokenError: If the request is invalid
"""
...
async def load_access_token(self, token: str) -> AccessTokenT | None:
"""
Loads an access token by its token.
Args:
token: The access token to verify.
Returns:
The AuthInfo, or None if the token is invalid.
"""
async def revoke_token(
self,
token: AccessTokenT | RefreshTokenT,
) -> None:
"""
Revokes an access or refresh token.
If the given token is invalid or already revoked, this method should do nothing.
Implementations SHOULD revoke both the access token and its corresponding
refresh token, regardless of which of the access token or refresh token is
provided.
Args:
token: the token to revoke
"""
def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
parsed_uri = urlparse(redirect_uri_base)
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs]
for k, v in params.items():
if v is not None:
query_params.append((k, v))
redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params)))
return redirect_uri
class ProviderTokenVerifier(TokenVerifier):
"""Token verifier that uses an OAuthAuthorizationServerProvider.
This is provided for backwards compatibility with existing auth_server_provider
configurations. For new implementations using AS/RS separation, consider using
the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier.
"""
def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"):
self.provider = provider
async def verify_token(self, token: str) -> AccessToken | None:
"""Verify token using the provider's load_access_token method."""
return await self.provider.load_access_token(token)

View File

@@ -0,0 +1,253 @@
from collections.abc import Awaitable, Callable
from typing import Any
from urllib.parse import urlparse
from pydantic import AnyHttpUrl
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route, request_response # type: ignore
from starlette.types import ASGIApp
from mcp.server.auth.handlers.authorize import AuthorizationHandler
from mcp.server.auth.handlers.metadata import MetadataHandler
from mcp.server.auth.handlers.register import RegistrationHandler
from mcp.server.auth.handlers.revoke import RevocationHandler
from mcp.server.auth.handlers.token import TokenHandler
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
from mcp.shared.auth import OAuthMetadata
def validate_issuer_url(url: AnyHttpUrl):
"""
Validate that the issuer URL meets OAuth 2.0 requirements.
Args:
url: The issuer URL to validate
Raises:
ValueError: If the issuer URL is invalid
"""
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
if (
url.scheme != "https"
and url.host != "localhost"
and (url.host is not None and not url.host.startswith("127.0.0.1"))
):
raise ValueError("Issuer URL must be HTTPS") # pragma: no cover
# No fragments or query parameters allowed
if url.fragment:
raise ValueError("Issuer URL must not have a fragment") # pragma: no cover
if url.query:
raise ValueError("Issuer URL must not have a query string") # pragma: no cover
AUTHORIZATION_PATH = "/authorize"
TOKEN_PATH = "/token"
REGISTRATION_PATH = "/register"
REVOCATION_PATH = "/revoke"
def cors_middleware(
handler: Callable[[Request], Response | Awaitable[Response]],
allow_methods: list[str],
) -> ASGIApp:
cors_app = CORSMiddleware(
app=request_response(handler),
allow_origins="*",
allow_methods=allow_methods,
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
)
return cors_app
def create_auth_routes(
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
issuer_url: AnyHttpUrl,
service_documentation_url: AnyHttpUrl | None = None,
client_registration_options: ClientRegistrationOptions | None = None,
revocation_options: RevocationOptions | None = None,
) -> list[Route]:
validate_issuer_url(issuer_url)
client_registration_options = client_registration_options or ClientRegistrationOptions()
revocation_options = revocation_options or RevocationOptions()
metadata = build_metadata(
issuer_url,
service_documentation_url,
client_registration_options,
revocation_options,
)
client_authenticator = ClientAuthenticator(provider)
# Create routes
# Allow CORS requests for endpoints meant to be hit by the OAuth client
# (with the client secret). This is intended to support things like MCP Inspector,
# where the client runs in a web browser.
routes = [
Route(
"/.well-known/oauth-authorization-server",
endpoint=cors_middleware(
MetadataHandler(metadata).handle,
["GET", "OPTIONS"],
),
methods=["GET", "OPTIONS"],
),
Route(
AUTHORIZATION_PATH,
# do not allow CORS for authorization endpoint;
# clients should just redirect to this
endpoint=AuthorizationHandler(provider).handle,
methods=["GET", "POST"],
),
Route(
TOKEN_PATH,
endpoint=cors_middleware(
TokenHandler(provider, client_authenticator).handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
),
]
if client_registration_options.enabled: # pragma: no branch
registration_handler = RegistrationHandler(
provider,
options=client_registration_options,
)
routes.append(
Route(
REGISTRATION_PATH,
endpoint=cors_middleware(
registration_handler.handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
)
)
if revocation_options.enabled: # pragma: no branch
revocation_handler = RevocationHandler(provider, client_authenticator)
routes.append(
Route(
REVOCATION_PATH,
endpoint=cors_middleware(
revocation_handler.handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
)
)
return routes
def build_metadata(
issuer_url: AnyHttpUrl,
service_documentation_url: AnyHttpUrl | None,
client_registration_options: ClientRegistrationOptions,
revocation_options: RevocationOptions,
) -> OAuthMetadata:
authorization_url = AnyHttpUrl(str(issuer_url).rstrip("/") + AUTHORIZATION_PATH)
token_url = AnyHttpUrl(str(issuer_url).rstrip("/") + TOKEN_PATH)
# Create metadata
metadata = OAuthMetadata(
issuer=issuer_url,
authorization_endpoint=authorization_url,
token_endpoint=token_url,
scopes_supported=client_registration_options.valid_scopes,
response_types_supported=["code"],
response_modes_supported=None,
grant_types_supported=["authorization_code", "refresh_token"],
token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"],
token_endpoint_auth_signing_alg_values_supported=None,
service_documentation=service_documentation_url,
ui_locales_supported=None,
op_policy_uri=None,
op_tos_uri=None,
introspection_endpoint=None,
code_challenge_methods_supported=["S256"],
)
# Add registration endpoint if supported
if client_registration_options.enabled: # pragma: no branch
metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH)
# Add revocation endpoint if supported
if revocation_options.enabled: # pragma: no branch
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"]
return metadata
def build_resource_metadata_url(resource_server_url: AnyHttpUrl) -> AnyHttpUrl:
"""
Build RFC 9728 compliant protected resource metadata URL.
Inserts /.well-known/oauth-protected-resource between host and resource path
as specified in RFC 9728 §3.1.
Args:
resource_server_url: The resource server URL (e.g., https://example.com/mcp)
Returns:
The metadata URL (e.g., https://example.com/.well-known/oauth-protected-resource/mcp)
"""
parsed = urlparse(str(resource_server_url))
# Handle trailing slash: if path is just "/", treat as empty
resource_path = parsed.path if parsed.path != "/" else ""
return AnyHttpUrl(f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-protected-resource{resource_path}")
def create_protected_resource_routes(
resource_url: AnyHttpUrl,
authorization_servers: list[AnyHttpUrl],
scopes_supported: list[str] | None = None,
resource_name: str | None = None,
resource_documentation: AnyHttpUrl | None = None,
) -> list[Route]:
"""
Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728).
Args:
resource_url: The URL of this resource server
authorization_servers: List of authorization servers that can issue tokens
scopes_supported: Optional list of scopes supported by this resource
Returns:
List of Starlette routes for protected resource metadata
"""
from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler
from mcp.shared.auth import ProtectedResourceMetadata
metadata = ProtectedResourceMetadata(
resource=resource_url,
authorization_servers=authorization_servers,
scopes_supported=scopes_supported,
resource_name=resource_name,
resource_documentation=resource_documentation,
# bearer_methods_supported defaults to ["header"] in the model
)
handler = ProtectedResourceMetadataHandler(metadata)
# RFC 9728 §3.1: Register route at /.well-known/oauth-protected-resource + resource path
metadata_url = build_resource_metadata_url(resource_url)
# Extract just the path part for route registration
parsed = urlparse(str(metadata_url))
well_known_path = parsed.path
return [
Route(
well_known_path,
endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]),
methods=["GET", "OPTIONS"],
)
]

View File

@@ -0,0 +1,30 @@
from pydantic import AnyHttpUrl, BaseModel, Field
class ClientRegistrationOptions(BaseModel):
enabled: bool = False
client_secret_expiry_seconds: int | None = None
valid_scopes: list[str] | None = None
default_scopes: list[str] | None = None
class RevocationOptions(BaseModel):
enabled: bool = False
class AuthSettings(BaseModel):
issuer_url: AnyHttpUrl = Field(
...,
description="OAuth authorization server URL that issues tokens for this resource server.",
)
service_documentation_url: AnyHttpUrl | None = None
client_registration_options: ClientRegistrationOptions | None = None
revocation_options: RevocationOptions | None = None
required_scopes: list[str] | None = None
# Resource Server settings (when operating as RS only)
resource_server_url: AnyHttpUrl | None = Field(
...,
description="The URL of the MCP server to be used as the resource identifier "
"and base route to look up OAuth Protected Resource Metadata.",
)

Some files were not shown because too many files have changed in this diff Show More