Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1,41 @@
from .shared import PYDANTIC_V2 as PYDANTIC_V2
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
from .shared import is_pydantic_v1_model_class as is_pydantic_v1_model_class
from .shared import is_pydantic_v1_model_instance as is_pydantic_v1_model_instance
from .shared import (
is_uploadfile_or_nonable_uploadfile_annotation as is_uploadfile_or_nonable_uploadfile_annotation,
)
from .shared import (
is_uploadfile_sequence_annotation as is_uploadfile_sequence_annotation,
)
from .shared import lenient_issubclass as lenient_issubclass
from .shared import sequence_types as sequence_types
from .shared import value_is_sequence as value_is_sequence
from .v2 import BaseConfig as BaseConfig
from .v2 import ModelField as ModelField
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .v2 import RequiredParam as RequiredParam
from .v2 import Undefined as Undefined
from .v2 import UndefinedType as UndefinedType
from .v2 import Url as Url
from .v2 import Validator as Validator
from .v2 import _regenerate_error_with_loc as _regenerate_error_with_loc
from .v2 import copy_field_info as copy_field_info
from .v2 import create_body_model as create_body_model
from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_cached_model_fields as get_cached_model_fields
from .v2 import get_compat_model_name_map as get_compat_model_name_map
from .v2 import get_definitions as get_definitions
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import get_schema_from_model_field as get_schema_from_model_field
from .v2 import is_bytes_field as is_bytes_field
from .v2 import is_bytes_sequence_field as is_bytes_sequence_field
from .v2 import is_scalar_field as is_scalar_field
from .v2 import is_scalar_sequence_field as is_scalar_sequence_field
from .v2 import is_sequence_field as is_sequence_field
from .v2 import serialize_sequence_value as serialize_sequence_value
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,
)

View File

@@ -0,0 +1,206 @@
import sys
import types
import typing
import warnings
from collections import deque
from collections.abc import Mapping, Sequence
from dataclasses import is_dataclass
from typing import (
Annotated,
Any,
Union,
)
from fastapi.types import UnionType
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import get_args, get_origin
# Copy from Pydantic v2, compatible with v1
if sys.version_info < (3, 10):
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # type: ignore[attr-defined]
else:
WithArgsTypes: tuple[Any, ...] = (
typing._GenericAlias, # type: ignore[attr-defined]
types.GenericAlias,
types.UnionType,
) # pyright: ignore[reportAttributeAccessIssue]
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
sequence_annotation_to_type = {
Sequence: list,
list: list,
tuple: tuple,
set: set,
frozenset: frozenset,
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
Url: type[Any]
# Copy of Pydantic v2, compatible with v1
def lenient_issubclass(
cls: Any, class_or_tuple: Union[type[Any], tuple[type[Any], ...], None]
) -> bool:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
except TypeError: # pragma: no cover
if isinstance(cls, WithArgsTypes):
return False
raise # pragma: no cover
def _annotation_is_sequence(annotation: Union[type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types)
def field_annotation_is_sequence(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if field_annotation_is_sequence(arg):
return True
return False
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
get_origin(annotation)
)
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes))
def _annotation_is_complex(annotation: Union[type[Any], None]) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
def field_annotation_is_complex(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
if origin is Annotated:
return field_annotation_is_complex(get_args(annotation)[0])
return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
or hasattr(origin, "__pydantic_core_schema__")
or hasattr(origin, "__get_pydantic_core_schema__")
)
def field_annotation_is_scalar(annotation: Any) -> bool:
# handle Ellipsis here to make tuple[int, ...] work nicely
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
for arg in get_args(annotation):
if field_annotation_is_scalar_sequence(arg):
at_least_one_scalar_sequence = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_sequence
return field_annotation_is_sequence(annotation) and all(
field_annotation_is_scalar(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, bytes):
return True
return False
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, UploadFile):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, UploadFile):
return True
return False
def is_bytes_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_bytes_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_bytes_or_nonable_bytes_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_uploadfile_sequence_annotation(arg):
at_least_one = True
continue
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
def is_pydantic_v1_model_instance(obj: Any) -> bool:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from pydantic import v1
return isinstance(obj, v1.BaseModel)
def is_pydantic_v1_model_class(cls: Any) -> bool:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from pydantic import v1
return lenient_issubclass(cls, v1.BaseModel)
def annotation_is_pydantic_v1(annotation: Any) -> bool:
if is_pydantic_v1_model_class(annotation):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if is_pydantic_v1_model_class(arg):
return True
if field_annotation_is_sequence(annotation):
for sub_annotation in get_args(annotation):
if annotation_is_pydantic_v1(sub_annotation):
return True
return False

View File

@@ -0,0 +1,568 @@
import re
import warnings
from collections.abc import Sequence
from copy import copy, deepcopy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
from typing import (
Annotated,
Any,
Union,
cast,
)
from fastapi._compat import shared
from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
)
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
from typing_extensions import Literal, get_args, get_origin
try:
from pydantic_core.core_schema import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
except ImportError: # pragma: no cover
from pydantic_core.core_schema import (
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
)
RequiredParam = PydanticUndefined
Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
Validator = Any
# TODO: remove when dropping support for Pydantic < v2.12.3
_Attrs = {
"default": ...,
"default_factory": None,
"alias": None,
"alias_priority": None,
"validation_alias": None,
"serialization_alias": None,
"title": None,
"field_title_generator": None,
"description": None,
"examples": None,
"exclude": None,
"exclude_if": None,
"discriminator": None,
"deprecated": None,
"json_schema_extra": None,
"frozen": None,
"validate_default": None,
"repr": True,
"init": None,
"init_var": None,
"kw_only": None,
}
# TODO: remove when dropping support for Pydantic < v2.12.3
def asdict(field_info: FieldInfo) -> dict[str, Any]:
attributes = {}
for attr in _Attrs:
value = getattr(field_info, attr, Undefined)
if value is not Undefined:
attributes[attr] = value
return {
"annotation": field_info.annotation,
"metadata": field_info.metadata,
"attributes": attributes,
}
class BaseConfig:
pass
class ErrorWrapper(Exception):
pass
@dataclass
class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
config: Union[ConfigDict, None] = None
@property
def alias(self) -> str:
a = self.field_info.alias
return a if a is not None else self.name
@property
def validation_alias(self) -> Union[str, None]:
va = self.field_info.validation_alias
if isinstance(va, str) and va:
return va
return None
@property
def serialization_alias(self) -> Union[str, None]:
sa = self.field_info.serialization_alias
return sa or None
@property
def required(self) -> bool:
return self.field_info.is_required()
@property
def default(self) -> Any:
return self.get_default()
@property
def type_(self) -> Any:
return self.field_info.annotation
def __post_init__(self) -> None:
with warnings.catch_warnings():
# Pydantic >= 2.12.0 warns about field specific metadata that is unused
# (e.g. `TypeAdapter(Annotated[int, Field(alias='b')])`). In some cases, we
# end up building the type adapter from a model field annotation so we
# need to ignore the warning:
if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 12):
from pydantic.warnings import UnsupportedFieldAttributeWarning
warnings.simplefilter(
"ignore", category=UnsupportedFieldAttributeWarning
)
# TODO: remove after dropping support for Python 3.8 and
# setting the min Pydantic to v2.12.3 that adds asdict()
field_dict = asdict(self.field_info)
annotated_args = (
field_dict["annotation"],
*field_dict["metadata"],
# this FieldInfo needs to be created again so that it doesn't include
# the old field info metadata and only the rest of the attributes
Field(**field_dict["attributes"]),
)
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[annotated_args],
config=self.config,
)
def get_default(self) -> Any:
if self.field_info.is_required():
return Undefined
return self.field_info.get_default(call_default_factory=True)
def validate(
self,
value: Any,
values: dict[str, Any] = {}, # noqa: B006
*,
loc: tuple[Union[int, str], ...] = (),
) -> tuple[Any, Union[list[dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
# What calls this code passes a value that already called
# self._type_adapter.validate_python(value)
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from
# ModelField to its JSON Schema.
return id(self)
def _has_computed_fields(field: ModelField) -> bool:
computed_fields = field._type_adapter.core_schema.get("schema", {}).get(
"computed_fields", []
)
return len(computed_fields) > 0
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> dict[str, Any]:
override_mode: Union[Literal["validation"], None] = (
None
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
)
field_alias = (
(field.validation_alias or field.alias)
if field.mode == "validation"
else (field.serialization_alias or field.alias)
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = field.field_info.title or field_alias.title().replace(
"_", " "
)
return json_schema
def get_definitions(
*,
fields: Sequence[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, dict[str, Any]],
]:
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
validation_fields = [field for field in fields if field.mode == "validation"]
serialization_fields = [field for field in fields if field.mode == "serialization"]
flat_validation_models = get_flat_models_from_fields(
validation_fields, known_models=set()
)
flat_serialization_models = get_flat_models_from_fields(
serialization_fields, known_models=set()
)
flat_validation_model_fields = [
ModelField(
field_info=FieldInfo(annotation=model),
name=model.__name__,
mode="validation",
)
for model in flat_validation_models
]
flat_serialization_model_fields = [
ModelField(
field_info=FieldInfo(annotation=model),
name=model.__name__,
mode="serialization",
)
for model in flat_serialization_models
]
flat_model_fields = flat_validation_model_fields + flat_serialization_model_fields
input_types = {f.type_ for f in fields}
unique_flat_model_fields = {
f for f in flat_model_fields if f.type_ not in input_types
}
inputs = [
(
field,
(
field.mode
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
),
field._type_adapter.core_schema,
)
for field in list(fields) + list(unique_flat_model_fields)
]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
for item_def in cast(dict[str, dict[str, Any]], definitions).values():
if "description" in item_def:
item_description = cast(str, item_def["description"]).split("\f")[0]
item_def["description"] = item_description
new_mapping, new_definitions = _remap_definitions_and_field_mappings(
model_name_map=model_name_map,
definitions=definitions, # type: ignore[arg-type]
field_mapping=field_mapping,
)
return new_mapping, new_definitions
def _replace_refs(
*,
schema: dict[str, Any],
old_name_to_new_name_map: dict[str, str],
) -> dict[str, Any]:
new_schema = deepcopy(schema)
for key, value in new_schema.items():
if key == "$ref":
value = schema["$ref"]
if isinstance(value, str):
ref_name = schema["$ref"].split("/")[-1]
if ref_name in old_name_to_new_name_map:
new_name = old_name_to_new_name_map[ref_name]
new_schema["$ref"] = REF_TEMPLATE.format(model=new_name)
continue
if isinstance(value, dict):
new_schema[key] = _replace_refs(
schema=value,
old_name_to_new_name_map=old_name_to_new_name_map,
)
elif isinstance(value, list):
new_value = []
for item in value:
if isinstance(item, dict):
new_item = _replace_refs(
schema=item,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_value.append(new_item)
else:
new_value.append(item)
new_schema[key] = new_value
return new_schema
def _remap_definitions_and_field_mappings(
*,
model_name_map: ModelNameMap,
definitions: dict[str, Any],
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, Any],
]:
old_name_to_new_name_map = {}
for field_key, schema in field_mapping.items():
model = field_key[0].type_
if model not in model_name_map or "$ref" not in schema:
continue
new_name = model_name_map[model]
old_name = schema["$ref"].split("/")[-1]
if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
continue
old_name_to_new_name_map[old_name] = new_name
new_field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
] = {}
for field_key, schema in field_mapping.items():
new_schema = _replace_refs(
schema=schema,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_field_mapping[field_key] = new_schema
new_definitions = {}
for key, value in definitions.items():
if key in old_name_to_new_name_map:
new_key = old_name_to_new_name_map[key]
else:
new_key = key
new_value = _replace_refs(
schema=value,
old_name_to_new_name_map=old_name_to_new_name_map,
)
new_definitions[new_key] = new_value
return new_field_mapping, new_definitions
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return shared.field_annotation_is_scalar(
field.field_info.annotation
) and not isinstance(field.field_info, params.Body)
def is_sequence_field(field: ModelField) -> bool:
return shared.field_annotation_is_sequence(field.field_info.annotation)
def is_scalar_sequence_field(field: ModelField) -> bool:
return shared.field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return shared.is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return shared.is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
cls = type(field_info)
merged_field_info = cls.from_annotation(annotation)
new_field_info = copy(field_info)
new_field_info.metadata = merged_field_info.metadata
new_field_info.annotation = merged_field_info.annotation
return new_field_info
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
if origin_type is Union or origin_type is UnionType: # Handle optional sequences
union_args = get_args(field.field_info.annotation)
for union_arg in union_args:
if union_arg is type(None):
continue
origin_type = get_origin(union_arg) or union_arg
break
assert issubclass(origin_type, shared.sequence_types) # type: ignore[arg-type]
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index]
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
def get_model_fields(model: type[BaseModel]) -> list[ModelField]:
model_fields: list[ModelField] = []
for name, field_info in model.model_fields.items():
type_ = field_info.annotation
if lenient_issubclass(type_, (BaseModel, dict)) or is_dataclass(type_):
model_config = None
else:
model_config = model.model_config
model_fields.append(
ModelField(
field_info=field_info,
name=name,
config=model_config,
)
)
return model_fields
@lru_cache
def get_cached_model_fields(model: type[BaseModel]) -> list[ModelField]:
return get_model_fields(model) # type: ignore[return-value]
# Duplicate of several schema functions from Pydantic v1 to make them compatible with
# Pydantic v2 and allow mixing the models
TypeModelOrEnum = Union[type["BaseModel"], type[Enum]]
TypeModelSet = set[TypeModelOrEnum]
def normalize_name(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str]:
name_model_map = {}
for model in unique_models:
model_name = normalize_name(model.__name__)
name_model_map[model_name] = model
return {v: k for k, v in name_model_map.items()}
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
all_flat_models = set()
v2_model_fields = [field for field in fields if isinstance(field, ModelField)]
v2_flat_models = get_flat_models_from_fields(v2_model_fields, known_models=set())
all_flat_models = all_flat_models.union(v2_flat_models) # type: ignore[arg-type]
model_name_map = get_model_name_map(all_flat_models) # type: ignore[arg-type]
return model_name_map
def get_flat_models_from_model(
model: type["BaseModel"], known_models: Union[TypeModelSet, None] = None
) -> TypeModelSet:
known_models = known_models or set()
fields = get_model_fields(model)
get_flat_models_from_fields(fields, known_models=known_models)
return known_models
def get_flat_models_from_annotation(
annotation: Any, known_models: TypeModelSet
) -> TypeModelSet:
origin = get_origin(annotation)
if origin is not None:
for arg in get_args(annotation):
if lenient_issubclass(arg, (BaseModel, Enum)) and arg not in known_models:
known_models.add(arg)
if lenient_issubclass(arg, BaseModel):
get_flat_models_from_model(arg, known_models=known_models)
else:
get_flat_models_from_annotation(arg, known_models=known_models)
return known_models
def get_flat_models_from_field(
field: ModelField, known_models: TypeModelSet
) -> TypeModelSet:
field_type = field.type_
if lenient_issubclass(field_type, BaseModel):
if field_type in known_models:
return known_models
known_models.add(field_type)
get_flat_models_from_model(field_type, known_models=known_models)
elif lenient_issubclass(field_type, Enum):
known_models.add(field_type)
else:
get_flat_models_from_annotation(field_type, known_models=known_models)
return known_models
def get_flat_models_from_fields(
fields: Sequence[ModelField], known_models: TypeModelSet
) -> TypeModelSet:
for field in fields:
get_flat_models_from_field(field, known_models=known_models)
return known_models
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: tuple[Union[str, int], ...]
) -> list[dict[str, Any]]:
updated_loc_errors: list[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())} for err in errors
]
return updated_loc_errors