Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Middleware for MCP authorization.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,48 @@
|
||||
import contextvars
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
|
||||
from mcp.server.auth.provider import AccessToken
|
||||
|
||||
# Create a contextvar to store the authenticated user
|
||||
# The default is None, indicating no authenticated user is present
|
||||
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
|
||||
|
||||
|
||||
def get_access_token() -> AccessToken | None:
|
||||
"""
|
||||
Get the access token from the current context.
|
||||
|
||||
Returns:
|
||||
The access token if an authenticated user is available, None otherwise.
|
||||
"""
|
||||
auth_user = auth_context_var.get()
|
||||
return auth_user.access_token if auth_user else None
|
||||
|
||||
|
||||
class AuthContextMiddleware:
|
||||
"""
|
||||
Middleware that extracts the authenticated user from the request
|
||||
and sets it in a contextvar for easy access throughout the request lifecycle.
|
||||
|
||||
This middleware should be added after the AuthenticationMiddleware in the
|
||||
middleware stack to ensure that the user is properly authenticated before
|
||||
being stored in the context.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
user = scope.get("user")
|
||||
if isinstance(user, AuthenticatedUser):
|
||||
# Set the authenticated user in the contextvar
|
||||
token = auth_context_var.set(user)
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
auth_context_var.reset(token)
|
||||
else:
|
||||
# No authenticated user, just process the request
|
||||
await self.app(scope, receive, send)
|
||||
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.provider import AccessToken, TokenVerifier
|
||||
|
||||
|
||||
class AuthenticatedUser(SimpleUser):
|
||||
"""User with authentication info."""
|
||||
|
||||
def __init__(self, auth_info: AccessToken):
|
||||
super().__init__(auth_info.client_id)
|
||||
self.access_token = auth_info
|
||||
self.scopes = auth_info.scopes
|
||||
|
||||
|
||||
class BearerAuthBackend(AuthenticationBackend):
|
||||
"""
|
||||
Authentication backend that validates Bearer tokens using a TokenVerifier.
|
||||
"""
|
||||
|
||||
def __init__(self, token_verifier: TokenVerifier):
|
||||
self.token_verifier = token_verifier
|
||||
|
||||
async def authenticate(self, conn: HTTPConnection):
|
||||
auth_header = next(
|
||||
(conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
|
||||
None,
|
||||
)
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Validate the token with the verifier
|
||||
auth_info = await self.token_verifier.verify_token(token)
|
||||
|
||||
if not auth_info:
|
||||
return None
|
||||
|
||||
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
|
||||
return None
|
||||
|
||||
return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info)
|
||||
|
||||
|
||||
class RequireAuthMiddleware:
|
||||
"""
|
||||
Middleware that requires a valid Bearer token in the Authorization header.
|
||||
|
||||
This will validate the token with the auth provider and store the resulting
|
||||
auth info in the request state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: Any,
|
||||
required_scopes: list[str],
|
||||
resource_metadata_url: AnyHttpUrl | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application
|
||||
required_scopes: List of scopes that the token must have
|
||||
resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header
|
||||
"""
|
||||
self.app = app
|
||||
self.required_scopes = required_scopes
|
||||
self.resource_metadata_url = resource_metadata_url
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
auth_user = scope.get("user")
|
||||
if not isinstance(auth_user, AuthenticatedUser):
|
||||
await self._send_auth_error(
|
||||
send, status_code=401, error="invalid_token", description="Authentication required"
|
||||
)
|
||||
return
|
||||
|
||||
auth_credentials = scope.get("auth")
|
||||
|
||||
for required_scope in self.required_scopes:
|
||||
# auth_credentials should always be provided; this is just paranoia
|
||||
if auth_credentials is None or required_scope not in auth_credentials.scopes:
|
||||
await self._send_auth_error(
|
||||
send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}"
|
||||
)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None:
|
||||
"""Send an authentication error response with WWW-Authenticate header."""
|
||||
# Build WWW-Authenticate header value
|
||||
www_auth_parts = [f'error="{error}"', f'error_description="{description}"']
|
||||
if self.resource_metadata_url: # pragma: no cover
|
||||
www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"')
|
||||
|
||||
www_authenticate = f"Bearer {', '.join(www_auth_parts)}"
|
||||
|
||||
# Send response
|
||||
body = {"error": error, "error_description": description}
|
||||
body_bytes = json.dumps(body).encode()
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"content-length", str(len(body_bytes)).encode()),
|
||||
(b"www-authenticate", www_authenticate.encode()),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": body_bytes,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,115 @@
|
||||
import base64
|
||||
import binascii
|
||||
import hmac
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message # pragma: no cover
|
||||
|
||||
|
||||
class ClientAuthenticator:
|
||||
"""
|
||||
ClientAuthenticator is a callable which validates requests from a client
|
||||
application, used to verify /token calls.
|
||||
If, during registration, the client requested to be issued a secret, the
|
||||
authenticator asserts that /token calls must be authenticated with
|
||||
that same token.
|
||||
NOTE: clients can opt for no authentication during registration, in which case this
|
||||
logic is skipped.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||
"""
|
||||
Initialize the dependency.
|
||||
|
||||
Args:
|
||||
provider: Provider to look up client information
|
||||
"""
|
||||
self.provider = provider
|
||||
|
||||
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull:
|
||||
"""
|
||||
Authenticate a client from an HTTP request.
|
||||
|
||||
Extracts client credentials from the appropriate location based on the
|
||||
client's registered authentication method and validates them.
|
||||
|
||||
Args:
|
||||
request: The HTTP request containing client credentials
|
||||
|
||||
Returns:
|
||||
The authenticated client information
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
"""
|
||||
form_data = await request.form()
|
||||
client_id = form_data.get("client_id")
|
||||
if not client_id:
|
||||
raise AuthenticationError("Missing client_id")
|
||||
|
||||
client = await self.provider.get_client(str(client_id))
|
||||
if not client:
|
||||
raise AuthenticationError("Invalid client_id") # pragma: no cover
|
||||
|
||||
request_client_secret: str | None = None
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
||||
if client.token_endpoint_auth_method == "client_secret_basic":
|
||||
if not auth_header.startswith("Basic "):
|
||||
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")
|
||||
|
||||
try:
|
||||
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
|
||||
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
|
||||
if ":" not in decoded:
|
||||
raise ValueError("Invalid Basic auth format")
|
||||
basic_client_id, request_client_secret = decoded.split(":", 1)
|
||||
|
||||
# URL-decode both parts per RFC 6749 Section 2.3.1
|
||||
basic_client_id = unquote(basic_client_id)
|
||||
request_client_secret = unquote(request_client_secret)
|
||||
|
||||
if basic_client_id != client_id:
|
||||
raise AuthenticationError("Client ID mismatch in Basic auth")
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error):
|
||||
raise AuthenticationError("Invalid Basic authentication header")
|
||||
|
||||
elif client.token_endpoint_auth_method == "client_secret_post":
|
||||
raw_form_data = form_data.get("client_secret")
|
||||
# form_data.get() can return a UploadFile or None, so we need to check if it's a string
|
||||
if isinstance(raw_form_data, str):
|
||||
request_client_secret = str(raw_form_data)
|
||||
|
||||
elif client.token_endpoint_auth_method == "none":
|
||||
request_client_secret = None
|
||||
else:
|
||||
raise AuthenticationError( # pragma: no cover
|
||||
f"Unsupported auth method: {client.token_endpoint_auth_method}"
|
||||
)
|
||||
|
||||
# If client from the store expects a secret, validate that the request provides
|
||||
# that secret
|
||||
if client.client_secret: # pragma: no branch
|
||||
if not request_client_secret:
|
||||
raise AuthenticationError("Client secret is required") # pragma: no cover
|
||||
|
||||
# hmac.compare_digest requires that both arguments are either bytes or a `str` containing
|
||||
# only ASCII characters. Since we do not control `request_client_secret`, we encode both
|
||||
# arguments to bytes.
|
||||
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()):
|
||||
raise AuthenticationError("Invalid client_secret") # pragma: no cover
|
||||
|
||||
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
|
||||
raise AuthenticationError("Client secret has expired") # pragma: no cover
|
||||
|
||||
return client
|
||||
Reference in New Issue
Block a user